Skip to content

Commit

Permalink
Add abstraction for counted paginated queries
Browse files Browse the repository at this point in the history
`PaginatedQueryRunner` has a separate `query_count` method to return
the count, but since the total entry count isn't included in the
`PaginatedList` data structure, it's annoying to pass between
components of a service. Either the result plus the count have to be
returned as a tuple, or the service has to add its own data structure
to wrap `PaginatedList`.

Avoid this by introducing a `CountedPaginatedQueryRunner` and a
corresponding `CountedPaginatedList` that always includes a `count`
attribute. This can be used by services such as Gafaelfawr that always
want to count the total number of entries, either because the table
is small or because the count can always be satisfied from the table
indices.
  • Loading branch information
rra committed Dec 16, 2024
1 parent 3335dfa commit 42821ae
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 20 deletions.
52 changes: 51 additions & 1 deletion docs/user-guide/database/pagination.rst
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,60 @@ Here is a very simplified example of a route handler that sets this header:
Here, ``perform_query`` is a wrapper around `~safir.database.PaginatedQueryRunner` that constructs and runs the query.
A real route handler would have more query parameters and more documentation.

Note that this example also sets a non-standard ``X-Total-Count`` header containing the total count of entries returned by the underlying query without pagination.
Including result counts
-----------------------

The example above also sets a non-standard ``X-Total-Count`` header containing the total count of entries returned by the underlying query without pagination.
`~safir.database.PaginatedQueryRunner.query_count` will return this information.
There is no standard way to return this information to the client, but ``X-Total-Count`` is a widely-used informal standard.

If you will always want to include the count, use `~safir.database.CountedPaginatedQueryRunner` instead.
Its `~safir.database.CountedPaginatedQueryRunner.query_object` and `~safir.database.CountedPaginatedQueryRunner.query_row` methods will return a `~safir.database.CountedPaginatedList`, which contains a ``count`` attribute holding the count.
This is equivalent to calling `~safir.database.PaginatedQueryRunner.query_object` or `~safir.database.PaginatedQueryRunner.query_object` followed by `~safir.database.PaginatedQueryRunner.query_count`, but the encapsulation into a data structure makes it easier to pass the results between components of the service.

Here's the same code above but using that approach:

.. code-block:: python
.. code-block:: python
:emphasize-lines: 27, 34
@router.get("/query", response_class=Model)
async def query(
*,
cursor: Annotated[
str | None,
Query(
title="Pagination cursor",
description="Cursor to navigate paginated results",
),
] = None,
limit: Annotated[
int,
Query(
title="Row limit",
description="Maximum number of entries to return",
examples=[100],
ge=1,
le=100,
),
] = 100,
request: Request,
response: Response,
) -> list[Model]:
parsed_cursor = None
if cursor:
parsed_cursor = ModelCursor.from_str(cursor)
runner = CountedPaginatedQueryRunner(Model, ModelCursor)
stmt = build_query(...)
results = await runner.query_object(
session, stmt, cursor=parsed_cursor, limit=limit
)
if cursor or limit:
response.headers["Link"] = results.link_header(request.url)
response.headers["X-Total-Count"] = str(results.count)
return results.entries
Including links in the response
-------------------------------

Expand Down
4 changes: 4 additions & 0 deletions safir/src/safir/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
initialize_database,
)
from ._pagination import (
CountedPaginatedList,
CountedPaginatedQueryRunner,
DatetimeIdCursor,
InvalidCursorError,
PaginatedList,
Expand All @@ -28,6 +30,8 @@

__all__ = [
"AlembicConfigError",
"CountedPaginatedList",
"CountedPaginatedQueryRunner",
"DatabaseInitializationError",
"DatetimeIdCursor",
"InvalidCursorError",
Expand Down
160 changes: 153 additions & 7 deletions safir/src/safir/database/_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
"""Type of an entry in a paginated list."""

__all__ = [
"CountedPaginatedList",
"CountedPaginatedQueryRunner",
"DatetimeIdCursor",
"InvalidCursorError",
"PaginatedList",
Expand Down Expand Up @@ -321,19 +323,18 @@ def __str__(self) -> str:
class PaginatedList(Generic[E, C]):
"""Paginated SQL results with accompanying pagination metadata.
Holds a paginated list of any Pydantic type, complete with a count and
cursors. Can hold any type of entry and any type of cursor, but implicitly
requires the entry type be one that is meaningfully paginated by that type
of cursor.
Holds a paginated list of any Pydantic type with pagination cursors. Can
hold any type of entry and any type of cursor, but implicitly requires the
entry type be one that is meaningfully paginated by that type of cursor.
"""

entries: list[E]
"""A batch of entries."""

next_cursor: C | None = None
next_cursor: C | None
"""Cursor for the next batch of entries."""

prev_cursor: C | None = None
prev_cursor: C | None
"""Cursor for the previous batch of entries."""

def first_url(self, current_url: URL) -> str:
Expand Down Expand Up @@ -418,7 +419,7 @@ def link_header(self, current_url: URL) -> str:


class PaginatedQueryRunner(Generic[E, C]):
"""Construct and run database queries that return paginated results.
"""Run database queries that return paginated results.
This class implements the logic for keyset pagination based on arbitrary
SQLAlchemy ORM where clauses.
Expand Down Expand Up @@ -688,3 +689,148 @@ async def _paginated_query(
return PaginatedList[E, C](
entries=entries, prev_cursor=prev_cursor, next_cursor=next_cursor
)


@dataclass
class CountedPaginatedList(PaginatedList[E, C]):
"""Paginated SQL results with pagination metadata and total count.
Holds a paginated list of any Pydantic type, complete with a count and
cursors. Can hold any type of entry and any type of cursor, but implicitly
requires the entry type be one that is meaningfully paginated by that type
of cursor.
"""

count: int
"""Total number of entries if queried without pagination."""


class CountedPaginatedQueryRunner(PaginatedQueryRunner[E, C]):
"""Run database queries that return paginated results with counts.
This variation of `PaginatedQueryRunner` always runs a second query to
count the total number of available entries if queried without pagination.
It should only be used on small tables or with queries that can be
satisfied from the table indices; otherwise, the count query could be
undesirably slow.
Parameters
----------
entry_type
Type of each entry returned by the queries. This must be a Pydantic
model.
cursor_type
Type of the pagination cursor, which encapsulates the logic of how
entries are sorted and what set of keys is used to retrieve the next
or previous batch of entries.
"""

async def query_object(
self,
session: async_scoped_session,
stmt: Select[tuple],
*,
cursor: C | None = None,
limit: int | None = None,
) -> CountedPaginatedList[E, C]:
"""Perform a query for objects with an optional cursor and limit.
Perform the query provided in ``stmt`` with appropriate sorting and
pagination as determined by the cursor type. Also performs a second
query to get the total count of entries if retrieved without
pagination.
This method should be used with queries that return a single
SQLAlchemy model. The provided query will be run with the session
`~sqlalchemy.ext.asyncio.async_scoped_session.scalars` method and the
resulting object passed to Pydantic's ``model_validate`` to convert to
``entry_type``. For queries returning a tuple of attributes, use
`query_row` instead.
Unfortunately, this distinction cannot be type-checked, so be careful
to use the correct method.
Parameters
----------
session
Database session within which to run the query.
stmt
Select statement to execute. Pagination and ordering will be
added, so this statement should not already have limits or order
clauses applied. This statement must return a list of SQLAlchemy
ORM models that can be converted to ``entry_type`` by Pydantic.
cursor
If present, continue from the provided keyset cursor.
limit
If present, limit the result count to at most this number of rows.
Returns
-------
CountedPaginatedList
Results of the query wrapped with pagination information and a
count of the total number of entries.
"""
result = await super().query_object(
session, stmt, cursor=cursor, limit=limit
)
count = await self.query_count(session, stmt)
return CountedPaginatedList[E, C](
entries=result.entries,
next_cursor=result.next_cursor,
prev_cursor=result.prev_cursor,
count=count,
)

async def query_row(
self,
session: async_scoped_session,
stmt: Select[tuple],
*,
cursor: C | None = None,
limit: int | None = None,
) -> CountedPaginatedList[E, C]:
"""Perform a query for attributes with an optional cursor and limit.
Perform the query provided in ``stmt`` with appropriate sorting and
pagination as determined by the cursor type. Also performs a second
query to get the total count of entries if retrieved without
pagination.
This method should be used with queries that return a list of
attributes that can be converted to the ``entry_type`` Pydantic model.
For queries returning a single ORM object, use `query_object` instead.
Unfortunately, this distinction cannot be type-checked, so be careful
to use the correct method.
Parameters
----------
session
Database session within which to run the query.
stmt
Select statement to execute. Pagination and ordering will be
added, so this statement should not already have limits or order
clauses applied. This statement must return a tuple of attributes
that can be converted to ``entry_type`` by Pydantic's
``model_validate``.
cursor
If present, continue from the provided keyset cursor.
limit
If present, limit the result count to at most this number of rows.
Returns
-------
CountedPaginatedList
Results of the query wrapped with pagination information and a
count of the total number of entries.
"""
result = await super().query_row(
session, stmt, cursor=cursor, limit=limit
)
count = await self.query_count(session, stmt)
return CountedPaginatedList[E, C](
entries=result.entries,
next_cursor=result.next_cursor,
prev_cursor=result.prev_cursor,
count=count,
)
47 changes: 35 additions & 12 deletions safir/tests/database_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from starlette.datastructures import URL

from safir.database import (
CountedPaginatedQueryRunner,
DatetimeIdCursor,
PaginatedQueryRunner,
PaginationLinkData,
Expand Down Expand Up @@ -410,11 +411,12 @@ async def test_pagination(database_url: str, database_password: str) -> None:

# Query by object and test the pagination cursors going backwards and
# forwards.
builder = PaginatedQueryRunner(PaginationModel, TableCursor)
runner = PaginatedQueryRunner(PaginationModel, TableCursor)
counted_runner = CountedPaginatedQueryRunner(PaginationModel, TableCursor)
async with session.begin():
stmt: Select[tuple] = select(PaginationTable)
assert await builder.query_count(session, stmt) == 7
result = await builder.query_object(session, stmt, limit=2)
assert await runner.query_count(session, stmt) == 7
result = await runner.query_object(session, stmt, limit=2)
assert_model_lists_equal(result.entries, rows[:2])
assert not result.prev_cursor
base_url = URL("https://example.com/query")
Expand All @@ -427,7 +429,15 @@ async def test_pagination(database_url: str, database_password: str) -> None:
assert result.prev_url(base_url) is None
assert str(result.next_cursor) == "1600000000.5_1"

result = await builder.query_object(
counted_result = await counted_runner.query_object(
session, stmt, limit=2
)
assert counted_result.entries == result.entries
assert counted_result.prev_cursor == result.prev_cursor
assert counted_result.next_cursor == result.next_cursor
assert counted_result.count == 7

result = await runner.query_object(
session, stmt, cursor=result.next_cursor, limit=3
)
assert_model_lists_equal(result.entries, rows[2:5])
Expand All @@ -447,7 +457,7 @@ async def test_pagination(database_url: str, database_password: str) -> None:
assert result.prev_url(base_url) == prev_url
next_cursor = result.next_cursor

result = await builder.query_object(
result = await runner.query_object(
session, stmt, cursor=result.prev_cursor
)
assert_model_lists_equal(result.entries, rows[:2])
Expand All @@ -457,7 +467,7 @@ async def test_pagination(database_url: str, database_password: str) -> None:
f'<{base_url!s}&cursor={result.next_cursor}>; rel="next"'
)

result = await builder.query_object(session, stmt, cursor=next_cursor)
result = await runner.query_object(session, stmt, cursor=next_cursor)
assert_model_lists_equal(result.entries, rows[5:])
assert not result.next_cursor
base_url = URL("https://example.com/query")
Expand All @@ -468,14 +478,14 @@ async def test_pagination(database_url: str, database_password: str) -> None:
)
prev_cursor = result.prev_cursor

result = await builder.query_object(session, stmt, cursor=prev_cursor)
result = await runner.query_object(session, stmt, cursor=prev_cursor)
assert_model_lists_equal(result.entries, rows[:5])
assert result.link_header(base_url) == (
f'<{base_url!s}>; rel="first", '
f'<{base_url!s}?cursor={result.next_cursor}>; rel="next"'
)

result = await builder.query_object(
result = await runner.query_object(
session, stmt, cursor=prev_cursor, limit=2
)
assert_model_lists_equal(result.entries, rows[3:5])
Expand All @@ -490,26 +500,39 @@ async def test_pagination(database_url: str, database_password: str) -> None:
# function.
async with session.begin():
stmt = select(PaginationTable.time, PaginationTable.id)
result = await builder.query_row(session, stmt, limit=2)
result = await runner.query_row(session, stmt, limit=2)
assert_model_lists_equal(result.entries, rows[:2])
assert await builder.query_count(session, stmt) == 7
assert await runner.query_count(session, stmt) == 7

counted_result = await counted_runner.query_row(session, stmt, limit=2)
assert counted_result.entries == result.entries
assert counted_result.prev_cursor == result.prev_cursor
assert counted_result.next_cursor == result.next_cursor
assert counted_result.count == 7

# Querying for the entire table should return the everything with no
# pagination cursors. Try this with both an object query and an attribute
# query.
async with session.begin():
result = await builder.query_object(session, select(PaginationTable))
stmt = select(PaginationTable)
result = await runner.query_object(session, stmt)
assert_model_lists_equal(result.entries, rows)
assert not result.next_cursor
assert not result.prev_cursor
stmt = select(PaginationTable.id, PaginationTable.time)
result = await builder.query_row(session, stmt)
result = await runner.query_row(session, stmt)
assert_model_lists_equal(result.entries, rows)
assert not result.next_cursor
assert not result.prev_cursor
base_url = URL("https://example.com/query?foo=b")
assert result.link_header(base_url) == (f'<{base_url!s}>; rel="first"')

counted_result = await counted_runner.query_row(session, stmt)
assert counted_result.entries == result.entries
assert not counted_result.next_cursor
assert not counted_result.prev_cursor
assert counted_result.count == len(counted_result.entries)


def test_link_data() -> None:
header = (
Expand Down

0 comments on commit 42821ae

Please sign in to comment.