diff --git a/api/src/auth/api_jwt_auth.py b/api/src/auth/api_jwt_auth.py new file mode 100644 index 000000000..14cf2c8c5 --- /dev/null +++ b/api/src/auth/api_jwt_auth.py @@ -0,0 +1,150 @@ +import logging +import uuid +from datetime import timedelta + +import jwt +from pydantic import Field +from sqlalchemy import select + +import src.util.datetime_util as datetime_util +from src.adapters import db +from src.auth.auth_errors import JwtValidationError +from src.db.models.user_models import User, UserTokenSession +from src.util.env_config import PydanticBaseEnvConfig + +logger = logging.getLogger(__name__) + + +class ApiJwtConfig(PydanticBaseEnvConfig): + + private_key: str = Field(alias="API_JWT_PRIVATE_KEY") + public_key: str = Field(alias="API_JWT_PUBLIC_KEY") + + issuer: str = Field("simpler-grants-api", alias="API_JWT_ISSUER") + audience: str = Field("simpler-grants-api", alias="API_JWT_AUDIENCE") + + algorithm: str = Field("RS256", alias="API_JWT_ALGORITHM") + + token_expiration_minutes: int = Field(30, alias="API_JWT_TOKEN_EXPIRATION_MINUTES") + + +# Initialize a config at startup that we'll use below +_config: ApiJwtConfig | None = None + + +def initialize() -> None: + global _config + if not _config: + _config = ApiJwtConfig() + logger.info( + "Constructed JWT configuration", + extra={ + # NOTE: We don't just log the entire config + # because that would include the encryption keys + "issuer": _config.issuer, + "audience": _config.audience, + "algorithm": _config.algorithm, + "token_expiration_minutes": _config.token_expiration_minutes, + }, + ) + + +def get_config() -> ApiJwtConfig: + global _config + + if _config is None: + raise Exception("No JWT configuration - initialize() must be run first") + + return _config + + +def create_jwt_for_user( + user: User, db_session: db.Session, config: ApiJwtConfig | None = None +) -> str: + if config is None: + config = get_config() + + # Generate a random ID + token_id = uuid.uuid4() + + # Always do all time checks in UTC for consistency + current_time = datetime_util.utcnow() + expiration_time = current_time + timedelta(minutes=config.token_expiration_minutes) + + # Create the session in the DB + db_session.add( + UserTokenSession( + user=user, + token_id=token_id, + expires_at=expiration_time, + ) + ) + + # Create the JWT with information we'll want to receive back + payload = { + "sub": str(token_id), + # iat -> issued at + "iat": current_time, + "aud": config.audience, + "iss": config.issuer, + } + + return jwt.encode(payload, config.private_key, algorithm="RS256") + + +def parse_jwt_for_user( + token: str, db_session: db.Session, config: ApiJwtConfig | None = None +) -> User: + # TODO - more implementation/validation to come in https://github.com/HHS/simpler-grants-gov/issues/2809 + if config is None: + config = get_config() + + current_timestamp = datetime_util.utcnow() + + try: + parsed_jwt: dict = jwt.decode( + token, + config.public_key, + algorithms=[config.algorithm], + issuer=config.issuer, + audience=config.audience, + options={ + "verify_signature": True, + "verify_iat": True, + "verify_aud": True, + "verify_iss": True, + # We do not set the following fields + # so do not want to validate. + "verify_exp": False, # expiration is managed in the DB + "verify_nbf": False, # Tokens are always fine to use immediately + }, + ) + + except jwt.ImmatureSignatureError as e: # IAT errors hit this + raise JwtValidationError("Token not yet valid") from e + except jwt.InvalidIssuerError as e: + raise JwtValidationError("Unknown Issuer") from e + except jwt.InvalidAudienceError as e: + raise JwtValidationError("Unknown Audience") from e + except jwt.PyJWTError as e: + # Every other error case wrap in the same generic error message. + raise JwtValidationError("Unable to process token") from e + + sub_id = parsed_jwt.get("sub", None) + if sub_id is None: + raise JwtValidationError("Token missing sub field") + + token_session: UserTokenSession | None = db_session.execute( + select(UserTokenSession).join(User).where(UserTokenSession.token_id == sub_id) + ).scalar_one_or_none() + + # We check both the token expires_at timestamp as well as an + # is_valid flag to make sure the token is still valid. + if token_session is None: + raise JwtValidationError("Token session does not exist") + if token_session.expires_at < current_timestamp: + raise JwtValidationError("Token expired") + if token_session.is_valid is False: + raise JwtValidationError("Token is no longer valid") + + return token_session.user diff --git a/api/src/auth/auth_errors.py b/api/src/auth/auth_errors.py new file mode 100644 index 000000000..1e2bdee1b --- /dev/null +++ b/api/src/auth/auth_errors.py @@ -0,0 +1,8 @@ +class JwtValidationError(Exception): + """ + Exception we will reraise if there are + any issues processing a JWT that should + cause the endpoint to raise a 401 + """ + + pass diff --git a/api/src/auth/login_gov_jwt_auth.py b/api/src/auth/login_gov_jwt_auth.py index 3df2b5055..44e7cbec7 100644 --- a/api/src/auth/login_gov_jwt_auth.py +++ b/api/src/auth/login_gov_jwt_auth.py @@ -4,6 +4,7 @@ import jwt from pydantic import Field +from src.auth.auth_errors import JwtValidationError from src.util.env_config import PydanticBaseEnvConfig logger = logging.getLogger(__name__) @@ -40,16 +41,6 @@ class LoginGovUser: email: str -class JwtValidationError(Exception): - """ - Exception we will reraise if there are - any issues processing a JWT that should - cause the endpoint to raise a 401 - """ - - pass - - def _refresh_keys(config: LoginGovConfig) -> None: """ WARNING: diff --git a/api/src/db/migrations/versions/2024_11_18_user_session_table.py b/api/src/db/migrations/versions/2024_11_18_user_session_table.py new file mode 100644 index 000000000..58e28ad50 --- /dev/null +++ b/api/src/db/migrations/versions/2024_11_18_user_session_table.py @@ -0,0 +1,51 @@ +"""user session table + +Revision ID: 16eaca2334c9 +Revises: 7346f6b52c3d +Create Date: 2024-11-18 13:10:37.039657 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "16eaca2334c9" +down_revision = "7346f6b52c3d" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "user_token_session", + sa.Column("user_id", sa.UUID(), nullable=False), + sa.Column("token_id", sa.UUID(), nullable=False), + sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("is_valid", sa.Boolean(), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["user_id"], ["api.user.user_id"], name=op.f("user_token_session_user_id_user_fkey") + ), + sa.PrimaryKeyConstraint("user_id", "token_id", name=op.f("user_token_session_pkey")), + schema="api", + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("user_token_session", schema="api") + # ### end Alembic commands ### diff --git a/api/src/db/models/user_models.py b/api/src/db/models/user_models.py index 36c27eacc..c954733a3 100644 --- a/api/src/db/models/user_models.py +++ b/api/src/db/models/user_models.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime from sqlalchemy import ForeignKey from sqlalchemy.dialects.postgresql import UUID @@ -33,3 +34,17 @@ class LinkExternalUser(ApiSchemaTable, TimestampMixin): ) email: Mapped[str] + + +class UserTokenSession(ApiSchemaTable, TimestampMixin): + __tablename__ = "user_token_session" + + user_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey(User.user_id), primary_key=True) + user: Mapped[User] = relationship(User) + + token_id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True) + + expires_at: Mapped[datetime] + + # When a user logs out, we set this flag to False. + is_valid: Mapped[bool] = mapped_column(default=True) diff --git a/api/tests/src/auth/conftest.py b/api/tests/src/auth/conftest.py new file mode 100644 index 000000000..dd0f703ad --- /dev/null +++ b/api/tests/src/auth/conftest.py @@ -0,0 +1,40 @@ +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + + +def _generate_rsa_key_pair(): + # Rather than define a private/public key, generate one for the tests + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + private_key = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + public_key = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + return private_key, public_key + + +@pytest.fixture(scope="session") +def rsa_key_pair(): + return _generate_rsa_key_pair() + + +@pytest.fixture(scope="session") +def private_rsa_key(rsa_key_pair): + return rsa_key_pair[0] + + +@pytest.fixture(scope="session") +def public_rsa_key(rsa_key_pair): + return rsa_key_pair[1] + + +@pytest.fixture(scope="session") +def other_rsa_key_pair(): + return _generate_rsa_key_pair() diff --git a/api/tests/src/auth/test_api_jwt_auth.py b/api/tests/src/auth/test_api_jwt_auth.py new file mode 100644 index 000000000..4770327f0 --- /dev/null +++ b/api/tests/src/auth/test_api_jwt_auth.py @@ -0,0 +1,54 @@ +from calendar import timegm +from datetime import datetime + +import jwt +import pytest +from freezegun import freeze_time + +from src.auth.api_jwt_auth import ApiJwtConfig, create_jwt_for_user, parse_jwt_for_user +from src.db.models.user_models import UserTokenSession +from tests.src.db.models.factories import UserFactory + + +@pytest.fixture +def jwt_config(private_rsa_key, public_rsa_key): + return ApiJwtConfig( + API_JWT_PRIVATE_KEY=private_rsa_key, + API_JWT_PUBLIC_KEY=public_rsa_key, + ) + + +@freeze_time("2024-11-14 12:00:00", tz_offset=0) +def test_create_jwt_for_user(enable_factory_create, db_session, jwt_config): + user = UserFactory.create() + + token = create_jwt_for_user(user, db_session, jwt_config) + + decoded_token = jwt.decode( + token, algorithms=[jwt_config.algorithm], options={"verify_signature": False} + ) + + # Verify the issued at timestamp is at the expected (now) timestamp + # note we have to convert it to a unix timestamp + assert decoded_token["iat"] == timegm( + datetime.fromisoformat("2024-11-14 12:00:00+00:00").utctimetuple() + ) + assert decoded_token["iss"] == jwt_config.issuer + assert decoded_token["aud"] == jwt_config.audience + + # Verify that the sub_id returned can be used to fetch a UserTokenSession object + token_session = ( + db_session.query(UserTokenSession) + .filter(UserTokenSession.token_id == decoded_token["sub"]) + .one_or_none() + ) + + assert token_session.user_id == user.user_id + assert token_session.is_valid is True + # Verify expires_at is set to 30 minutes after now by default + assert token_session.expires_at == datetime.fromisoformat("2024-11-14 12:30:00+00:00") + + # Basic testing that the JWT we create for a user can in turn be fetched and processed later + # TODO - more in https://github.com/HHS/simpler-grants-gov/issues/2809 + parsed_user = parse_jwt_for_user(token, db_session, jwt_config) + assert parsed_user.user_id == user.user_id diff --git a/api/tests/src/auth/test_login_gov_jwt_auth.py b/api/tests/src/auth/test_login_gov_jwt_auth.py index cda5fc62d..3ad22566b 100644 --- a/api/tests/src/auth/test_login_gov_jwt_auth.py +++ b/api/tests/src/auth/test_login_gov_jwt_auth.py @@ -2,8 +2,6 @@ import jwt import pytest -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa import src.auth.login_gov_jwt_auth as login_gov_jwt_auth from src.auth.login_gov_jwt_auth import JwtValidationError, LoginGovConfig, validate_token @@ -12,43 +10,6 @@ DEFAULT_ISSUER = "http://localhost:3000" -def _generate_rsa_key_pair(): - # Rather than define a private/public key, generate one for the tests - key = rsa.generate_private_key(public_exponent=65537, key_size=2048) - - private_key = key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ) - - public_key = key.public_key().public_bytes( - encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo - ) - - return private_key, public_key - - -@pytest.fixture(scope="session") -def rsa_key_pair(): - return _generate_rsa_key_pair() - - -@pytest.fixture(scope="session") -def private_rsa_key(rsa_key_pair): - return rsa_key_pair[0] - - -@pytest.fixture(scope="session") -def public_rsa_key(rsa_key_pair): - return rsa_key_pair[1] - - -@pytest.fixture(scope="session") -def other_rsa_key_pair(): - return _generate_rsa_key_pair() - - @pytest.fixture def login_gov_config(public_rsa_key): # Note this isn't session scoped so it gets remade diff --git a/documentation/api/database/erds/api-schema.png b/documentation/api/database/erds/api-schema.png index e96c3e5e9..093fa6d18 100644 Binary files a/documentation/api/database/erds/api-schema.png and b/documentation/api/database/erds/api-schema.png differ