diff --git a/requirements.txt b/requirements.txt index 788eb5d..9489e24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ boto3==1.9.200 botocore==1.12.253 cachetools==3.0.0 certifi==2018.11.29 -cffi==1.11.5 +cffi==1.14.5 chardet==3.0.4 coverage==4.5.2 dnspython==1.16.0 @@ -46,7 +46,7 @@ rsa==4.0 s3transfer==0.2.1 six==1.12.0 sparkpost==1.3.6 -typed-ast==1.2.0 +typed-ast==1.4.2 uritemplate==3.0.0 urllib3==1.24.2 wrapt==1.11.0 diff --git a/src/maglink.py b/src/maglink.py index d1ec328..ae72539 100644 --- a/src/maglink.py +++ b/src/maglink.py @@ -44,7 +44,7 @@ def forgot_user(event, magiclinks, user_coll): return util.add_cors_headers({"statusCode": 200, "body": "Forgot password link has been emailed to you"}) -def director_link(magiclinks, num_links, event, user): +def director_link(magiclinks, event, user): """ Function used to generate magic links for one or more users to be promoted """ @@ -54,7 +54,7 @@ def director_link(magiclinks, num_links, event, user): for i in event['permissions']: permissions.append(i) # for each of the emails requested to be promoted... - for j in range(min(num_links, len(event['emailsTo']))): + for j in range(len(event['emailsTo'])): # a unique magic link is generated as 32 random alphanumeric characters magiclink = ''.join([random.choice(string.ascii_letters + string.digits) for _ in range(32)]) # an object is created to be stored in the database @@ -85,7 +85,6 @@ def director_link(magiclinks, num_links, event, user): "link_base": {"type": "string"}, "permissions": {"type": "array"}, "emailsTo": {"type": "array"}, - "numLinks": {"type": "integer"} }, "required": ["token", "permissions", "emailsTo"] }) @@ -95,8 +94,7 @@ def do_director_link(event, magiclinks, user=None): """ Function used by directors to promote users through magiclinks """ - num_links = event.get('numLinks', 1) - links_list = director_link(magiclinks, num_links, event, user) + links_list = director_link(magiclinks, event, user) return util.add_cors_headers({"statusCode": 200, "body": links_list}) diff --git a/src/read.py b/src/read.py index 76574d9..1a4c78e 100644 --- a/src/read.py +++ b/src/read.py @@ -1,15 +1,13 @@ from src.schemas import * - - def tidy_results(res): """ Function used to clean up read results before returning to the user """ for i in res: - del i['_id'] - del i['password'] + del i["_id"] + del i["password"] return res @@ -36,17 +34,17 @@ def tidy_results(res): }, "required": ["query"] }) -def public_read(event, context): +def aggregate_read(event, context): """ - Function responsible for performing a public read (can be requested by anyone) + Function responsible for performing an aggregate read """ # the fields to be aggregated - fields = event['fields'] + fields = event["query"]["fields"] # filter based on the just_here boolean indicating whether or not to aggregate on checked-in users - match = {"$match": {"registration_status": ("checked-in" if event.get('just_here', False) else {"$ne": "unregistered"})}} + match = {"$match": {"registration_status": ("checked-in" if event.get("just_here", False) else {"$ne": "unregistered"})}} # group by is performed using each of the fields requested group = {"$group": {"_id": {field: "$" + field for field in fields}, "total": {"$sum": 1}}} - user_coll = util.coll('users') + user_coll = util.coll("users") # aggregate's pipelining is used to fetch the results from the user data return {"statusCode": 200, "body": list(user_coll.aggregate([match, group]))} @@ -55,42 +53,55 @@ def user_read(event, context, user): """ Function used by a LCS user to fetch their information """ - # if the desired action is to aggregate, than it is no different from a public read - if event.get('aggregate', False): - return public_read(event, context) + # if the desired action is to aggregate, than perform an aggregate read + if event.get("aggregate", False): + return aggregate_read(event, context) - # otherwise, any reimbursement information is removed before sending that user's data back - if user['registration_status'] in ['unregistered', 'registered', 'rejected']: - if 'travelling_from' in user and 'reimbursement' in user['travelling_from']: - del user['travelling_from']['reimbursement'] + # otherwise, any reimbursement information is removed before sending that user"s data back + if user["registration_status"] in ["unregistered", "registered", "rejected"]: + if "travelling_from" in user and "reimbursement" in user["travelling_from"]: + del user["travelling_from"]["reimbursement"] return {"statusCode": 200, "body": [user]} -@ensure_role([['director', 'organizer']], on_failure=lambda e, c, u, *a: user_read(e, c, u)) +@ensure_role([["director", "organizer"]], on_failure=lambda e, c, u, *a: user_read(e, c, u)) +@ensure_schema({ + "type": "object", + "properties": { + "token": {"type": "string"}, + "query": {"type": "object"}, + "aggregate": {"type": "boolean"} + }, + "required": ["query"] +}) def organizer_read(event, context, user): """ Function responsible for performing an organizer query. In-case of insufficient permissions, falls back on user_read """ # if aggregation is desired, a public read will suffice - if event.get('aggregate', False): - return public_read(event, context) + if event.get("aggregate", False): + return aggregate_read(event, context) # otherwise, the organizer submitted query is ran on the database and results are returned - user_coll = util.coll('users') - return {"statusCode": 200, "body": tidy_results(list(user_coll.find(event['query'])))} + user_coll = util.coll("users") + return {"statusCode": 200, "body": tidy_results(list(user_coll.find(event["query"])))} @ensure_schema({ "type": "object", "properties": { "token": {"type": "string"}, - "query": {"type": "object"}, + "query": { + "oneOf": [ + {"type": "object"}, {"type": "array"} + ] + }, "aggregate": {"type": "boolean"} }, - "required": ["query"] + "required": ["token"] }) -@ensure_logged_in_user(on_failure=lambda e, c, u, *a: public_read(e, c)) -@ensure_role([['director']], on_failure=lambda e, c, u, *a: organizer_read(e, c, u)) +@ensure_logged_in_user() +@ensure_role([["director"]], on_failure=lambda e, c, u, *a: organizer_read(e, c, u)) def read_info(event, context, user=None): """ We allow for an arbitrary mongo query to be passed in. @@ -99,8 +110,19 @@ def read_info(event, context, user=None): If the endpoint is called by a non-LCS user, falls back upon public_read If the endpoint is called by a non-director, falls back upon organizer_read """ - tests = util.coll('users') + tests = util.coll("users") + + if "query" not in event: + return {"statusCode": 400, "body": "Missing parameter 'query' from request. It is required for director read"} - if event.get('aggregate', False): - return {"statusCode": 200, "body": list(tests.aggregate(event['query']))} - return {"statusCode": 200, "body": tidy_results(list(tests.find(event['query'])))} + if event.get("aggregate", False): + if not isinstance(event["query"], list): + return {"statusCode": 400, "body": "Invalid parameter 'query'. " + "Expected the query to be of type array (the aggregation pipeline)"} + else: + return {"statusCode": 200, "body": list(tests.aggregate(event["query"]))} + else: + if not isinstance(event["query"], dict): + return {"statusCode": 400, "body": "Invalid parameter 'query'. Expected the query to be of type object"} + else: + return {"statusCode": 200, "body": tidy_results(list(tests.find(event["query"])))} diff --git a/src/schemas.py b/src/schemas.py index 1cf9d5e..0b88997 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -18,7 +18,7 @@ def wrapt(event, context, *extras): js.validate(event, schema) return util.add_cors_headers(fn(event, context, *extras)) except js.exceptions.ValidationError as e: - return util.add_cors_headers(on_failure(event, context, e)) + return util.add_cors_headers(on_failure(event, context, e.message)) return wrapt return wrap @@ -31,6 +31,9 @@ def rapper(fn): @wraps(fn) def wrapt(event, context, *args): + if token_key not in event: + return on_failure(event, context, "Missing authentication token", *args) + token = event[token_key] try: decoded_payload = jwt.decode(token, config.JWT_SECRET, algorithms=[config.JWT_ALGO]) diff --git a/tests/test_cal_announce.py b/tests/test_cal_announce.py index 75ade73..2fcdb29 100644 --- a/tests/test_cal_announce.py +++ b/tests/test_cal_announce.py @@ -3,7 +3,6 @@ from config import GOOGLE_CAL - def test_get_cal(): res = cal_announce.google_cal({}, {}) print(res) diff --git a/tests/test_read.py b/tests/test_read.py new file mode 100644 index 0000000..6721ab8 --- /dev/null +++ b/tests/test_read.py @@ -0,0 +1,194 @@ +import importlib +from functools import wraps + + +from src import read +from unittest.mock import patch + + +test_user = { + "email": "test@hackru.org", + "role": { + "hacker": True, + "volunteer": False, + "judge": False, + "sponsor": False, + "mentor": False, + "organizer": False, + "director": False + }, + "registration_status": "unregistered", + "travelling_from": { + "reimbursement": 10 + } +} + +DUMMY_TOKEN = "dummy token" + +test_aggregate_query = { + "token": DUMMY_TOKEN, + "query": { + "fields": ["major"] + }, + "aggregate": True +} + +test_user_query = { + "token": DUMMY_TOKEN +} + +test_elevated_query = { + "token": DUMMY_TOKEN, + "query": { + "email": "test@hackru.org" + } +} + +test_director_aggregate_query = { + "token": DUMMY_TOKEN, + "query": [ + {"$match": {"registration_status": "registered"}}, + {"$group": {"_id": None, "registered_people": {"$sum": "$amount"}}} + ], + "aggregate": True +} + + +def patched_ensure_logged_in_user(role=None): + def ensure_logged_in_user(): + def wrapper(fn, *args1, **kwargs1): + @wraps(fn) + def wrapt(event, context, *args, **kwargs): + tu = test_user.copy() + roles = test_user["role"].copy() + if role: + assert role in roles + roles[role] = True + tu["role"] = roles + return fn(event, context, tu, *args) + return wrapt + return wrapper + return ensure_logged_in_user + + +def test_unauthenticated_read(): + res = read.read_info({}, {}) + assert res["statusCode"] == 400 + + +@patch("src.schemas.ensure_logged_in_user", patched_ensure_logged_in_user()) +def test_missing_query_aggregate_read(): + importlib.reload(read) + tq = test_aggregate_query.copy() + del tq["query"] + res = read.read_info(tq, {}) + assert res["statusCode"] == 400 + assert str(res["body"]).startswith("Error in JSON: 'query' is a required property") + + +@patch("src.schemas.ensure_logged_in_user", patched_ensure_logged_in_user()) +def test_invalid_aggregate_read(): + importlib.reload(read) + tq = test_aggregate_query.copy() + tq["query"] = {"email": "test@hackru.org"} + res = read.read_info(tq, {}) + assert res["statusCode"] == 400 + assert str(res["body"]).startswith("Error in JSON: 'fields' is a required property") + tq["query"] = {"fields": ["bad field"]} + res = read.read_info(tq, {}) + assert res["statusCode"] == 400 + assert str(res["body"]).startswith("Error in JSON: 'bad field' is not one of") + + +@patch("src.schemas.ensure_logged_in_user", patched_ensure_logged_in_user()) +def test_successful_aggregate_read(): + importlib.reload(read) + res = read.read_info(test_aggregate_query, {}) + assert res["statusCode"] == 200 + + +@patch("src.schemas.ensure_logged_in_user", patched_ensure_logged_in_user()) +def test_successful_user_read(): + importlib.reload(read) + res = read.read_info(test_user_query, {}) + assert res["statusCode"] == 200 + # user object is returned as first object in an array + assert res["body"] == [test_user] + + +@patch("src.schemas.ensure_logged_in_user", patched_ensure_logged_in_user("organizer")) +def test_missing_query_aggregate_organizer_read(): + importlib.reload(read) + tq = test_elevated_query.copy() + del tq["query"] + res = read.read_info(tq, {}) + assert res["statusCode"] == 400 + assert str(res["body"]).startswith("Error in JSON: 'query' is a required property") + + +@patch("src.schemas.ensure_logged_in_user", patched_ensure_logged_in_user("organizer")) +def test_invalid_aggregate_organizer_read(): + importlib.reload(read) + tq = test_elevated_query.copy() + tq["aggregate"] = True + res = read.read_info(tq, {}) + assert res["statusCode"] == 400 + assert str(res["body"]).startswith("Error in JSON: 'fields' is a required property") + + +@patch("src.schemas.ensure_logged_in_user", patched_ensure_logged_in_user("organizer")) +def test_successful_aggregate_organizer_read(): + importlib.reload(read) + res = read.read_info(test_aggregate_query, {}) + assert res["statusCode"] == 200 + + +@patch("src.schemas.ensure_logged_in_user", patched_ensure_logged_in_user("organizer")) +def test_successful_query_organizer_read(): + importlib.reload(read) + res = read.read_info(test_elevated_query, {}) + assert res["statusCode"] == 200 + + +@patch("src.schemas.ensure_logged_in_user", patched_ensure_logged_in_user("director")) +def test_missing_query_director_read(): + importlib.reload(read) + tq = test_elevated_query.copy() + del tq["query"] + res = read.read_info(tq, {}) + assert res["statusCode"] == 400 + assert str(res["body"]).startswith("Missing parameter 'query'") + + +@patch("src.schemas.ensure_logged_in_user", patched_ensure_logged_in_user("director")) +def test_invalid_param_director_read(): + importlib.reload(read) + tq = test_elevated_query.copy() + tq["query"] = ["bad query"] + res = read.read_info(tq, {}) + assert res["statusCode"] == 400 + assert str(res["body"]).startswith("Invalid parameter 'query'") + + +@patch("src.schemas.ensure_logged_in_user", patched_ensure_logged_in_user("director")) +def test_successful_director_read(): + importlib.reload(read) + res = read.read_info(test_elevated_query, {}) + assert res["statusCode"] == 200 + + +@patch("src.schemas.ensure_logged_in_user", patched_ensure_logged_in_user("director")) +def test_invalid_param_director_aggregate_read(): + importlib.reload(read) + tq = test_director_aggregate_query.copy() + tq["query"] = {"email": "test@hackru.org"} + res = read.read_info(tq, {}) + assert res["statusCode"] == 400 + assert str(res["body"]).startswith("Invalid parameter 'query'") + + +@patch("src.schemas.ensure_logged_in_user", patched_ensure_logged_in_user("director")) +def test_successful_director_aggregate_read(): + importlib.reload(read) + res = read.read_info(test_director_aggregate_query, {}) + assert res["statusCode"] == 200