Skip to content

Commit

Permalink
Merge pull request #167 from evo-company/97-fix-aiopg_and_sqla_compat…
Browse files Browse the repository at this point in the history
…ibility

Fix aiopg + sqlalchemy >= 1.4 compatibility issue
  • Loading branch information
kindermax authored Sep 15, 2024
2 parents 38ecdd2 + 712a307 commit b2719ae
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 1 deletion.
11 changes: 11 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ jobs:
- name: Run unit tests
run: tox run -- --cov-report=term

test-db:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install Lets
uses: lets-cli/[email protected]
with:
version: latest
- name: Test database integration
run: timeout 600 lets test-pg

federation-test:
runs-on: ubuntu-latest
steps:
Expand Down
42 changes: 42 additions & 0 deletions hiku/sources/aiopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
Iterable,
Any,
List,
Iterator,
Tuple,
)

import sqlalchemy
from sqlalchemy.sql import Select
from sqlalchemy import any_
from sqlalchemy.sql.elements import BinaryExpression

Expand All @@ -17,12 +20,51 @@
FETCH_SIZE = 100


def _uniq_fields(fields: List[Field]) -> Iterator[Field]:
visited = set()
for f in fields:
if f.name not in visited:
visited.add(f.name)
yield f


class FieldsQuery(_sa.FieldsQuery):
def in_impl(
self, column: sqlalchemy.Column, values: Iterable
) -> BinaryExpression:
return column == any_(values)

def select_expr(
self, fields_: List[Field], ids: Iterable
) -> Tuple[Select, Callable]:
result_columns = [self.from_clause.c[f.name] for f in fields_]
# aiopg requires unique columns to be passed to select,
# otherwise it will raise an error
query_columns = [
column
for f in _uniq_fields(fields_)
if (column := self.from_clause.c[f.name]) != self.primary_key
]

expr = (
sqlalchemy.select(
*_sa._process_select_params([self.primary_key] + query_columns)
)
.select_from(self.from_clause)
.where(self.in_impl(self.primary_key, ids))
)

def result_proc(rows: List[_sa.Row]) -> List:
rows_map = {
row[self.primary_key]: [row[c] for c in result_columns]
for row in map(_sa._process_result_row, rows)
}

nulls = [None for _ in fields_]
return [rows_map.get(id_, nulls) for id_ in ids]

return expr, result_proc

async def __call__(
self, ctx: Context, fields_: List[Field], ids: Iterable
) -> List:
Expand Down
2 changes: 1 addition & 1 deletion lets.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ commands:
test-pg:
description: Run tests with pg
depends: [_build-tests]
cmd: [docker-compose, run, --rm, test-pg]
cmd: [docker compose, run, --rm, test-pg]

test-tox:
description: Run tests using tox
Expand Down

0 comments on commit b2719ae

Please sign in to comment.