Skip to content

Commit

Permalink
[Issue #3102] Call Oauth token endpoint (#3122)
Browse files Browse the repository at this point in the history
## Summary
Fixes #3102

### Time to review: __10 mins__

## Changes proposed
Added a client (and mock) for calling an OAuth token endpoint

A lot of restructuring of test utils for clearer setup

## Context for reviewers
https://developers.login.gov/oidc/token/

This gets the token from the OAuth server and parses the response. There
is more work to do on this later as login.gov requires a special JWT to
also be passed, but a basic version doesn't need that (our local mock
doesn't care), so I'll follow-up on that later.

This approach to setting up a client is following some patterns I've
used before that worked well. Building a mock version alongside the real
one helps with testing.

## Additional information
Still nothing new visually, under the hood it is just one more big step
remaining to process the token

---------

Co-authored-by: nava-platform-bot <[email protected]>
  • Loading branch information
chouinar and nava-platform-bot authored Dec 12, 2024
1 parent 7a7b8a8 commit 08c13b7
Show file tree
Hide file tree
Showing 17 changed files with 365 additions and 66 deletions.
2 changes: 2 additions & 0 deletions api/local.env
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ LOGIN_GOV_CLIENT_ID=local_mock_client_id

LOGIN_GOV_JWK_ENDPOINT=http://host.docker.internal:5001/issuer1/jwks
LOGIN_GOV_AUTH_ENDPOINT=http://localhost:5001/issuer1/authorize
LOGIN_GOV_TOKEN_ENDPOINT=http://host.docker.internal:5001/issuer1/token
LOGIN_GOV_ENDPOINT=http://localhost:5001



LOGIN_FINAL_DESTINATION=http://localhost:8080/v1/users/login/result

# These should be set in your override.env file
Expand Down
4 changes: 2 additions & 2 deletions api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Empty file.
Empty file.
50 changes: 50 additions & 0 deletions api/src/adapters/oauth/login_gov/login_gov_oauth_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any

import requests

from src.adapters.oauth.oauth_client import BaseOauthClient
from src.adapters.oauth.oauth_client_models import OauthTokenRequest, OauthTokenResponse
from src.auth.login_gov_jwt_auth import LoginGovConfig, get_config


class LoginGovOauthClient(BaseOauthClient):

def __init__(self, config: LoginGovConfig | None = None):
if config is None:
config = get_config()

self.config = config
self.session = self._build_session()

def _build_session(self, session: requests.Session | None = None) -> requests.Session:
"""Set things on the session that should be shared between all requests"""
if not session:
session = requests.Session()

session.headers.update({"Content-Type": "application/x-www-form-urlencoded"})

return session

def _request(self, method: str, full_url: str, **kwargs: Any) -> requests.Response:
"""Utility method for making a request with our session"""

# By default timeout after 5 seconds
if "timeout" not in kwargs:
kwargs["timeout"] = 5

return self.session.request(method, full_url, **kwargs)

def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse:
"""Query the login.gov token endpoint"""

body = {
"code": request.code,
"grant_type": request.grant_type,
# TODO https://github.com/HHS/simpler-grants-gov/issues/3103
# when we support client assertion, we need to not add the client_id
"client_id": self.config.client_id,
}

response = self._request("POST", self.config.login_gov_token_endpoint, data=body)

return OauthTokenResponse.model_validate_json(response.text)
21 changes: 21 additions & 0 deletions api/src/adapters/oauth/login_gov/mock_login_gov_oauth_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from src.adapters.oauth.oauth_client import BaseOauthClient
from src.adapters.oauth.oauth_client_models import OauthTokenRequest, OauthTokenResponse


class MockLoginGovOauthClient(BaseOauthClient):

def __init__(self) -> None:
self.responses: dict[str, OauthTokenResponse] = {}

def add_token_response(self, code: str, response: OauthTokenResponse) -> None:
self.responses[code] = response

def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse:
response = self.responses.get(request.code, None)

if response is None:
response = OauthTokenResponse(
error="error", error_description="default mock error description"
)

return response
14 changes: 14 additions & 0 deletions api/src/adapters/oauth/oauth_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import abc

from src.adapters.oauth.oauth_client_models import OauthTokenRequest, OauthTokenResponse


class BaseOauthClient(abc.ABC, metaclass=abc.ABCMeta):

@abc.abstractmethod
def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse:
"""Call the POST token endpoint
See: https://developers.login.gov/oidc/token/
"""
pass
33 changes: 33 additions & 0 deletions api/src/adapters/oauth/oauth_client_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from dataclasses import dataclass

from pydantic import BaseModel


@dataclass
class OauthTokenRequest:
"""https://developers.login.gov/oidc/token/#request-parameters"""

code: str
grant_type: str = "authorization_code"

# TODO: https://github.com/HHS/simpler-grants-gov/issues/3103
# client_assertion: str | None = None
# client_assertion_type: str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer


class OauthTokenResponse(BaseModel):
"""https://developers.login.gov/oidc/token/#token-response"""

# These fields are given defaults so we don't need None-checks
# for them elsewhere, if the response didn't error, they have valid values
id_token: str = ""
access_token: str = ""
token_type: str = ""
expires_in: int = 0

# These fields are only set if the response errored
error: str | None = None
error_description: str | None = None

def is_error_response(self) -> bool:
return self.error is not None
1 change: 1 addition & 0 deletions api/src/auth/login_gov_jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class LoginGovConfig(PydanticBaseEnvConfig):
login_gov_endpoint: str = Field(alias="LOGIN_GOV_ENDPOINT")
login_gov_jwk_endpoint: str = Field(alias="LOGIN_GOV_JWK_ENDPOINT")
login_gov_auth_endpoint: str = Field(alias="LOGIN_GOV_AUTH_ENDPOINT")
login_gov_token_endpoint: str = Field(alias="LOGIN_GOV_TOKEN_ENDPOINT")

# Where we send a user after they have successfully logged in
# for now we'll always send them to the same place (a frontend page)
Expand Down
21 changes: 17 additions & 4 deletions api/src/services/users/login_gov_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from sqlalchemy import select

import src.adapters.db as db
from src.adapters.oauth.login_gov.login_gov_oauth_client import LoginGovOauthClient
from src.adapters.oauth.oauth_client_models import OauthTokenRequest
from src.api.route_utils import raise_flask_error
from src.db.models.user_models import LoginGovState
from src.util.string_utils import is_valid_uuid
Expand All @@ -25,6 +27,11 @@ class LoginGovCallbackResponse:
is_user_new: bool


def get_login_gov_client() -> LoginGovOauthClient:
"""Get the login.gov client, in a method to be overridable in tests"""
return LoginGovOauthClient()


def handle_login_gov_callback(query_data: dict, db_session: db.Session) -> LoginGovCallbackResponse:
"""Handle the callback from login.gov after calling the authenticate endpoint
Expand Down Expand Up @@ -63,11 +70,17 @@ def handle_login_gov_callback(query_data: dict, db_session: db.Session) -> Login
if login_gov_state is None:
raise_flask_error(404, "OAuth state not found")

# TODO: Call the token endpoint (make a client)
# call the token endpoint (make a client)
# https://developers.login.gov/oidc/token/
# * Request building/call
# * Creating a JWT with the key we gave login.gov
# * Handling the response
# TODO: Creating a JWT with the key we gave login.gov
client = get_login_gov_client()
response = client.get_token(OauthTokenRequest(code=callback_params.code))

# If this request failed, we'll assume we're the issue and 500
# TODO - need to test with actual login.gov if there could be other scenarios
# the mock always returns something as long as the request is well-formatted
if response.is_error_response():
raise_flask_error(500, response.error_description)

# TODO: Process the token response from login.gov
# Note that a lot of this is already in https://github.com/HHS/simpler-grants-gov/pull/3004
Expand Down
66 changes: 10 additions & 56 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import os
import pathlib
import urllib
import uuid

import _pytest.monkeypatch
Expand All @@ -18,6 +17,7 @@
import src.app as app_entry
import tests.src.db.models.factories as factories
from src.adapters import search
from src.adapters.oauth.login_gov.mock_login_gov_oauth_client import MockLoginGovOauthClient
from src.constants.schema import Schemas
from src.db import models
from src.db.models.foreign import metadata as foreign_metadata
Expand All @@ -26,6 +26,7 @@
from src.db.models.staging import metadata as staging_metadata
from src.util.local import load_local_env_vars
from tests.lib import db_testing
from tests.lib.auth_test_utils import mock_oauth_endpoint

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -268,72 +269,25 @@ def other_rsa_key_pair():
####################


def mock_jwk_endpoint(app):
@app.get("/test-endpoint/jwk")
def jwk_endpoint():
response = {
"keys": [
{
"alg": "RS256",
"use": "sig",
"kty": "RSA",
"n": "test_abc123",
"e": "AQAB",
"kid": "xyz123",
}
]
}

return flask.jsonify(response)


def oauth_param_override():
"""Override endpoint called in the mock authorize endpoint setup below.
To override you can do the following in your test:
def override():
return {"error": "access_denied"}
monkeypatch.setattr("tests.conftest.oauth_param_override", override)
"""
return {}


def mock_oauth_endpoint(app):
# Adds a mock oauth endpoint to the app
# itself for auth purposes

@app.get("/test-endpoint/oauth-authorize")
def oauth_authorize():
# This endpoint represents a mocked version of
# https://developers.login.gov/oidc/authorization/
# and needs to return the state value as well as a code.
query_args = flask.request.args

params = {"state": query_args.get("state"), "code": str(uuid.uuid4())}
params.update(oauth_param_override())
encoded_params = urllib.parse.urlencode(params)

redirect_uri = f"{query_args['redirect_uri']}?{encoded_params}"

return flask.redirect(redirect_uri)
@pytest.fixture(scope="session")
def mock_oauth_client():
return MockLoginGovOauthClient()


# Make app session scoped so the database connection pool is only created once
# for the test session. This speeds up the tests.
@pytest.fixture(scope="session")
def app(db_client, opportunity_index_alias, monkeypatch_session, public_rsa_key) -> APIFlask:
def app(
db_client, opportunity_index_alias, monkeypatch_session, private_rsa_key, mock_oauth_client
) -> APIFlask:
# Override the OAuth endpoint path before creating the app which loads the config at startup
monkeypatch_session.setenv(
"LOGIN_GOV_AUTH_ENDPOINT", "http://localhost:8080/test-endpoint/oauth-authorize"
)

# Create the app
app = app_entry.create_app()

# Add the endpoint to the app
mock_oauth_endpoint(app)
# Add endpoints and mocks for handling the external OAuth logic
mock_oauth_endpoint(app, monkeypatch_session, private_rsa_key, mock_oauth_client)

return app

Expand Down
Loading

0 comments on commit 08c13b7

Please sign in to comment.