Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue #3102] Call Oauth token endpoint #3122

Merged
merged 19 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/local.env
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ LOGIN_GOV_CLIENT_ID=local_mock_client_id

LOGIN_GOV_JWK_ENDPOINT=http://host.docker.internal:5001/issuer1/jwks
LOGIN_GOV_AUTH_ENDPOINT=http://localhost:5001/issuer1/authorize
LOGIN_GOV_TOKEN_ENDPOINT=http://host.docker.internal:5001/issuer1/token
LOGIN_GOV_ENDPOINT=http://localhost:5001



LOGIN_FINAL_DESTINATION=http://localhost:8080/v1/users/login/result

# These should be set in your override.env file
Expand Down
4 changes: 2 additions & 2 deletions api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Empty file.
Empty file.
50 changes: 50 additions & 0 deletions api/src/adapters/oauth/login_gov/login_gov_oauth_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any

import requests

from src.adapters.oauth.oauth_client import BaseOauthClient
from src.adapters.oauth.oauth_client_models import OauthTokenRequest, OauthTokenResponse
from src.auth.login_gov_jwt_auth import LoginGovConfig, get_config


class LoginGovOauthClient(BaseOauthClient):

def __init__(self, config: LoginGovConfig | None = None):
if config is None:
config = get_config()

self.config = config
self.session = self._build_session()

def _build_session(self, session: requests.Session | None = None) -> requests.Session:
"""Set things on the session that should be shared between all requests"""
if not session:
session = requests.Session()

session.headers.update({"Content-Type": "application/x-www-form-urlencoded"})

return session

def _request(self, method: str, full_url: str, **kwargs: Any) -> requests.Response:
"""Utility method for making a request with our session"""

# By default timeout after 5 seconds
if "timeout" not in kwargs:
kwargs["timeout"] = 5

return self.session.request(method, full_url, **kwargs)

def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse:
"""Query the login.gov token endpoint"""

body = {
"code": request.code,
"grant_type": request.grant_type,
# TODO https://github.com/HHS/simpler-grants-gov/issues/3103
# when we support client assertion, we need to not add the client_id
"client_id": self.config.client_id,
}

response = self._request("POST", self.config.login_gov_token_endpoint, data=body)

return OauthTokenResponse.model_validate_json(response.text)
21 changes: 21 additions & 0 deletions api/src/adapters/oauth/login_gov/mock_login_gov_oauth_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from src.adapters.oauth.oauth_client import BaseOauthClient
from src.adapters.oauth.oauth_client_models import OauthTokenRequest, OauthTokenResponse


class MockLoginGovOauthClient(BaseOauthClient):

def __init__(self) -> None:
self.responses: dict[str, OauthTokenResponse] = {}

def add_token_response(self, code: str, response: OauthTokenResponse) -> None:
self.responses[code] = response

def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse:
response = self.responses.get(request.code, None)

if response is None:
response = OauthTokenResponse(
error="error", error_description="default mock error description"
)

return response
14 changes: 14 additions & 0 deletions api/src/adapters/oauth/oauth_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import abc

from src.adapters.oauth.oauth_client_models import OauthTokenRequest, OauthTokenResponse


class BaseOauthClient(abc.ABC, metaclass=abc.ABCMeta):

@abc.abstractmethod
def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse:
"""Call the POST token endpoint

See: https://developers.login.gov/oidc/token/
"""
pass
33 changes: 33 additions & 0 deletions api/src/adapters/oauth/oauth_client_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from dataclasses import dataclass

from pydantic import BaseModel


@dataclass
class OauthTokenRequest:
"""https://developers.login.gov/oidc/token/#request-parameters"""

code: str
grant_type: str = "authorization_code"

# TODO: https://github.com/HHS/simpler-grants-gov/issues/3103
# client_assertion: str | None = None
# client_assertion_type: str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer


class OauthTokenResponse(BaseModel):
"""https://developers.login.gov/oidc/token/#token-response"""

# These fields are given defaults so we don't need None-checks
# for them elsewhere, if the response didn't error, they have valid values
id_token: str = ""
access_token: str = ""
token_type: str = ""
expires_in: int = 0

# These fields are only set if the response errored
error: str | None = None
error_description: str | None = None

def is_error_response(self) -> bool:
return self.error is not None
1 change: 1 addition & 0 deletions api/src/auth/login_gov_jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class LoginGovConfig(PydanticBaseEnvConfig):
login_gov_endpoint: str = Field(alias="LOGIN_GOV_ENDPOINT")
login_gov_jwk_endpoint: str = Field(alias="LOGIN_GOV_JWK_ENDPOINT")
login_gov_auth_endpoint: str = Field(alias="LOGIN_GOV_AUTH_ENDPOINT")
login_gov_token_endpoint: str = Field(alias="LOGIN_GOV_TOKEN_ENDPOINT")

# Where we send a user after they have successfully logged in
# for now we'll always send them to the same place (a frontend page)
Expand Down
21 changes: 17 additions & 4 deletions api/src/services/users/login_gov_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from sqlalchemy import select

import src.adapters.db as db
from src.adapters.oauth.login_gov.login_gov_oauth_client import LoginGovOauthClient
from src.adapters.oauth.oauth_client_models import OauthTokenRequest
from src.api.route_utils import raise_flask_error
from src.db.models.user_models import LoginGovState
from src.util.string_utils import is_valid_uuid
Expand All @@ -25,6 +27,11 @@ class LoginGovCallbackResponse:
is_user_new: bool


def get_login_gov_client() -> LoginGovOauthClient:
"""Get the login.gov client, in a method to be overridable in tests"""
return LoginGovOauthClient()


def handle_login_gov_callback(query_data: dict, db_session: db.Session) -> LoginGovCallbackResponse:
"""Handle the callback from login.gov after calling the authenticate endpoint

Expand Down Expand Up @@ -63,11 +70,17 @@ def handle_login_gov_callback(query_data: dict, db_session: db.Session) -> Login
if login_gov_state is None:
raise_flask_error(404, "OAuth state not found")

# TODO: Call the token endpoint (make a client)
# call the token endpoint (make a client)
# https://developers.login.gov/oidc/token/
# * Request building/call
# * Creating a JWT with the key we gave login.gov
# * Handling the response
# TODO: Creating a JWT with the key we gave login.gov
client = get_login_gov_client()
response = client.get_token(OauthTokenRequest(code=callback_params.code))

# If this request failed, we'll assume we're the issue and 500
# TODO - need to test with actual login.gov if there could be other scenarios
# the mock always returns something as long as the request is well-formatted
if response.is_error_response():
raise_flask_error(500, response.error_description)

# TODO: Process the token response from login.gov
# Note that a lot of this is already in https://github.com/HHS/simpler-grants-gov/pull/3004
Expand Down
66 changes: 10 additions & 56 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import os
import pathlib
import urllib
import uuid

import _pytest.monkeypatch
Expand All @@ -18,6 +17,7 @@
import src.app as app_entry
import tests.src.db.models.factories as factories
from src.adapters import search
from src.adapters.oauth.login_gov.mock_login_gov_oauth_client import MockLoginGovOauthClient
from src.constants.schema import Schemas
from src.db import models
from src.db.models.foreign import metadata as foreign_metadata
Expand All @@ -26,6 +26,7 @@
from src.db.models.staging import metadata as staging_metadata
from src.util.local import load_local_env_vars
from tests.lib import db_testing
from tests.lib.auth_test_utils import mock_oauth_endpoint

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -262,72 +263,25 @@ def other_rsa_key_pair():
####################


def mock_jwk_endpoint(app):
@app.get("/test-endpoint/jwk")
def jwk_endpoint():
response = {
"keys": [
{
"alg": "RS256",
"use": "sig",
"kty": "RSA",
"n": "test_abc123",
"e": "AQAB",
"kid": "xyz123",
}
]
}

return flask.jsonify(response)


def oauth_param_override():
"""Override endpoint called in the mock authorize endpoint setup below.

To override you can do the following in your test:

def override():
return {"error": "access_denied"}

monkeypatch.setattr("tests.conftest.oauth_param_override", override)
"""
return {}


def mock_oauth_endpoint(app):
# Adds a mock oauth endpoint to the app
# itself for auth purposes

@app.get("/test-endpoint/oauth-authorize")
def oauth_authorize():
# This endpoint represents a mocked version of
# https://developers.login.gov/oidc/authorization/
# and needs to return the state value as well as a code.
query_args = flask.request.args

params = {"state": query_args.get("state"), "code": str(uuid.uuid4())}
params.update(oauth_param_override())
encoded_params = urllib.parse.urlencode(params)

redirect_uri = f"{query_args['redirect_uri']}?{encoded_params}"

return flask.redirect(redirect_uri)
@pytest.fixture(scope="session")
def mock_oauth_client():
return MockLoginGovOauthClient()


# Make app session scoped so the database connection pool is only created once
# for the test session. This speeds up the tests.
@pytest.fixture(scope="session")
def app(db_client, opportunity_index_alias, monkeypatch_session, public_rsa_key) -> APIFlask:
def app(
db_client, opportunity_index_alias, monkeypatch_session, private_rsa_key, mock_oauth_client
) -> APIFlask:
# Override the OAuth endpoint path before creating the app which loads the config at startup
monkeypatch_session.setenv(
"LOGIN_GOV_AUTH_ENDPOINT", "http://localhost:8080/test-endpoint/oauth-authorize"
)

# Create the app
app = app_entry.create_app()

# Add the endpoint to the app
mock_oauth_endpoint(app)
# Add endpoints and mocks for handling the external OAuth logic
mock_oauth_endpoint(app, monkeypatch_session, private_rsa_key, mock_oauth_client)

return app

Expand Down
Loading