Skip to content

Commit

Permalink
[Issue #2808] Create a JWT for a user (#2898)
Browse files Browse the repository at this point in the history
## Summary
Fixes #2808

### Time to review: __10 mins__

## Changes proposed
Adds logic to create our own internal JWT for a user

Adds a new user session table for tracking the JWT expiration

Adds the beginning of the logic for parsing that JWT

## Context for reviewers
Generating a JWT is pretty simple, just give the JWT library a key and
whatever you want encoded, and it just kinda works.

To follow some general conventions, I'm making a token that has a
payload like:
```json
{
    "sub": "abc-123",
    "iat": 1234567890, // unix timestamp - issued at
    "aud": "simpler-grants-api",
    "iss": "simpler-grants-api",
}
```
The `sub` value is the main one and is a "token_id" that we store in the
new `user_token_session` table and can later use to fetch the token back
out.

I added some rough parsing logic just for the sake of testing, this will
expand significantly in the next PR where we actually build that parsing
logic back out.

Open question: would it make more sense to have audience be the user_id
itself? Logically makes sense, but the parsing on the other side is a
bit clunkier.

---------

Co-authored-by: nava-platform-bot <[email protected]>
  • Loading branch information
chouinar and nava-platform-bot authored Nov 21, 2024
1 parent 873c776 commit 4736537
Show file tree
Hide file tree
Showing 9 changed files with 319 additions and 49 deletions.
150 changes: 150 additions & 0 deletions api/src/auth/api_jwt_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import logging
import uuid
from datetime import timedelta

import jwt
from pydantic import Field
from sqlalchemy import select

import src.util.datetime_util as datetime_util
from src.adapters import db
from src.auth.auth_errors import JwtValidationError
from src.db.models.user_models import User, UserTokenSession
from src.util.env_config import PydanticBaseEnvConfig

logger = logging.getLogger(__name__)


class ApiJwtConfig(PydanticBaseEnvConfig):

private_key: str = Field(alias="API_JWT_PRIVATE_KEY")
public_key: str = Field(alias="API_JWT_PUBLIC_KEY")

issuer: str = Field("simpler-grants-api", alias="API_JWT_ISSUER")
audience: str = Field("simpler-grants-api", alias="API_JWT_AUDIENCE")

algorithm: str = Field("RS256", alias="API_JWT_ALGORITHM")

token_expiration_minutes: int = Field(30, alias="API_JWT_TOKEN_EXPIRATION_MINUTES")


# Initialize a config at startup that we'll use below
_config: ApiJwtConfig | None = None


def initialize() -> None:
global _config
if not _config:
_config = ApiJwtConfig()
logger.info(
"Constructed JWT configuration",
extra={
# NOTE: We don't just log the entire config
# because that would include the encryption keys
"issuer": _config.issuer,
"audience": _config.audience,
"algorithm": _config.algorithm,
"token_expiration_minutes": _config.token_expiration_minutes,
},
)


def get_config() -> ApiJwtConfig:
global _config

if _config is None:
raise Exception("No JWT configuration - initialize() must be run first")

return _config


def create_jwt_for_user(
user: User, db_session: db.Session, config: ApiJwtConfig | None = None
) -> str:
if config is None:
config = get_config()

# Generate a random ID
token_id = uuid.uuid4()

# Always do all time checks in UTC for consistency
current_time = datetime_util.utcnow()
expiration_time = current_time + timedelta(minutes=config.token_expiration_minutes)

# Create the session in the DB
db_session.add(
UserTokenSession(
user=user,
token_id=token_id,
expires_at=expiration_time,
)
)

# Create the JWT with information we'll want to receive back
payload = {
"sub": str(token_id),
# iat -> issued at
"iat": current_time,
"aud": config.audience,
"iss": config.issuer,
}

return jwt.encode(payload, config.private_key, algorithm="RS256")


def parse_jwt_for_user(
token: str, db_session: db.Session, config: ApiJwtConfig | None = None
) -> User:
# TODO - more implementation/validation to come in https://github.com/HHS/simpler-grants-gov/issues/2809
if config is None:
config = get_config()

current_timestamp = datetime_util.utcnow()

try:
parsed_jwt: dict = jwt.decode(
token,
config.public_key,
algorithms=[config.algorithm],
issuer=config.issuer,
audience=config.audience,
options={
"verify_signature": True,
"verify_iat": True,
"verify_aud": True,
"verify_iss": True,
# We do not set the following fields
# so do not want to validate.
"verify_exp": False, # expiration is managed in the DB
"verify_nbf": False, # Tokens are always fine to use immediately
},
)

except jwt.ImmatureSignatureError as e: # IAT errors hit this
raise JwtValidationError("Token not yet valid") from e
except jwt.InvalidIssuerError as e:
raise JwtValidationError("Unknown Issuer") from e
except jwt.InvalidAudienceError as e:
raise JwtValidationError("Unknown Audience") from e
except jwt.PyJWTError as e:
# Every other error case wrap in the same generic error message.
raise JwtValidationError("Unable to process token") from e

sub_id = parsed_jwt.get("sub", None)
if sub_id is None:
raise JwtValidationError("Token missing sub field")

token_session: UserTokenSession | None = db_session.execute(
select(UserTokenSession).join(User).where(UserTokenSession.token_id == sub_id)
).scalar_one_or_none()

# We check both the token expires_at timestamp as well as an
# is_valid flag to make sure the token is still valid.
if token_session is None:
raise JwtValidationError("Token session does not exist")
if token_session.expires_at < current_timestamp:
raise JwtValidationError("Token expired")
if token_session.is_valid is False:
raise JwtValidationError("Token is no longer valid")

return token_session.user
8 changes: 8 additions & 0 deletions api/src/auth/auth_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class JwtValidationError(Exception):
"""
Exception we will reraise if there are
any issues processing a JWT that should
cause the endpoint to raise a 401
"""

pass
11 changes: 1 addition & 10 deletions api/src/auth/login_gov_jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jwt
from pydantic import Field

from src.auth.auth_errors import JwtValidationError
from src.util.env_config import PydanticBaseEnvConfig

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -40,16 +41,6 @@ class LoginGovUser:
email: str


class JwtValidationError(Exception):
"""
Exception we will reraise if there are
any issues processing a JWT that should
cause the endpoint to raise a 401
"""

pass


def _refresh_keys(config: LoginGovConfig) -> None:
"""
WARNING:
Expand Down
51 changes: 51 additions & 0 deletions api/src/db/migrations/versions/2024_11_18_user_session_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""user session table
Revision ID: 16eaca2334c9
Revises: 7346f6b52c3d
Create Date: 2024-11-18 13:10:37.039657
"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "16eaca2334c9"
down_revision = "7346f6b52c3d"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"user_token_session",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("token_id", sa.UUID(), nullable=False),
sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False),
sa.Column("is_valid", sa.Boolean(), nullable=False),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["user_id"], ["api.user.user_id"], name=op.f("user_token_session_user_id_user_fkey")
),
sa.PrimaryKeyConstraint("user_id", "token_id", name=op.f("user_token_session_pkey")),
schema="api",
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("user_token_session", schema="api")
# ### end Alembic commands ###
15 changes: 15 additions & 0 deletions api/src/db/models/user_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from datetime import datetime

from sqlalchemy import ForeignKey
from sqlalchemy.dialects.postgresql import UUID
Expand Down Expand Up @@ -33,3 +34,17 @@ class LinkExternalUser(ApiSchemaTable, TimestampMixin):
)

email: Mapped[str]


class UserTokenSession(ApiSchemaTable, TimestampMixin):
__tablename__ = "user_token_session"

user_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey(User.user_id), primary_key=True)
user: Mapped[User] = relationship(User)

token_id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True)

expires_at: Mapped[datetime]

# When a user logs out, we set this flag to False.
is_valid: Mapped[bool] = mapped_column(default=True)
40 changes: 40 additions & 0 deletions api/tests/src/auth/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa


def _generate_rsa_key_pair():
# Rather than define a private/public key, generate one for the tests
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)

private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)

public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo
)

return private_key, public_key


@pytest.fixture(scope="session")
def rsa_key_pair():
return _generate_rsa_key_pair()


@pytest.fixture(scope="session")
def private_rsa_key(rsa_key_pair):
return rsa_key_pair[0]


@pytest.fixture(scope="session")
def public_rsa_key(rsa_key_pair):
return rsa_key_pair[1]


@pytest.fixture(scope="session")
def other_rsa_key_pair():
return _generate_rsa_key_pair()
54 changes: 54 additions & 0 deletions api/tests/src/auth/test_api_jwt_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from calendar import timegm
from datetime import datetime

import jwt
import pytest
from freezegun import freeze_time

from src.auth.api_jwt_auth import ApiJwtConfig, create_jwt_for_user, parse_jwt_for_user
from src.db.models.user_models import UserTokenSession
from tests.src.db.models.factories import UserFactory


@pytest.fixture
def jwt_config(private_rsa_key, public_rsa_key):
return ApiJwtConfig(
API_JWT_PRIVATE_KEY=private_rsa_key,
API_JWT_PUBLIC_KEY=public_rsa_key,
)


@freeze_time("2024-11-14 12:00:00", tz_offset=0)
def test_create_jwt_for_user(enable_factory_create, db_session, jwt_config):
user = UserFactory.create()

token = create_jwt_for_user(user, db_session, jwt_config)

decoded_token = jwt.decode(
token, algorithms=[jwt_config.algorithm], options={"verify_signature": False}
)

# Verify the issued at timestamp is at the expected (now) timestamp
# note we have to convert it to a unix timestamp
assert decoded_token["iat"] == timegm(
datetime.fromisoformat("2024-11-14 12:00:00+00:00").utctimetuple()
)
assert decoded_token["iss"] == jwt_config.issuer
assert decoded_token["aud"] == jwt_config.audience

# Verify that the sub_id returned can be used to fetch a UserTokenSession object
token_session = (
db_session.query(UserTokenSession)
.filter(UserTokenSession.token_id == decoded_token["sub"])
.one_or_none()
)

assert token_session.user_id == user.user_id
assert token_session.is_valid is True
# Verify expires_at is set to 30 minutes after now by default
assert token_session.expires_at == datetime.fromisoformat("2024-11-14 12:30:00+00:00")

# Basic testing that the JWT we create for a user can in turn be fetched and processed later
# TODO - more in https://github.com/HHS/simpler-grants-gov/issues/2809
parsed_user = parse_jwt_for_user(token, db_session, jwt_config)
assert parsed_user.user_id == user.user_id
Loading

0 comments on commit 4736537

Please sign in to comment.