-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
## Summary Fixes #2808 ### Time to review: __10 mins__ ## Changes proposed Adds logic to create our own internal JWT for a user Adds a new user session table for tracking the JWT expiration Adds the beginning of the logic for parsing that JWT ## Context for reviewers Generating a JWT is pretty simple, just give the JWT library a key and whatever you want encoded, and it just kinda works. To follow some general conventions, I'm making a token that has a payload like: ```json { "sub": "abc-123", "iat": 1234567890, // unix timestamp - issued at "aud": "simpler-grants-api", "iss": "simpler-grants-api", } ``` The `sub` value is the main one and is a "token_id" that we store in the new `user_token_session` table and can later use to fetch the token back out. I added some rough parsing logic just for the sake of testing, this will expand significantly in the next PR where we actually build that parsing logic back out. Open question: would it make more sense to have audience be the user_id itself? Logically makes sense, but the parsing on the other side is a bit clunkier. --------- Co-authored-by: nava-platform-bot <[email protected]>
- Loading branch information
1 parent
873c776
commit 4736537
Showing
9 changed files
with
319 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import logging | ||
import uuid | ||
from datetime import timedelta | ||
|
||
import jwt | ||
from pydantic import Field | ||
from sqlalchemy import select | ||
|
||
import src.util.datetime_util as datetime_util | ||
from src.adapters import db | ||
from src.auth.auth_errors import JwtValidationError | ||
from src.db.models.user_models import User, UserTokenSession | ||
from src.util.env_config import PydanticBaseEnvConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ApiJwtConfig(PydanticBaseEnvConfig): | ||
|
||
private_key: str = Field(alias="API_JWT_PRIVATE_KEY") | ||
public_key: str = Field(alias="API_JWT_PUBLIC_KEY") | ||
|
||
issuer: str = Field("simpler-grants-api", alias="API_JWT_ISSUER") | ||
audience: str = Field("simpler-grants-api", alias="API_JWT_AUDIENCE") | ||
|
||
algorithm: str = Field("RS256", alias="API_JWT_ALGORITHM") | ||
|
||
token_expiration_minutes: int = Field(30, alias="API_JWT_TOKEN_EXPIRATION_MINUTES") | ||
|
||
|
||
# Initialize a config at startup that we'll use below | ||
_config: ApiJwtConfig | None = None | ||
|
||
|
||
def initialize() -> None: | ||
global _config | ||
if not _config: | ||
_config = ApiJwtConfig() | ||
logger.info( | ||
"Constructed JWT configuration", | ||
extra={ | ||
# NOTE: We don't just log the entire config | ||
# because that would include the encryption keys | ||
"issuer": _config.issuer, | ||
"audience": _config.audience, | ||
"algorithm": _config.algorithm, | ||
"token_expiration_minutes": _config.token_expiration_minutes, | ||
}, | ||
) | ||
|
||
|
||
def get_config() -> ApiJwtConfig: | ||
global _config | ||
|
||
if _config is None: | ||
raise Exception("No JWT configuration - initialize() must be run first") | ||
|
||
return _config | ||
|
||
|
||
def create_jwt_for_user( | ||
user: User, db_session: db.Session, config: ApiJwtConfig | None = None | ||
) -> str: | ||
if config is None: | ||
config = get_config() | ||
|
||
# Generate a random ID | ||
token_id = uuid.uuid4() | ||
|
||
# Always do all time checks in UTC for consistency | ||
current_time = datetime_util.utcnow() | ||
expiration_time = current_time + timedelta(minutes=config.token_expiration_minutes) | ||
|
||
# Create the session in the DB | ||
db_session.add( | ||
UserTokenSession( | ||
user=user, | ||
token_id=token_id, | ||
expires_at=expiration_time, | ||
) | ||
) | ||
|
||
# Create the JWT with information we'll want to receive back | ||
payload = { | ||
"sub": str(token_id), | ||
# iat -> issued at | ||
"iat": current_time, | ||
"aud": config.audience, | ||
"iss": config.issuer, | ||
} | ||
|
||
return jwt.encode(payload, config.private_key, algorithm="RS256") | ||
|
||
|
||
def parse_jwt_for_user( | ||
token: str, db_session: db.Session, config: ApiJwtConfig | None = None | ||
) -> User: | ||
# TODO - more implementation/validation to come in https://github.com/HHS/simpler-grants-gov/issues/2809 | ||
if config is None: | ||
config = get_config() | ||
|
||
current_timestamp = datetime_util.utcnow() | ||
|
||
try: | ||
parsed_jwt: dict = jwt.decode( | ||
token, | ||
config.public_key, | ||
algorithms=[config.algorithm], | ||
issuer=config.issuer, | ||
audience=config.audience, | ||
options={ | ||
"verify_signature": True, | ||
"verify_iat": True, | ||
"verify_aud": True, | ||
"verify_iss": True, | ||
# We do not set the following fields | ||
# so do not want to validate. | ||
"verify_exp": False, # expiration is managed in the DB | ||
"verify_nbf": False, # Tokens are always fine to use immediately | ||
}, | ||
) | ||
|
||
except jwt.ImmatureSignatureError as e: # IAT errors hit this | ||
raise JwtValidationError("Token not yet valid") from e | ||
except jwt.InvalidIssuerError as e: | ||
raise JwtValidationError("Unknown Issuer") from e | ||
except jwt.InvalidAudienceError as e: | ||
raise JwtValidationError("Unknown Audience") from e | ||
except jwt.PyJWTError as e: | ||
# Every other error case wrap in the same generic error message. | ||
raise JwtValidationError("Unable to process token") from e | ||
|
||
sub_id = parsed_jwt.get("sub", None) | ||
if sub_id is None: | ||
raise JwtValidationError("Token missing sub field") | ||
|
||
token_session: UserTokenSession | None = db_session.execute( | ||
select(UserTokenSession).join(User).where(UserTokenSession.token_id == sub_id) | ||
).scalar_one_or_none() | ||
|
||
# We check both the token expires_at timestamp as well as an | ||
# is_valid flag to make sure the token is still valid. | ||
if token_session is None: | ||
raise JwtValidationError("Token session does not exist") | ||
if token_session.expires_at < current_timestamp: | ||
raise JwtValidationError("Token expired") | ||
if token_session.is_valid is False: | ||
raise JwtValidationError("Token is no longer valid") | ||
|
||
return token_session.user |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
class JwtValidationError(Exception): | ||
""" | ||
Exception we will reraise if there are | ||
any issues processing a JWT that should | ||
cause the endpoint to raise a 401 | ||
""" | ||
|
||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
api/src/db/migrations/versions/2024_11_18_user_session_table.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""user session table | ||
Revision ID: 16eaca2334c9 | ||
Revises: 7346f6b52c3d | ||
Create Date: 2024-11-18 13:10:37.039657 | ||
""" | ||
|
||
import sqlalchemy as sa | ||
from alembic import op | ||
|
||
# revision identifiers, used by Alembic. | ||
revision = "16eaca2334c9" | ||
down_revision = "7346f6b52c3d" | ||
branch_labels = None | ||
depends_on = None | ||
|
||
|
||
def upgrade(): | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.create_table( | ||
"user_token_session", | ||
sa.Column("user_id", sa.UUID(), nullable=False), | ||
sa.Column("token_id", sa.UUID(), nullable=False), | ||
sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False), | ||
sa.Column("is_valid", sa.Boolean(), nullable=False), | ||
sa.Column( | ||
"created_at", | ||
sa.TIMESTAMP(timezone=True), | ||
server_default=sa.text("now()"), | ||
nullable=False, | ||
), | ||
sa.Column( | ||
"updated_at", | ||
sa.TIMESTAMP(timezone=True), | ||
server_default=sa.text("now()"), | ||
nullable=False, | ||
), | ||
sa.ForeignKeyConstraint( | ||
["user_id"], ["api.user.user_id"], name=op.f("user_token_session_user_id_user_fkey") | ||
), | ||
sa.PrimaryKeyConstraint("user_id", "token_id", name=op.f("user_token_session_pkey")), | ||
schema="api", | ||
) | ||
# ### end Alembic commands ### | ||
|
||
|
||
def downgrade(): | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.drop_table("user_token_session", schema="api") | ||
# ### end Alembic commands ### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import pytest | ||
from cryptography.hazmat.primitives import serialization | ||
from cryptography.hazmat.primitives.asymmetric import rsa | ||
|
||
|
||
def _generate_rsa_key_pair(): | ||
# Rather than define a private/public key, generate one for the tests | ||
key = rsa.generate_private_key(public_exponent=65537, key_size=2048) | ||
|
||
private_key = key.private_bytes( | ||
encoding=serialization.Encoding.PEM, | ||
format=serialization.PrivateFormat.TraditionalOpenSSL, | ||
encryption_algorithm=serialization.NoEncryption(), | ||
) | ||
|
||
public_key = key.public_key().public_bytes( | ||
encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo | ||
) | ||
|
||
return private_key, public_key | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def rsa_key_pair(): | ||
return _generate_rsa_key_pair() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def private_rsa_key(rsa_key_pair): | ||
return rsa_key_pair[0] | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def public_rsa_key(rsa_key_pair): | ||
return rsa_key_pair[1] | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def other_rsa_key_pair(): | ||
return _generate_rsa_key_pair() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from calendar import timegm | ||
from datetime import datetime | ||
|
||
import jwt | ||
import pytest | ||
from freezegun import freeze_time | ||
|
||
from src.auth.api_jwt_auth import ApiJwtConfig, create_jwt_for_user, parse_jwt_for_user | ||
from src.db.models.user_models import UserTokenSession | ||
from tests.src.db.models.factories import UserFactory | ||
|
||
|
||
@pytest.fixture | ||
def jwt_config(private_rsa_key, public_rsa_key): | ||
return ApiJwtConfig( | ||
API_JWT_PRIVATE_KEY=private_rsa_key, | ||
API_JWT_PUBLIC_KEY=public_rsa_key, | ||
) | ||
|
||
|
||
@freeze_time("2024-11-14 12:00:00", tz_offset=0) | ||
def test_create_jwt_for_user(enable_factory_create, db_session, jwt_config): | ||
user = UserFactory.create() | ||
|
||
token = create_jwt_for_user(user, db_session, jwt_config) | ||
|
||
decoded_token = jwt.decode( | ||
token, algorithms=[jwt_config.algorithm], options={"verify_signature": False} | ||
) | ||
|
||
# Verify the issued at timestamp is at the expected (now) timestamp | ||
# note we have to convert it to a unix timestamp | ||
assert decoded_token["iat"] == timegm( | ||
datetime.fromisoformat("2024-11-14 12:00:00+00:00").utctimetuple() | ||
) | ||
assert decoded_token["iss"] == jwt_config.issuer | ||
assert decoded_token["aud"] == jwt_config.audience | ||
|
||
# Verify that the sub_id returned can be used to fetch a UserTokenSession object | ||
token_session = ( | ||
db_session.query(UserTokenSession) | ||
.filter(UserTokenSession.token_id == decoded_token["sub"]) | ||
.one_or_none() | ||
) | ||
|
||
assert token_session.user_id == user.user_id | ||
assert token_session.is_valid is True | ||
# Verify expires_at is set to 30 minutes after now by default | ||
assert token_session.expires_at == datetime.fromisoformat("2024-11-14 12:30:00+00:00") | ||
|
||
# Basic testing that the JWT we create for a user can in turn be fetched and processed later | ||
# TODO - more in https://github.com/HHS/simpler-grants-gov/issues/2809 | ||
parsed_user = parse_jwt_for_user(token, db_session, jwt_config) | ||
assert parsed_user.user_id == user.user_id |
Oops, something went wrong.