diff --git a/api/local.env b/api/local.env index 45f1065a6..96ef5b0a4 100644 --- a/api/local.env +++ b/api/local.env @@ -45,9 +45,11 @@ LOGIN_GOV_CLIENT_ID=local_mock_client_id LOGIN_GOV_JWK_ENDPOINT=http://host.docker.internal:5001/issuer1/jwks LOGIN_GOV_AUTH_ENDPOINT=http://localhost:5001/issuer1/authorize +LOGIN_GOV_TOKEN_ENDPOINT=http://host.docker.internal:5001/issuer1/token LOGIN_GOV_ENDPOINT=http://localhost:5001 + LOGIN_FINAL_DESTINATION=http://localhost:8080/v1/users/login/result # These should be set in your override.env file diff --git a/api/poetry.lock b/api/poetry.lock index 98116ff22..9412d487e 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "alembic" @@ -2293,4 +2293,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "~3.13" -content-hash = "f326e54890c37c7f0564972089f18963b6c89e647a12959f75cb442b057af645" +content-hash = "e967e5a513ccfa475e80592e446c7b07dc8eaa7bdf76bf3c23378f81c1638fdf" diff --git a/api/src/adapters/oauth/__init__.py b/api/src/adapters/oauth/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/src/adapters/oauth/login_gov/__init__.py b/api/src/adapters/oauth/login_gov/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/src/adapters/oauth/login_gov/login_gov_oauth_client.py b/api/src/adapters/oauth/login_gov/login_gov_oauth_client.py new file mode 100644 index 000000000..5fc782a51 --- /dev/null +++ b/api/src/adapters/oauth/login_gov/login_gov_oauth_client.py @@ -0,0 +1,50 @@ +from typing import Any + +import requests + +from src.adapters.oauth.oauth_client import BaseOauthClient +from src.adapters.oauth.oauth_client_models import OauthTokenRequest, OauthTokenResponse +from src.auth.login_gov_jwt_auth import LoginGovConfig, get_config + + +class LoginGovOauthClient(BaseOauthClient): + + def __init__(self, config: LoginGovConfig | None = None): + if config is None: + config = get_config() + + self.config = config + self.session = self._build_session() + + def _build_session(self, session: requests.Session | None = None) -> requests.Session: + """Set things on the session that should be shared between all requests""" + if not session: + session = requests.Session() + + session.headers.update({"Content-Type": "application/x-www-form-urlencoded"}) + + return session + + def _request(self, method: str, full_url: str, **kwargs: Any) -> requests.Response: + """Utility method for making a request with our session""" + + # By default timeout after 5 seconds + if "timeout" not in kwargs: + kwargs["timeout"] = 5 + + return self.session.request(method, full_url, **kwargs) + + def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse: + """Query the login.gov token endpoint""" + + body = { + "code": request.code, + "grant_type": request.grant_type, + # TODO https://github.com/HHS/simpler-grants-gov/issues/3103 + # when we support client assertion, we need to not add the client_id + "client_id": self.config.client_id, + } + + response = self._request("POST", self.config.login_gov_token_endpoint, data=body) + + return OauthTokenResponse.model_validate_json(response.text) diff --git a/api/src/adapters/oauth/login_gov/mock_login_gov_oauth_client.py b/api/src/adapters/oauth/login_gov/mock_login_gov_oauth_client.py new file mode 100644 index 000000000..e437b8a5a --- /dev/null +++ b/api/src/adapters/oauth/login_gov/mock_login_gov_oauth_client.py @@ -0,0 +1,21 @@ +from src.adapters.oauth.oauth_client import BaseOauthClient +from src.adapters.oauth.oauth_client_models import OauthTokenRequest, OauthTokenResponse + + +class MockLoginGovOauthClient(BaseOauthClient): + + def __init__(self) -> None: + self.responses: dict[str, OauthTokenResponse] = {} + + def add_token_response(self, code: str, response: OauthTokenResponse) -> None: + self.responses[code] = response + + def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse: + response = self.responses.get(request.code, None) + + if response is None: + response = OauthTokenResponse( + error="error", error_description="default mock error description" + ) + + return response diff --git a/api/src/adapters/oauth/oauth_client.py b/api/src/adapters/oauth/oauth_client.py new file mode 100644 index 000000000..6fe06309f --- /dev/null +++ b/api/src/adapters/oauth/oauth_client.py @@ -0,0 +1,14 @@ +import abc + +from src.adapters.oauth.oauth_client_models import OauthTokenRequest, OauthTokenResponse + + +class BaseOauthClient(abc.ABC, metaclass=abc.ABCMeta): + + @abc.abstractmethod + def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse: + """Call the POST token endpoint + + See: https://developers.login.gov/oidc/token/ + """ + pass diff --git a/api/src/adapters/oauth/oauth_client_models.py b/api/src/adapters/oauth/oauth_client_models.py new file mode 100644 index 000000000..e684ad719 --- /dev/null +++ b/api/src/adapters/oauth/oauth_client_models.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass + +from pydantic import BaseModel + + +@dataclass +class OauthTokenRequest: + """https://developers.login.gov/oidc/token/#request-parameters""" + + code: str + grant_type: str = "authorization_code" + + # TODO: https://github.com/HHS/simpler-grants-gov/issues/3103 + # client_assertion: str | None = None + # client_assertion_type: str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer + + +class OauthTokenResponse(BaseModel): + """https://developers.login.gov/oidc/token/#token-response""" + + # These fields are given defaults so we don't need None-checks + # for them elsewhere, if the response didn't error, they have valid values + id_token: str = "" + access_token: str = "" + token_type: str = "" + expires_in: int = 0 + + # These fields are only set if the response errored + error: str | None = None + error_description: str | None = None + + def is_error_response(self) -> bool: + return self.error is not None diff --git a/api/src/auth/login_gov_jwt_auth.py b/api/src/auth/login_gov_jwt_auth.py index 731c9b9dd..212f85b35 100644 --- a/api/src/auth/login_gov_jwt_auth.py +++ b/api/src/auth/login_gov_jwt_auth.py @@ -38,6 +38,7 @@ class LoginGovConfig(PydanticBaseEnvConfig): login_gov_endpoint: str = Field(alias="LOGIN_GOV_ENDPOINT") login_gov_jwk_endpoint: str = Field(alias="LOGIN_GOV_JWK_ENDPOINT") login_gov_auth_endpoint: str = Field(alias="LOGIN_GOV_AUTH_ENDPOINT") + login_gov_token_endpoint: str = Field(alias="LOGIN_GOV_TOKEN_ENDPOINT") # Where we send a user after they have successfully logged in # for now we'll always send them to the same place (a frontend page) diff --git a/api/src/services/users/login_gov_callback_handler.py b/api/src/services/users/login_gov_callback_handler.py index 56c99b96e..53734f198 100644 --- a/api/src/services/users/login_gov_callback_handler.py +++ b/api/src/services/users/login_gov_callback_handler.py @@ -5,6 +5,8 @@ from sqlalchemy import select import src.adapters.db as db +from src.adapters.oauth.login_gov.login_gov_oauth_client import LoginGovOauthClient +from src.adapters.oauth.oauth_client_models import OauthTokenRequest from src.api.route_utils import raise_flask_error from src.db.models.user_models import LoginGovState from src.util.string_utils import is_valid_uuid @@ -25,6 +27,11 @@ class LoginGovCallbackResponse: is_user_new: bool +def get_login_gov_client() -> LoginGovOauthClient: + """Get the login.gov client, in a method to be overridable in tests""" + return LoginGovOauthClient() + + def handle_login_gov_callback(query_data: dict, db_session: db.Session) -> LoginGovCallbackResponse: """Handle the callback from login.gov after calling the authenticate endpoint @@ -63,11 +70,17 @@ def handle_login_gov_callback(query_data: dict, db_session: db.Session) -> Login if login_gov_state is None: raise_flask_error(404, "OAuth state not found") - # TODO: Call the token endpoint (make a client) + # call the token endpoint (make a client) # https://developers.login.gov/oidc/token/ - # * Request building/call - # * Creating a JWT with the key we gave login.gov - # * Handling the response + # TODO: Creating a JWT with the key we gave login.gov + client = get_login_gov_client() + response = client.get_token(OauthTokenRequest(code=callback_params.code)) + + # If this request failed, we'll assume we're the issue and 500 + # TODO - need to test with actual login.gov if there could be other scenarios + # the mock always returns something as long as the request is well-formatted + if response.is_error_response(): + raise_flask_error(500, response.error_description) # TODO: Process the token response from login.gov # Note that a lot of this is already in https://github.com/HHS/simpler-grants-gov/pull/3004 diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 07670eddb..12224e868 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -1,7 +1,6 @@ import logging import os import pathlib -import urllib import uuid import _pytest.monkeypatch @@ -18,6 +17,7 @@ import src.app as app_entry import tests.src.db.models.factories as factories from src.adapters import search +from src.adapters.oauth.login_gov.mock_login_gov_oauth_client import MockLoginGovOauthClient from src.constants.schema import Schemas from src.db import models from src.db.models.foreign import metadata as foreign_metadata @@ -26,6 +26,7 @@ from src.db.models.staging import metadata as staging_metadata from src.util.local import load_local_env_vars from tests.lib import db_testing +from tests.lib.auth_test_utils import mock_oauth_endpoint logger = logging.getLogger(__name__) @@ -262,72 +263,25 @@ def other_rsa_key_pair(): #################### -def mock_jwk_endpoint(app): - @app.get("/test-endpoint/jwk") - def jwk_endpoint(): - response = { - "keys": [ - { - "alg": "RS256", - "use": "sig", - "kty": "RSA", - "n": "test_abc123", - "e": "AQAB", - "kid": "xyz123", - } - ] - } - - return flask.jsonify(response) - - -def oauth_param_override(): - """Override endpoint called in the mock authorize endpoint setup below. - - To override you can do the following in your test: - - def override(): - return {"error": "access_denied"} - - monkeypatch.setattr("tests.conftest.oauth_param_override", override) - """ - return {} - - -def mock_oauth_endpoint(app): - # Adds a mock oauth endpoint to the app - # itself for auth purposes - - @app.get("/test-endpoint/oauth-authorize") - def oauth_authorize(): - # This endpoint represents a mocked version of - # https://developers.login.gov/oidc/authorization/ - # and needs to return the state value as well as a code. - query_args = flask.request.args - - params = {"state": query_args.get("state"), "code": str(uuid.uuid4())} - params.update(oauth_param_override()) - encoded_params = urllib.parse.urlencode(params) - - redirect_uri = f"{query_args['redirect_uri']}?{encoded_params}" - - return flask.redirect(redirect_uri) +@pytest.fixture(scope="session") +def mock_oauth_client(): + return MockLoginGovOauthClient() # Make app session scoped so the database connection pool is only created once # for the test session. This speeds up the tests. @pytest.fixture(scope="session") -def app(db_client, opportunity_index_alias, monkeypatch_session, public_rsa_key) -> APIFlask: +def app( + db_client, opportunity_index_alias, monkeypatch_session, private_rsa_key, mock_oauth_client +) -> APIFlask: # Override the OAuth endpoint path before creating the app which loads the config at startup monkeypatch_session.setenv( "LOGIN_GOV_AUTH_ENDPOINT", "http://localhost:8080/test-endpoint/oauth-authorize" ) - - # Create the app app = app_entry.create_app() - # Add the endpoint to the app - mock_oauth_endpoint(app) + # Add endpoints and mocks for handling the external OAuth logic + mock_oauth_endpoint(app, monkeypatch_session, private_rsa_key, mock_oauth_client) return app diff --git a/api/tests/lib/auth_test_utils.py b/api/tests/lib/auth_test_utils.py new file mode 100644 index 000000000..d224bfe78 --- /dev/null +++ b/api/tests/lib/auth_test_utils.py @@ -0,0 +1,113 @@ +"""Utilities for creating and working with auth features in tests""" + +import urllib +import uuid +from datetime import datetime, timedelta + +import flask +import jwt + +from src.adapters.oauth.oauth_client_models import OauthTokenResponse + + +def create_jwt( + user_id: str, + private_key: str | bytes, + email: str = "fake_mail@mail.com", + nonce: str = "abc123", + expires_at: datetime | None = None, + issued_at: datetime | None = None, + not_before: datetime | None = None, + # Note that these values need to match what we set + # in conftest.py::setup_login_gov_auth + issuer: str = "http://localhost:3000", + audience: str = "AUDIENCE_TEST", +): + """Create a JWT in roughly the format login.gov will give us""" + + # Default datetime values are set to clearly not be an issue + if expires_at is None: + expires_at = datetime.now() + timedelta(days=365) + if issued_at is None: + issued_at = datetime.now() - timedelta(days=365) + if not_before is None: + not_before = datetime.now() - timedelta(days=365) + + payload = { + "sub": user_id, + "iss": issuer, + "aud": audience, + "email": email, + "nonce": nonce, + # The jwt encode function automatically turns these datetime + # objects into a UTC timestamp integer + "exp": expires_at, + "iat": issued_at, + "nbf": not_before, + # These values aren't checked by anything at the moment + # but are a part of the token from login.gov + "jti": "abc123", + "at_hash": "abc123", + "c_hash": "abc123", + "acr": "urn:acr.login.gov:auth-only", + } + + return jwt.encode(payload, private_key, algorithm="RS256") + + +def oauth_param_override(): + """Override endpoint called in the mock authorize endpoint setup below. + + To override you can do the following in your test: + + def override(): + return {"error": "access_denied"} + + monkeypatch.setattr("tests.lib.auth_test_utils.oauth_param_override", override) + """ + return {} + + +def mock_oauth_endpoint(app, monkeypatch, private_key, mock_oauth_client): + """Add mock oauth endpoints to the application + + For the initial authorize endpoint, we create an endpoint on the app itself + which redirects back to the configured redirect_uri and also sets up the + mock_oauth_client to have a successful response when calling it later for a token. + """ + + @app.get("/test-endpoint/oauth-authorize") + def oauth_authorize(): + # This endpoint represents a mocked version of + # https://developers.login.gov/oidc/authorization/ + # and needs to return the state value as well as a code. + query_args = flask.request.args + + params = {"state": query_args.get("state"), "code": str(uuid.uuid4())} + params.update(oauth_param_override()) + + # Add a dummy response we'll later get if the token endpoint is called + id_token = create_jwt( + user_id=query_args.get("state"), # Re-use the state as the user ID + private_key=private_key, + nonce=query_args.get("nonce"), + ) + mocked_token_response = OauthTokenResponse( + id_token=id_token, access_token="fake_token", token_type="Bearer", expires_in=300 + ) + mock_oauth_client.add_token_response(params["code"], mocked_token_response) + + encoded_params = urllib.parse.urlencode(params) + + redirect_uri = f"{query_args['redirect_uri']}?{encoded_params}" + + return flask.redirect(redirect_uri) + + # Override our callback endpoint to use this mocked client instead of the real one + def override_get_client(): + """Override the login_gov client we use in unit tests to be the mock version""" + return mock_oauth_client + + monkeypatch.setattr( + "src.services.users.login_gov_callback_handler.get_login_gov_client", override_get_client + ) diff --git a/api/tests/src/adapters/oauth/__init__.py b/api/tests/src/adapters/oauth/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/src/adapters/oauth/login_gov/__init__.py b/api/tests/src/adapters/oauth/login_gov/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/src/adapters/oauth/login_gov/test_login_gov_oauth_client.py b/api/tests/src/adapters/oauth/login_gov/test_login_gov_oauth_client.py new file mode 100644 index 000000000..c048f5599 --- /dev/null +++ b/api/tests/src/adapters/oauth/login_gov/test_login_gov_oauth_client.py @@ -0,0 +1,54 @@ +import json + +import requests + +from src.adapters.oauth.login_gov.login_gov_oauth_client import LoginGovOauthClient +from src.adapters.oauth.oauth_client_models import OauthTokenRequest +from src.auth.login_gov_jwt_auth import LoginGovConfig + + +def mock_response(monkeypatch, mocked_response: dict): + def mock_post(*args, **kwargs): + response = requests.Response() + # _content is fetched by the text method which we use when deserializing + response._content = bytes(json.dumps(mocked_response), "utf-8") + return response + + monkeypatch.setattr("requests.Session.request", mock_post) + + +def test_get_token(monkeypatch): + + mock_response( + monkeypatch, + {"id_token": "abc123", "access_token": "xyz456", "token_type": "Bearer", "expires_in": 300}, + ) + + client = LoginGovOauthClient(LoginGovConfig()) + resp = client.get_token(OauthTokenRequest(code="abc123")) + + assert resp.id_token == "abc123" + assert resp.access_token == "xyz456" + assert resp.token_type == "Bearer" + assert resp.expires_in == 300 + assert resp.error is None + assert resp.error_description is None + assert resp.is_error_response() is False + + +def test_get_token_error(monkeypatch): + mock_response( + monkeypatch, + {"error": "invalid_request", "error_description": "missing required parameter grant_type"}, + ) + + client = LoginGovOauthClient(LoginGovConfig()) + resp = client.get_token(OauthTokenRequest(code="abc123")) + + assert resp.id_token == "" + assert resp.access_token == "" + assert resp.token_type == "" + assert resp.expires_in == 0 + assert resp.error == "invalid_request" + assert resp.error_description == "missing required parameter grant_type" + assert resp.is_error_response() is True diff --git a/api/tests/src/api/users/test_user_route_login.py b/api/tests/src/api/users/test_user_route_login.py index 3b072c191..bd52965a9 100644 --- a/api/tests/src/api/users/test_user_route_login.py +++ b/api/tests/src/api/users/test_user_route_login.py @@ -3,6 +3,7 @@ import src.auth.login_gov_jwt_auth as login_gov_jwt_auth from src.api.route_utils import raise_flask_error +from tests.src.db.models.factories import LoginGovStateFactory # To help illustrate what we are testing, here is a diagram # @@ -17,11 +18,19 @@ # redirects to / # |---------------------- # .________V________. .____________________. -# | | redirects to | | -# | /login/callback | ------------->| /login/result | +# | | calls | | +# | /login/callback | ------------->| /token | # |_________________| |____________________| +# | +# | +# | .____________________. +# | redirects to | | +# -----------------------> | /login/result | +# |____________________| # -# TODO - this'll be more complex when I add the calls to the token endpoint +# For testing, we create an oauth-authorize endpoint that redirects back +# and sets up a few basic pieces of information on the Oauth side that later +# can be picked up in the token endpoint. def test_user_login_flow_302(client): @@ -31,6 +40,8 @@ def test_user_login_flow_302(client): login_gov_config = login_gov_jwt_auth.get_config() resp = client.get("/v1/users/login", follow_redirects=True) + print(resp.history) + # The final endpoint returns a 200 # and dumps the params it was called with assert resp.status_code == 200 @@ -175,7 +186,7 @@ def test_user_login_flow_error_in_auth_response_302(client, monkeypatch): def override(): return {"error": "access_denied", "error_description": "user does not have access"} - monkeypatch.setattr("tests.conftest.oauth_param_override", override) + monkeypatch.setattr("tests.lib.auth_test_utils.oauth_param_override", override) resp = client.get("/v1/users/login", follow_redirects=True) @@ -219,3 +230,28 @@ def test_user_callback_invalid_state_302(client, monkeypatch): assert resp_json["message"] == "error" assert resp_json["error_description"] == "Invalid OAuth state value" + + +def test_user_callback_error_in_token_302(client, enable_factory_create, caplog): + """Test behavior when we call the callback endpoint, but the oauth token endpoint has nothing""" + + # Create state so the callback gets past the check + login_gov_state = LoginGovStateFactory.create() + + resp = client.get( + f"/v1/users/login/callback?state={login_gov_state.login_gov_state_id}&code=xyz456", + follow_redirects=True, + ) + + # The final endpoint returns a 200 even when erroring as it is just a GET endpoint + assert resp.status_code == 200 + resp_json = resp.get_json() + + assert resp_json["message"] == "error" + assert resp_json["error_description"] == "internal error" + + # Verify it errored because of the response from token Oauth + assert ( + "Unexpected error occurred in login flow via raise_flask_error: default mock error description" + in caplog.messages + ) diff --git a/api/tests/src/db/models/factories.py b/api/tests/src/db/models/factories.py index 13c94bba8..e29a491ed 100644 --- a/api/tests/src/db/models/factories.py +++ b/api/tests/src/db/models/factories.py @@ -1956,6 +1956,14 @@ class Meta: email = factory.Faker("email") +class LoginGovStateFactory(BaseFactory): + class Meta: + model = user_models.LoginGovState + + login_gov_state_id = Generators.UuidObj + nonce = Generators.UuidObj + + class ExtractMetadataFactory(BaseFactory): class Meta: model = extract_models.ExtractMetadata