diff --git a/api/src/alembic/env.py b/api/src/alembic/env.py index 9fdd7519..b6dda829 100644 --- a/api/src/alembic/env.py +++ b/api/src/alembic/env.py @@ -8,6 +8,7 @@ from data_inclusion.api.core import db from data_inclusion.api.decoupage_administratif import models as _ # noqa: F401 F811 from data_inclusion.api.inclusion_data import models as _ # noqa: F401 F811 +from data_inclusion.api.request import models as _ # noqa: F401 F811 # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/api/src/alembic/versions/20241217_103425_bc02c1c6f60e_remove_request.py b/api/src/alembic/versions/20241217_103425_bc02c1c6f60e_remove_request.py deleted file mode 100644 index 9915ed97..00000000 --- a/api/src/alembic/versions/20241217_103425_bc02c1c6f60e_remove_request.py +++ /dev/null @@ -1,86 +0,0 @@ -"""remove request - -Revision ID: bc02c1c6f60e -Revises: 89e1ece4f56e -Create Date: 2024-12-17 10:34:25.120009 - -""" - -import sqlalchemy as sa -from alembic import op -from sqlalchemy.dialects import postgresql - -# revision identifiers, used by Alembic. -revision = "bc02c1c6f60e" -down_revision = "89e1ece4f56e" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - op.drop_table("api__requests") - pass - - -def downgrade() -> None: - op.create_table( - "api__requests", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.text("now()"), - nullable=False, - ), - sa.Column("status_code", sa.Integer(), nullable=False), - sa.Column("method", sa.String(), nullable=False), - sa.Column("path", sa.String(), nullable=False), - sa.Column("base_url", sa.String(), nullable=False), - sa.Column("user", sa.String(), nullable=True), - sa.Column( - "path_params", postgresql.JSONB(astext_type=sa.Text()), nullable=True - ), - sa.Column( - "query_params", postgresql.JSONB(astext_type=sa.Text()), nullable=True - ), - sa.Column("client_host", sa.String(), nullable=True), - sa.Column("client_port", sa.Integer(), nullable=True), - sa.Column("endpoint_name", sa.String(), nullable=True), - sa.PrimaryKeyConstraint("id", name=op.f("pk_api__requests")), - ) - with op.get_context().autocommit_block(): - op.create_index( - op.f("ix_api__requests__created_at"), - "api__requests", - ["created_at"], - unique=False, - postgresql_concurrently=True, - ) - op.create_index( - op.f("ix_api__requests__endpoint_name"), - "api__requests", - ["endpoint_name"], - unique=False, - postgresql_concurrently=True, - ) - op.create_index( - op.f("ix_api__requests__method"), - "api__requests", - ["method"], - unique=False, - postgresql_concurrently=True, - ) - op.create_index( - op.f("ix_api__requests__status_code"), - "api__requests", - ["status_code"], - unique=False, - postgresql_concurrently=True, - ) - op.create_index( - op.f("ix_api__requests__user"), - "api__requests", - ["user"], - unique=False, - postgresql_concurrently=True, - ) diff --git a/api/src/data_inclusion/api/app.py b/api/src/data_inclusion/api/app.py index c87d6ed3..5018b468 100644 --- a/api/src/data_inclusion/api/app.py +++ b/api/src/data_inclusion/api/app.py @@ -13,6 +13,7 @@ from data_inclusion.api.core import db from data_inclusion.api.inclusion_data.routes import router as data_api_router from data_inclusion.api.inclusion_schema.routes import router as schema_api_router +from data_inclusion.api.request.middleware import save_request_middleware API_DESCRIPTION_PATH = Path(__file__).parent / "api_description.md" @@ -73,6 +74,7 @@ def create_app() -> fastapi.FastAPI: setup_debug_toolbar_middleware(app) app.middleware("http")(db.db_session_middleware) + app.middleware("http")(save_request_middleware) app.include_router(v0_api_router) app.include_router( diff --git a/api/src/data_inclusion/api/request/__init__.py b/api/src/data_inclusion/api/request/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/src/data_inclusion/api/request/middleware.py b/api/src/data_inclusion/api/request/middleware.py new file mode 100644 index 00000000..e5239025 --- /dev/null +++ b/api/src/data_inclusion/api/request/middleware.py @@ -0,0 +1,22 @@ +import logging + +import fastapi + +from data_inclusion.api.request.services import save_request + +logger = logging.getLogger(__name__) + + +async def save_request_middleware(request: fastapi.Request, call_next): + response = fastapi.Response("Internal server error", status_code=500) + try: + response = await call_next(request) + except: + raise + finally: + try: + save_request(request, response, db_session=request.state.db_session) + except Exception as err: + logger.error(err) + pass + return response diff --git a/api/src/data_inclusion/api/request/models.py b/api/src/data_inclusion/api/request/models.py new file mode 100644 index 00000000..c5319692 --- /dev/null +++ b/api/src/data_inclusion/api/request/models.py @@ -0,0 +1,27 @@ +import sqlalchemy as sqla +from sqlalchemy.orm import Mapped + +from data_inclusion.api.core import db + + +class Request(db.Base): + id: Mapped[db.uuid_pk] + created_at: Mapped[db.timestamp] + status_code: Mapped[int] + method: Mapped[str] + path: Mapped[str] + base_url: Mapped[str] + user: Mapped[str | None] + path_params: Mapped[dict] + query_params: Mapped[dict] + client_host: Mapped[str | None] + client_port: Mapped[int | None] + endpoint_name: Mapped[str | None] + + __table_args__ = ( + sqla.Index(None, "endpoint_name"), + sqla.Index(None, "method"), + sqla.Index(None, "status_code"), + sqla.Index(None, "created_at"), + sqla.Index(None, "user"), + ) diff --git a/api/src/data_inclusion/api/request/services.py b/api/src/data_inclusion/api/request/services.py new file mode 100644 index 00000000..5d4db50b --- /dev/null +++ b/api/src/data_inclusion/api/request/services.py @@ -0,0 +1,47 @@ +import fastapi + +from data_inclusion.api.core import db +from data_inclusion.api.request import models + + +def is_trailing_slash_redirect( + request: fastapi.Request, response: fastapi.Response +) -> bool: + redirect_url = response.headers.get("location") + return response.status_code == 307 and str(request.url) == f"{redirect_url}/" + + +def save_request( + request: fastapi.Request, + response: fastapi.Response, + db_session=fastapi.Depends(db.get_session), +) -> None: + if is_trailing_slash_redirect(request=request, response=response): + return + + endpoint_name = None + if (route := request.scope.get("route")) is not None: + endpoint_name = route.name + + username = None + if (user := request.scope.get("user")) is not None and user.is_authenticated: + username = user.username + + request_instance = models.Request( + status_code=response.status_code, + method=request.method, + path=request.url.path, + base_url=str(request.base_url), + user=username, + path_params=request.path_params, + query_params={ + key: ",".join(request.query_params.getlist(key)) + for key in request.query_params.keys() + }, + client_host=request.client.host if request.client is not None else None, + client_port=request.client.port if request.client is not None else None, + endpoint_name=endpoint_name, + ) # type: ignore + + db_session.add(request_instance) + db_session.commit() diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 8d06d664..6ef12bed 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -144,6 +144,7 @@ def db_session(db_connection): factories.StructureFactory._meta.sqlalchemy_session = session factories.ServiceFactory._meta.sqlalchemy_session = session + factories.RequestFactory._meta.sqlalchemy_session = session yield session @@ -156,5 +157,6 @@ def predictable_sequences(): import factory.random factory.random.reseed_random(0) + factories.RequestFactory.reset_sequence() factories.ServiceFactory.reset_sequence() factories.StructureFactory.reset_sequence() diff --git a/api/tests/e2e/api/test_request.py b/api/tests/e2e/api/test_request.py new file mode 100644 index 00000000..0084a5a5 --- /dev/null +++ b/api/tests/e2e/api/test_request.py @@ -0,0 +1,57 @@ +import pytest +import sqlalchemy as sqla + +from data_inclusion.api.request import models + + +@pytest.mark.with_token +def test_save_api_request_with_token(api_client, db_session): + url = "/api/v0/structures/foo/bar?baz=1" + response = api_client.get(url) + + assert response.status_code == 404 + assert ( + db_session.scalar(sqla.select(sqla.func.count()).select_from(models.Request)) + == 1 + ) + + request_instance = db_session.scalars(sqla.select(models.Request)).first() + assert request_instance.status_code == 404 + assert request_instance.user == "some_user" + assert request_instance.path == "/api/v0/structures/foo/bar" + assert request_instance.method == "GET" + assert request_instance.path_params == {"source": "foo", "id": "bar"} + assert request_instance.query_params == {"baz": "1"} + assert request_instance.endpoint_name == "retrieve_structure_endpoint" + + +@pytest.mark.with_token +def test_ignore_redirect(api_client, db_session): + url = "/api/v0/structures/" + response = api_client.get(url) + + assert response.status_code == 200 + assert ( + db_session.scalar(sqla.select(sqla.func.count()).select_from(models.Request)) + == 1 + ) + + +def test_save_api_request_without_token(api_client, db_session): + url = "/api/v0/structures" + response = api_client.get(url) + + assert response.status_code == 403 + assert ( + db_session.scalar(sqla.select(sqla.func.count()).select_from(models.Request)) + == 1 + ) + + request_instance = db_session.scalars(sqla.select(models.Request)).first() + assert request_instance.status_code == 403 + assert request_instance.user is None + assert request_instance.path == "/api/v0/structures" + assert request_instance.method == "GET" + assert request_instance.path_params == {} + assert request_instance.query_params == {} + assert request_instance.endpoint_name == "list_structures_endpoint" diff --git a/api/tests/factories.py b/api/tests/factories.py index 52f825f1..f0365f2f 100644 --- a/api/tests/factories.py +++ b/api/tests/factories.py @@ -5,10 +5,19 @@ from data_inclusion import schema as di_schema from data_inclusion.api.inclusion_data import models +from data_inclusion.api.request.models import Request fake = faker.Faker("fr_FR") +class RequestFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = Request + sqlalchemy_session_persistence = "commit" + + status_code = 200 + + class StructureFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: model = models.Structure