diff --git a/lambda/app.py b/lambda/app.py index ec317326..a259fd22 100644 --- a/lambda/app.py +++ b/lambda/app.py @@ -24,6 +24,8 @@ from opentelemetry import trace from opentelemetry.propagate import inject except ImportError: + # If opentelemetry is not installed, we make dummy objects / functions here + # so we do not need to add conditionls throughout our codebase trace = None def inject(obj): @@ -96,7 +98,13 @@ def wrapper(*args, **kwargs): @ttl_cache(maxsize=2, ttl=10 * 60, timer=time.time) -def get_black_list(): +def get_black_list() -> dict: + """ + Return blacklist if configured + + Looks at the BLACKLIST_ENDPOINT environment variable for a URL to + contact for fetching the blacklist from. + """ endpoint = os.getenv('BLACKLIST_ENDPOINT', '') if endpoint: response = urllib.request.urlopen(endpoint).read().decode('utf-8') @@ -108,7 +116,12 @@ def get_black_list(): @app.middleware('http') -def initialize(event, get_response): +def initialize(event: chalice.app.Request, get_response) -> Response: + """ + Initialize various properties needed for each request. + + This function is called on *every* request before any of the handlers are called + """ JWT_MANAGER.black_list = get_black_list() jwt_keys = retrieve_secret(os.getenv('JWT_KEY_SECRET_NAME')) JWT_MANAGER.public_key = base64.b64decode(jwt_keys.get('rsa_pub_key', '')).decode() @@ -118,7 +131,12 @@ def initialize(event, get_response): @app.middleware('http') -def set_log_context(event: chalice.app.Request, get_response): +def set_log_context(event: chalice.app.Request, get_response) -> Response: + """ + Set context about current request for all log statements + + This function is called for each request before the request handlers are called. + """ origin_request_id = event.headers.get('x-origin-request-id') log_context( @@ -135,11 +153,19 @@ def set_log_context(event: chalice.app.Request, get_response): try: return get_response(event) finally: + # Reset log context after our request is completed log_context(user_id=None, route=None, request_id=None) @app.middleware('http') -def forward_origin_request_id(event: chalice.app.Request, get_response): +def forward_origin_request_id(event: chalice.app.Request, get_response) -> Response: + """ + Normalize x-request-id header to point to aws_request_id for all requests + + The original request_id, if present, is put in an x-origin-request-id + + This function is called for each request before the request handlers are called. + """ response = get_response(event) origin_request_id = event.headers.get('x-origin-request-id') @@ -157,17 +183,29 @@ class TeaException(Exception): class EulaException(TeaException): + """ + Exception indicating that the authorized user has not accepted the EULA for accessing the dataset + """ def __init__(self, payload: dict): self.payload = payload class RequestAuthorizer: + """ + Handle authorization of incoming requests. + + Supports handling traditional OAuth2 with appropriate redirects, as well as using + bearer tokens. + """ def __init__(self): self._response = None self._headers = {} @with_trace() def get_profile(self) -> Optional[UserProfile]: + """ + Return user profile if the user is authenticated + """ user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers) if user_profile is not None: return user_profile @@ -201,7 +239,7 @@ def get_profile(self) -> Optional[UserProfile]: return None @with_trace() - def _handle_auth_bearer_header(self, token) -> Optional[UserProfile]: + def _handle_auth_bearer_header(self, token: str) -> Optional[UserProfile]: """ Will handle the output from get_user_from_token in context of a chalice function. If user_id is determined, returns it. If user_id is not determined returns data to be returned @@ -255,6 +293,9 @@ def get_success_response_headers(self) -> dict: @with_trace() def get_request_id() -> str: + """ + Return AWS Lambda Request ID + """ assert app.lambda_context is not None return app.lambda_context.aws_request_id @@ -262,13 +303,22 @@ def get_request_id() -> str: @with_trace() def get_origin_request_id() -> Optional[str]: + """ + Return the *original* AWS Lambda Request ID + """ assert app.current_request is not None return app.current_request.headers.get("x-origin-request-id") @with_trace() -def get_aux_request_headers(): +def get_aux_request_headers() -> dict: + """ + Return common HTTP headers used when making requests to EarthData login servers. + + These are headers we send when making requests as a *client*, not when making + responses as a server. + """ req_headers = {"x-request-id": get_request_id()} origin_request_id = get_origin_request_id() @@ -282,12 +332,15 @@ def get_aux_request_headers(): @with_trace() -def check_for_browser(hdrs): +def check_for_browser(hdrs: dict) -> bool: + """ + Return True if request is being sent by a browser + """ return 'user-agent' in hdrs and hdrs['user-agent'].lower().startswith('mozilla') @with_trace() -def get_user_from_token(token): +def get_user_from_token(token: str) -> Optional[str]: """ This may be moved to rain-api-core.urs_util.py once things stabilize. Will query URS for user ID of requesting user based on token sent with request @@ -367,12 +420,18 @@ def get_user_from_token(token): @with_trace() def cumulus_log_message(outcome: str, code: int, http_method: str, k_v: dict): + """ + Emit log message to stdout in a format that cumulus understands + """ k_v.update({'code': code, 'http_method': http_method, 'status': outcome, 'requestid': get_request_id()}) print(json.dumps(k_v)) @with_trace() def restore_bucket_vars(): + """ + Update bucket config by re-fetching it from configured S3 object + """ global b_map # pylint: disable=global-statement log.debug('conf bucket: %s, bucket_map_file: %s', conf_bucket, bucket_map_file) @@ -405,7 +464,7 @@ def restore_bucket_vars(): @with_trace() -def do_auth_and_return(ctxt): +def do_auth_and_return(ctxt) -> Response: log.debug('context: {}'.format(ctxt)) here = ctxt['path'] if os.getenv('DOMAIN_NAME'): @@ -423,7 +482,10 @@ def do_auth_and_return(ctxt): @with_trace() -def add_cors_headers(headers): +def add_cors_headers(headers: dict): + """ + Add CORS headers to allow requests from all configured domains + """ assert app.current_request is not None # send CORS headers if we're configured to use them @@ -438,7 +500,10 @@ def add_cors_headers(headers): @with_trace() -def make_redirect(to_url, headers=None, status_code=301): +def make_redirect(to_url: str, headers: Optional[dict] = None, status_code: str = 301) -> Response: + """ + Return a HTTP Response redirecting users with appropriate headers + """ if headers is None: headers = {} headers['Location'] = to_url @@ -450,7 +515,15 @@ def make_redirect(to_url, headers=None, status_code=301): @with_trace() -def make_html_response(t_vars: dict, headers: dict, status_code: int = 200, template_file: str = 'root.html'): +def make_html_response( + t_vars: dict, + headers: dict, + status_code: int = 200, + template_file: str = 'root.html' +) -> Response: + """ + Return a HTTP response with rendered HTML from the given template + """ template_vars = { 'STAGE': STAGE if not os.getenv('DOMAIN_NAME') else None, 'status_code': status_code, @@ -488,7 +561,10 @@ def get_bcconfig(user_id: str) -> dict: # Cache by bucketname only key=lambda _, bucketname: hashkey(bucketname) ) -def get_bucket_region(session, bucketname) -> str: +def get_bucket_region(session, bucketname: str) -> str: + """ + Get the region of the given bucket + """ try: _time = time.time() bucket_region = session.client('s3').get_bucket_location(Bucket=bucketname)['LocationConstraint'] or 'us-east-1' @@ -507,7 +583,10 @@ def get_bucket_region(session, bucketname) -> str: @with_trace() -def get_user_ip(): +def get_user_ip() -> str: + """ + Return IP of the user making the request + """ assert app.current_request is not None x_forwarded_for = app.current_request.headers.get('x-forwarded-for') @@ -522,7 +601,13 @@ def get_user_ip(): @with_trace() -def try_download_from_bucket(bucket, filename, user_profile, headers: dict): +def try_download_from_bucket(bucket: str, filename: str, user_profile: UserProfile, headers: dict) -> Response: + """ + Attempt to redirect to given file from given bucket. + + Returns a redirect response with presigned S3 URL if successful, + or an appropriate error response if unsuccessful. + """ timer = Timer() timer.mark() user_id = None @@ -627,13 +712,16 @@ def try_download_from_bucket(bucket, filename, user_profile, headers: dict): @with_trace() -def get_jwt_field(cookievar: dict, fieldname: str): +def get_jwt_field(cookievar: dict, fieldname: str) -> Optional[str]: return cookievar.get(JWT_COOKIE_NAME, {}).get(fieldname, None) @app.route('/') @with_trace(context={}) -def root(): +def root() -> Response: + """ + Render human readable root page + """ template_vars = {'title': 'Welcome'} user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers) if user_profile is not None: @@ -648,7 +736,7 @@ def root(): @app.route('/logout') @with_trace(context={}) -def logout(): +def logout() -> Response: user_profile = JWT_MANAGER.get_profile_from_headers(app.current_request.headers) template_vars = {'title': 'Logged Out', 'URS_URL': get_urs_url(app.current_request.context)} @@ -667,7 +755,7 @@ def logout(): @app.route('/login') @with_trace(context={}) -def login(): +def login() -> Response: try: headers = {} aux_headers = get_aux_request_headers() @@ -694,7 +782,7 @@ def login(): @app.route('/version') @with_trace(context={}) -def version(): +def version() -> str: log.info("Got a version request!") version_return = {'version_id': ''} @@ -707,7 +795,7 @@ def version(): @app.route('/locate') @with_trace(context={}) -def locate(): +def locate() -> Response: query_params = app.current_request.query_params if query_params is None or query_params.get('bucket_name') is None: return Response(body='Required "bucket_name" query paramater not specified', @@ -727,7 +815,7 @@ def locate(): @with_trace() -def collapse_bucket_configuration(bucket_map): +def collapse_bucket_configuration(bucket_map) -> dict: for k, v in bucket_map.items(): if isinstance(v, dict): if 'bucket' in v: @@ -738,7 +826,10 @@ def collapse_bucket_configuration(bucket_map): @with_trace() -def get_range_header_val(): +def get_range_header_val() -> Optional[str]: + """ + Return value of range header if present + """ if 'Range' in app.current_request.headers: return app.current_request.headers['Range'] if 'range' in app.current_request.headers: @@ -775,7 +866,13 @@ def get_data_dl_s3_client(): @with_trace() -def try_download_head(bucket, filename): +def try_download_head(bucket: str, filename: str) -> Response: + """ + Try to handle a HEAD request for given filename in given bucket + + Return a redirect response if given filename exists in the bucket, provide + an error message otherwise. + """ timer = Timer() timer.mark("get_data_dl_s3_client()") @@ -839,6 +936,13 @@ def try_download_head(bucket, filename): @app.route('/{proxy+}', methods=['HEAD']) @with_trace(context={}) def dynamic_url_head(): + """ + Handle HEAD requests for arbitrary files in arbitrary buckets + + The name of the bucket and filename is parsed out of the URL. If the + file is found in the bucket and the request is authenticated properly, + a signed s3 URL is returned. If not, an error response is returned. + """ timer = Timer() timer.mark("restore_bucket_vars()") log.debug('attempting to HEAD a thing') @@ -872,6 +976,13 @@ def dynamic_url_head(): @app.route('/{proxy+}', methods=['GET']) @with_trace(context={}) def dynamic_url(): + """ + Handle GET requests for arbitrary files in arbitrary buckets + + The name of the bucket and filename is parsed out of the URL. If the + file is found in the bucket and the request is authenticated properly, + a signed s3 URL is returned. If not, an error response is returned. + """ timer = Timer() timer.mark("restore_bucket_vars()") @@ -963,6 +1074,9 @@ def dynamic_url(): @app.route('/s3credentials', methods=['GET']) @with_trace(context={}) def s3credentials(): + """ + Return temporary AWS credentials to calling user, with ability to access S3 + """ timer = Timer() timer.mark("restore_bucket_vars()") @@ -1026,7 +1140,7 @@ def s3credentials(): @with_trace() -def get_role_session_name(user_id: str, app_name: str): +def get_role_session_name(user_id: str, app_name: str) -> str: # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html#API_AssumeRole_RequestParameters if not re.match(r"[\w+,.=@-]*", app_name): app_name = "" @@ -1035,7 +1149,7 @@ def get_role_session_name(user_id: str, app_name: str): @with_trace() -def get_s3_credentials(user_id: str, role_session_name: str, policy: dict): +def get_s3_credentials(user_id: str, role_session_name: str, policy: dict) -> dict: client = boto3.client("sts") arn = os.getenv("EGRESS_APP_DOWNLOAD_ROLE_INREGION_ARN") response = client.assume_role( @@ -1053,20 +1167,23 @@ def get_s3_credentials(user_id: str, role_session_name: str, policy: dict): @app.route('/s3credentialsREADME', methods=['GET']) @with_trace(context={}) -def s3credentials_readme(): +def s3credentials_readme() -> Response: + """ + Return a human readable README for how to use /s3credentials + """ return make_html_response({}, {}, 200, "s3credentials_readme.html") @app.route('/profile') @with_trace(context={}) -def profile(): +def profile() -> Response: return Response(body='Profile not available.', status_code=200, headers={}) @app.route('/pubkey', methods=['GET']) @with_trace(context={}) -def pubkey(): +def pubkey() -> Response: thebody = json.dumps({ 'rsa_pub_key': JWT_MANAGER.public_key, 'algorithm': JWT_MANAGER.algorithm