Skip to content

Commit

Permalink
Fix various issues with pagination support
Browse files Browse the repository at this point in the history
Fix several issues with pagination support uncovered by converting
Gafaelfawr to the new approach.

* Document the type parameterization of cursors.
* Document the need to implement `from_entry` for cursors derived
  from `DatetimeIdCursor`.
* Move cursor parsing into the body of the suggested handler
  function since FastAPI cannot validate complex models as query
  parameters.
* Change the pagination exception to a new `InvalidCursorError`
  instead of `ValueError` since parsing will be done in the handler
  body and needs to throw an error that can be automatically
  handled and converted to a 422 HTTP status code.
* Rename `PaginatedLinkData` to `PaginationLinkData` for more
  correct grammar.
* Drop `DeclarativeBase` from the prototype of `query_object` because
  tuples are not covariant and this does not pass type checking.
* Document that `query_object` and `query_row` cannot be fully
  type-checked.
  • Loading branch information
rra committed Nov 26, 2024
1 parent c413791 commit b00b61f
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 25 deletions.
32 changes: 22 additions & 10 deletions docs/user-guide/database/pagination.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ To use the generic paginated query support, first you must define the cursor tha
The cursor class defines the following information needed for paginated queries:

#. How to construct a cursor to get all entries before or after a given entry.
Forward cursors must include the entry provided as a basis for the cursor, and reverse curors must exclude it.
Forward cursors must include the entry provided as a basis for the cursor, and reverse cursors must exclude it.
The Pydantic model of the entries is a type parameter to the cursor.
#. How to serialize to and deserialize from a string so that the cursor can be returned to an API client and sent back to retrieve the next batch of results.
#. The sort order the cursor represents.
A cursor class represents one and only one sort order, since keyset cursors rely on the sort order not changing.
Expand All @@ -48,15 +49,15 @@ In the general case, your application must define the cursor by creating a subcl
In the very common case that the API results are sorted first by some timestamp in descending order (most recent first) and then by an auto-increment unique key (most recently inserted row first), Safir provides `~safir.database.DatetimeIdCursor`, which is a generic cursor implementation that implements that ordering and keyset pagination policy.
In this case, you need only subclass `~safir.database.DatetimeIdCursor` and provide the SQLAlchemy ORM model columns that correspond to the timestamp and the unique key.

For example, if you are requesting paginated results from a table whose ORM model is named ``Job``, whose timestamp field is ``Job.creation_time``, and whose unique key is ``Job.id``, you can use the following cursor:
For example, if you are requesting paginated results from a table whose ORM model is named ``Job``, whose timestamp field is ``Job.creation_time``, and whose unique key is ``Job.id``, and using a Pydantic model named ``JobModel`` with the same field names, you can use the following cursor:

.. code-block:: python
from safir.database import DatetimeIdCursor
from sqlalchemy.orm import InstrumentedAttribute
class JobCursor(DatetimeIdCursor):
class JobCursor(DatetimeIdCursor[JobModel]):
@staticmethod
def id_column() -> InstrumentedAttribute:
return Job.id
Expand All @@ -65,9 +66,14 @@ For example, if you are requesting paginated results from a table whose ORM mode
def time_column() -> InstrumentedAttribute:
return Job.creation_time
@classmethod
def from_entry(cls, entry: JobModel, *, reverse: bool = False) -> Self:
return cls(id=entry.id, time=entry.creation_time, reverse=reverse)
(These are essentially class properties, but due to limitations in Python abstract data types and property decorators, they're implemented as static methods.)

In this case, `~safir.database.DatetimeIdCursor` will handle all of the other details for you, including serialization and deserialization.
The type parameter to `~safir.database.DatetimeIdCursor` must be the Pydantic model that the resulting paginated list will contain.

Performing paginated queries
============================
Expand All @@ -86,14 +92,13 @@ The parameter declaration should generally look something like the following:
async def query(
*,
cursor: Annotated[
ModelCursor | None,
str | None,
Query(
title="Pagination cursor",
description=(
"Optional cursor used when moving between pages of results"
),
),
BeforeValidator(lambda c: ModelCursor.from_str(c) if c else None),
] = None,
limit: Annotated[
int,
Expand All @@ -107,14 +112,21 @@ The parameter declaration should generally look something like the following:
] = 100,
request: Request,
response: Response,
) -> list[Model]: ...
) -> list[Model]:
parsed_cursor = None
if cursor:
parsed_cursor = ModelCursor.from_str(cursor)
...
You should be able to use your class's implementation of `~safir.database.PaginationCursor.from_str` as a validator, which lets FastAPI validate the syntax of the cursor for you and handle syntax errors.
Since the cursor is optional (the first query won't have a cursor), you'll need a small wrapper to handle `None`, as shown above.
Unfortunately, due to limitations in FastAPI, you cannot annotate the cursor parameter with a validator that returns the appropriate object.
You must instead parse the cursor in the body of the handler.
`~safir.database.PaginationCursor.from_str` should raise `~safir.database.InvalidCursorError` on parse failure, which will be automatically converted into an HTTP 422 response if you use the error handler described in :doc:`fastapi-errors`.

Also note the ``limit`` parameter, which should also be used on any paginated route.
This sets the size of each block of results.
By default, the exception raised by `~safir.database.DatetimeIdCursor` assumes the cursor is coming from a query parameter named ``cursor``.
If this is not true for your application and you are using a cursor derived from that class, you should either catch the exception, modify ``location`` and ``field_path`` as appropriate for your application, and then re-raise it or override the ``__str__`` method to throw an exception with different metadata.

Note the ``limit`` parameter in the above code example, which should also be used on any paginated route.
This sets the size of each block of results.
As shown here, you will generally want to set some upper limit on how large the limit can be and set a default limit if none was provided.
This ensures that clients cannot retrieve the full list of results with one query.

Expand Down
5 changes: 3 additions & 2 deletions safir/src/safir/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,22 @@
)
from ._pagination import (
DatetimeIdCursor,
PaginatedLinkData,
PaginatedList,
PaginatedQueryRunner,
PaginationCursor,
PaginationLinkData,
)
from ._retry import retry_async_transaction

__all__ = [
"AlembicConfigError",
"DatabaseInitializationError",
"DatetimeIdCursor",
"PaginatedLinkData",
"InvalidCursorError",
"PaginatedList",
"PaginatedQueryRunner",
"PaginationCursor",
"PaginationLinkData",
"create_async_session",
"create_database_engine",
"datetime_from_db",
Expand Down
33 changes: 26 additions & 7 deletions safir/src/safir/database/_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
from pydantic import BaseModel
from sqlalchemy import Select, and_, func, or_, select
from sqlalchemy.ext.asyncio import async_scoped_session
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute
from sqlalchemy.orm import InstrumentedAttribute
from starlette.datastructures import URL

from safir.fastapi import ClientRequestError
from safir.models import ErrorLocation

from ._datetime import datetime_to_db

_LINK_REGEX = re.compile(r'\s*<(?P<target>[^>]+)>;\s*rel="(?P<type>[^"]+)"')
Expand All @@ -32,15 +35,25 @@

__all__ = [
"DatetimeIdCursor",
"PaginatedLinkData",
"InvalidCursorError",
"PaginatedList",
"PaginatedQueryRunner",
"PaginationCursor",
"PaginationLinkData",
]


class InvalidCursorError(ClientRequestError):
"""The provided cursor was invalid."""

error = "invalid_cursor"

def __init__(self, message: str) -> None:
super().__init__(message, ErrorLocation.query, ["cursor"])


@dataclass
class PaginatedLinkData:
class PaginationLinkData:
"""Holds the data returned in an :rfc:`8288` ``Link`` header."""

prev_url: str | None
Expand All @@ -63,7 +76,7 @@ def from_header(cls, header: str | None) -> Self:
Returns
-------
PaginatedLinkData
PaginationLinkData
Parsed form of that header.
"""
links = {}
Expand Down Expand Up @@ -135,7 +148,7 @@ def from_str(cls, cursor: str) -> Self:
Raises
------
ValueError
InvalidCursorError
Raised if the cursor is invalid.
"""

Expand Down Expand Up @@ -256,7 +269,7 @@ def from_str(cls, cursor: str) -> Self:
previous=previous,
)
except Exception as e:
raise ValueError(f"Cannot parse cursor: {e!s}") from e
raise InvalidCursorError(f"Cannot parse cursor: {e!s}") from e

@classmethod
def apply_order(cls, stmt: Select, *, reverse: bool = False) -> Select:
Expand Down Expand Up @@ -452,7 +465,7 @@ async def query_count(
async def query_object(
self,
session: async_scoped_session,
stmt: Select[tuple[DeclarativeBase]],
stmt: Select[tuple],
*,
cursor: C | None = None,
limit: int | None = None,
Expand All @@ -469,6 +482,9 @@ async def query_object(
``entry_type``. For queries returning a tuple of attributes, use
`query_row` instead.
Unfortunately, this distinction cannot be type-checked, so be careful
to use the correct method.
Parameters
----------
session
Expand Down Expand Up @@ -512,6 +528,9 @@ async def query_row(
attributes that can be converted to the ``entry_type`` Pydantic model.
For queries returning a single ORM object, use `query_object` instead.
Unfortunately, this distinction cannot be type-checked, so be careful
to use the correct method.
Parameters
----------
session
Expand Down
12 changes: 6 additions & 6 deletions safir/tests/database_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

from safir.database import (
DatetimeIdCursor,
PaginatedLinkData,
PaginatedQueryRunner,
PaginationLinkData,
create_async_session,
create_database_engine,
datetime_from_db,
Expand Down Expand Up @@ -516,7 +516,7 @@ def test_link_data() -> None:
'<https://example.com/query>; rel="first", '
'<https://example.com/query?cursor=1600000000.5_1>; rel="next"'
)
link = PaginatedLinkData.from_header(header)
link = PaginationLinkData.from_header(header)
assert not link.prev_url
assert link.next_url == "https://example.com/query?cursor=1600000000.5_1"
assert link.first_url == "https://example.com/query"
Expand All @@ -526,7 +526,7 @@ def test_link_data() -> None:
'<https://example.com/query?limit=10&cursor=15_2>; rel="next", '
'<https://example.com/query?limit=10&cursor=p5_1>; rel="prev"'
)
link = PaginatedLinkData.from_header(header)
link = PaginationLinkData.from_header(header)
assert link.prev_url == "https://example.com/query?limit=10&cursor=p5_1"
assert link.next_url == "https://example.com/query?limit=10&cursor=15_2"
assert link.first_url == "https://example.com/query?limit=10"
Expand All @@ -535,18 +535,18 @@ def test_link_data() -> None:
'<https://example.com/query>; rel="first", '
'<https://example.com/query?cursor=p1510000000_2>; rel="previous"'
)
link = PaginatedLinkData.from_header(header)
link = PaginationLinkData.from_header(header)
assert link.prev_url == "https://example.com/query?cursor=p1510000000_2"
assert not link.next_url
assert link.first_url == "https://example.com/query"

header = '<https://example.com/query?foo=b>; rel="first"'
link = PaginatedLinkData.from_header(header)
link = PaginationLinkData.from_header(header)
assert not link.prev_url
assert not link.next_url
assert link.first_url == "https://example.com/query?foo=b"

link = PaginatedLinkData.from_header("")
link = PaginationLinkData.from_header("")
assert not link.prev_url
assert not link.next_url
assert not link.first_url

0 comments on commit b00b61f

Please sign in to comment.