Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue #3102] Call Oauth token endpoint #3122

Merged
merged 19 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/local.env
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ LOGIN_GOV_CLIENT_ID=local_mock_client_id

LOGIN_GOV_JWK_ENDPOINT=http://host.docker.internal:5001/issuer1/jwks
LOGIN_GOV_AUTH_ENDPOINT=http://localhost:5001/issuer1/authorize
LOGIN_GOV_TOKEN_ENDPOINT=http://host.docker.internal:5001/issuer1/token
LOGIN_GOV_ENDPOINT=http://localhost:5001



LOGIN_FINAL_DESTINATION=http://localhost:8080/v1/users/login/result

# These should be set in your override.env file
Expand Down
Empty file.
Empty file.
50 changes: 50 additions & 0 deletions api/src/adapters/oauth/login_gov/login_gov_oauth_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any

import requests

from src.adapters.oauth.oauth_client import BaseOauthClient
from src.adapters.oauth.oauth_client_models import OauthTokenRequest, OauthTokenResponse
from src.auth.login_gov_jwt_auth import LoginGovConfig, get_config


class LoginGovOauthClient(BaseOauthClient):

def __init__(self, config: LoginGovConfig | None = None):
if config is None:
config = get_config()

self.config = config
self.session = self._build_session()

def _build_session(self, session: requests.Session | None = None) -> requests.Session:
"""Set things on the session that should be shared between all requests"""
if not session:
session = requests.Session()

session.headers.update({"Content-Type": "application/x-www-form-urlencoded"})

return session

def _request(self, method: str, full_url: str, **kwargs: Any) -> requests.Response:
"""Utility method for making a request with our session"""

# By default timeout after 5 seconds
if "timeout" not in kwargs:
kwargs["timeout"] = 5

return self.session.request(method, full_url, **kwargs)

def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse:
"""Query the login.gov token endpoint"""

body = {
"code": request.code,
"grant_type": request.grant_type,
# TODO https://github.com/HHS/simpler-grants-gov/issues/3103
# when we support client assertion, we need to not add the client_id
"client_id": self.config.client_id,
}

response = self._request("POST", self.config.login_gov_token_endpoint, data=body)

return OauthTokenResponse.model_validate_json(response.text)
21 changes: 21 additions & 0 deletions api/src/adapters/oauth/login_gov/mock_login_gov_oauth_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from src.adapters.oauth.oauth_client import BaseOauthClient
from src.adapters.oauth.oauth_client_models import OauthTokenRequest, OauthTokenResponse


class MockLoginGovOauthClient(BaseOauthClient):

def __init__(self) -> None:
self.responses: dict[str, OauthTokenResponse] = {}

def add_token_response(self, code: str, response: OauthTokenResponse) -> None:
self.responses[code] = response

def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse:
response = self.responses.get(request.code, None)

if response is None:
response = OauthTokenResponse(
error="error", error_description="default mock error description"
)

return response
14 changes: 14 additions & 0 deletions api/src/adapters/oauth/oauth_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import abc

from src.adapters.oauth.oauth_client_models import OauthTokenRequest, OauthTokenResponse


class BaseOauthClient(abc.ABC, metaclass=abc.ABCMeta):

@abc.abstractmethod
def get_token(self, request: OauthTokenRequest) -> OauthTokenResponse:
"""Call the POST token endpoint

See: https://developers.login.gov/oidc/token/
"""
pass
33 changes: 33 additions & 0 deletions api/src/adapters/oauth/oauth_client_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from dataclasses import dataclass

from pydantic import BaseModel


@dataclass
class OauthTokenRequest:
"""https://developers.login.gov/oidc/token/#request-parameters"""

code: str
grant_type: str = "authorization_code"

# TODO: https://github.com/HHS/simpler-grants-gov/issues/3103
# client_assertion: str | None = None
# client_assertion_type: str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer


class OauthTokenResponse(BaseModel):
"""https://developers.login.gov/oidc/token/#token-response"""

# These fields are given defaults so we don't need None-checks
# for them elsewhere, if the response didn't error, they have valid values
id_token: str = ""
access_token: str = ""
token_type: str = ""
expires_in: int = 0

# These fields are only set if the response errored
error: str | None = None
error_description: str | None = None

def is_error_response(self) -> bool:
return self.error is not None
25 changes: 10 additions & 15 deletions api/src/api/users/user_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from src.auth.login_gov_jwt_auth import get_final_redirect_uri, get_login_gov_redirect_uri
from src.db.models.user_models import UserTokenSession
from src.services.users.get_user import get_user
from src.services.users.login_gov_callback_handler import handle_login_gov_callback

logger = logging.getLogger(__name__)

Expand All @@ -36,17 +37,21 @@
@user_blueprint.get("/login")
@user_blueprint.doc(responses=[302], description=LOGIN_DESCRIPTION)
@with_login_redirect_error_handler()
def user_login() -> flask.Response:
@flask_db.with_db_session()
def user_login(db_session: db.Session) -> flask.Response:
logger.info("GET /v1/users/login")
with db_session.begin():
redirect_uri = get_login_gov_redirect_uri(db_session)

return response.redirect_response(get_login_gov_redirect_uri())
return response.redirect_response(redirect_uri)


@user_blueprint.get("/login/callback")
@user_blueprint.input(user_schemas.UserLoginGovCallbackSchema, location="query")
@user_blueprint.doc(responses=[302], hide=True)
@with_login_redirect_error_handler()
def user_login_callback(query_data: dict) -> flask.Response:
@flask_db.with_db_session()
def user_login_callback(db_session: db.Session, query_data: dict) -> flask.Response:
logger.info("GET /v1/users/login/callback")

# TODO: Do not launch with this, just keeping this here for debugging
Expand All @@ -66,18 +71,8 @@ def user_login_callback(query_data: dict) -> flask.Response:
#
# The JWT we will process is the id_token returned

#########################################
# TODO - implementation remaining
# Process the data coming back from login.gov after the redirect
## Fetch the state UUID from the DB - validate we have it

# Call the token endpoint with the code
## Need to also account for making a JWT to call login.gov (not needed locally)
## Probably want to make a "client" for easier mocking

# Process the token response from login.gov + create a token (Existing draft PR for all of this)

# Docs - see if there is a way to either describe the "return" values or consider just hiding this route and document it manually.
with db_session.begin():
handle_login_gov_callback(query_data, db_session)

# Redirect to the final location for the user
return response.redirect_response(get_final_redirect_uri("success", "abc123xyz456", False))
Expand Down
8 changes: 7 additions & 1 deletion api/src/auth/login_gov_jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import jwt
from pydantic import Field

from src.adapters import db
from src.auth.auth_errors import JwtValidationError
from src.db.models.user_models import LoginGovState
from src.util.env_config import PydanticBaseEnvConfig

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,6 +38,7 @@ class LoginGovConfig(PydanticBaseEnvConfig):
login_gov_endpoint: str = Field(alias="LOGIN_GOV_ENDPOINT")
login_gov_jwk_endpoint: str = Field(alias="LOGIN_GOV_JWK_ENDPOINT")
login_gov_auth_endpoint: str = Field(alias="LOGIN_GOV_AUTH_ENDPOINT")
login_gov_token_endpoint: str = Field(alias="LOGIN_GOV_TOKEN_ENDPOINT")

# Where we send a user after they have successfully logged in
# for now we'll always send them to the same place (a frontend page)
Expand Down Expand Up @@ -102,7 +105,7 @@ def _refresh_keys(config: LoginGovConfig) -> None:
config.public_keys = list(public_keys)


def get_login_gov_redirect_uri(config: LoginGovConfig | None = None) -> str:
def get_login_gov_redirect_uri(db_session: db.Session, config: LoginGovConfig | None = None) -> str:
if config is None:
config = get_config()

Expand All @@ -129,6 +132,9 @@ def get_login_gov_redirect_uri(config: LoginGovConfig | None = None) -> str:
}
)

# Add the state to the DB
db_session.add(LoginGovState(login_gov_state_id=state, nonce=nonce))

return f"{config.login_gov_auth_endpoint}?{encoded_params}"


Expand Down
46 changes: 46 additions & 0 deletions api/src/db/migrations/versions/2024_12_04_login_gov_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""login gov state

Revision ID: 6a23520d2c3c
Revises: 16eaca2334c9
Create Date: 2024-12-04 16:35:29.200758

"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "6a23520d2c3c"
down_revision = "16eaca2334c9"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"login_gov_state",
sa.Column("login_gov_state_id", sa.UUID(), nullable=False),
sa.Column("nonce", sa.UUID(), nullable=False),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("login_gov_state_id", name=op.f("login_gov_state_pkey")),
schema="api",
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("login_gov_state", schema="api")
# ### end Alembic commands ###
11 changes: 11 additions & 0 deletions api/src/db/models/user_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,14 @@ class UserTokenSession(ApiSchemaTable, TimestampMixin):

# When a user logs out, we set this flag to False.
is_valid: Mapped[bool] = mapped_column(default=True)


class LoginGovState(ApiSchemaTable, TimestampMixin):
"""Table used to store temporary state during the OAuth login flow"""

__tablename__ = "login_gov_state"

login_gov_state_id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True)

# https://openid.net/specs/openid-connect-core-1_0.html#NonceNotes
nonce: Mapped[uuid.UUID]
89 changes: 89 additions & 0 deletions api/src/services/users/login_gov_callback_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import logging
from dataclasses import dataclass

from pydantic import BaseModel
from sqlalchemy import select

import src.adapters.db as db
from src.adapters.oauth.login_gov.login_gov_oauth_client import LoginGovOauthClient
from src.adapters.oauth.oauth_client_models import OauthTokenRequest
from src.api.route_utils import raise_flask_error
from src.db.models.user_models import LoginGovState
from src.util.string_utils import is_valid_uuid

logger = logging.getLogger(__name__)


class CallbackParams(BaseModel):
code: str
state: str
error: str | None = None
error_description: str | None = None


@dataclass
class LoginGovCallbackResponse:
token: str
is_user_new: bool


def get_login_gov_client() -> LoginGovOauthClient:
"""Get the login.gov client, in a method to be overridable in tests"""
return LoginGovOauthClient()


def handle_login_gov_callback(query_data: dict, db_session: db.Session) -> LoginGovCallbackResponse:
"""Handle the callback from login.gov after calling the authenticate endpoint

NOTE: Any errors thrown here will actually lead to a redirect due to the
with_login_redirect_error_handler handler we have attached to the route
"""

# Process the data coming back from login.gov via the redirect query params
# see: https://developers.login.gov/oidc/authorization/#authorization-response
callback_params = CallbackParams.model_validate(query_data)

# If we got an error back in the callback, raise an exception
# The only two documented error values are access_denied and invalid_request
# which would both indicate an issue in our configuration and we'll treat as a 5xx internal error
if callback_params.error is not None:
error_message = f"{callback_params.error} {callback_params.error_description}"
raise_flask_error(500, error_message)

# If the state value we received isn't a valid UUID
# then it's likely someone randomly calling the endpoint
# We don't want this validation on the schema as it would
# occur before our error catching that handles redirects
if not is_valid_uuid(callback_params.state):
raise_flask_error(422, "Invalid OAuth state value")

login_gov_state = db_session.execute(
select(LoginGovState).where(LoginGovState.login_gov_state_id == callback_params.state)
).scalar_one_or_none()

# If we don't have the state value in our DB, that either means:
# * login.gov is very broken and sending us bad data
# * Someone called this endpoint directly with a random value
#
# There isn't a way to truly separate those here, so we'll assume the more likely second one
# and raise a 404 to say we have no idea what they passed us.
if login_gov_state is None:
raise_flask_error(404, "OAuth state not found")

# call the token endpoint (make a client)
# https://developers.login.gov/oidc/token/
# TODO: Creating a JWT with the key we gave login.gov
client = get_login_gov_client()
response = client.get_token(OauthTokenRequest(code=callback_params.code))

# If this request failed, we'll assume we're the issue and 500
# TODO - need to test with actual login.gov if there could be other scenarios
# the mock always returns something as long as the request is well-formatted
if response.is_error_response():
raise_flask_error(500, response.error_description)

# TODO: Process the token response from login.gov
# Note that a lot of this is already in https://github.com/HHS/simpler-grants-gov/pull/3004

# TODO - connect this to the above logic once implemented
return LoginGovCallbackResponse(token="abc123xyz456", is_user_new=False)
9 changes: 9 additions & 0 deletions api/src/util/string_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from typing import Optional


Expand All @@ -12,3 +13,11 @@ def join_list(joining_list: Optional[list], join_txt: str = "\n") -> str:
return ""

return join_txt.join(joining_list)


def is_valid_uuid(value: str) -> bool:
try:
uuid.UUID(value)
return True
except ValueError:
return False
Loading