Skip to content

Commit

Permalink
🚋 Add new methods
Browse files Browse the repository at this point in the history
- count
- delete_all
- update_all
Filter can be used multiple times now.
  • Loading branch information
Vitalii Levytskyi committed Oct 1, 2024
1 parent 4f93c29 commit 799ef1a
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 96 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sqlmodel_repo"
version = "0.0.1"
version = "0.0.2"
description = "Active record mixin for SQLModel"
readme = "README.md"
keywords = ["sqlmodel", "sqlalchemy", "orm", "active record"]
Expand Down
250 changes: 155 additions & 95 deletions sqlmodel_repo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from contextlib import contextmanager
from typing import Optional

from sqlmodel import Session, SQLModel, select, func, text
from sqlmodel import Session, SQLModel, select, delete, update, func

try:
from fastapi.exceptions import HTTPException
except Exception:
pass


@contextmanager
Expand Down Expand Up @@ -32,71 +37,11 @@ def reuse_session_or_new(db_engine=None, session: Optional[Session] = None):
session.close()


class CollectionResult:
def __init__(self, stmt, model, db_engine, session=None):
self.stmt = stmt
self.model = model
self.db_engine = db_engine
self.session = session

def paginate(
self,
offset: int,
limit: int,
order_by: str,
desc: bool = False
) -> list:
"""Paginate results"""
with reuse_session_or_new(self.db_engine, self.session) as session:
return self._paginate(session, offset, limit, order_by, desc)

def paginate_with_total(
self,
offset: int,
limit: int,
order_by: str,
desc: bool = False
) -> (list, int):
"""Paginate results and fetch total count
Returns:
tuple(list, int) - Items and total count.
"""
with reuse_session_or_new(self.db_engine, self.session) as session:
count_stmt = select(func.count()).select_from(self.stmt.subquery())
count = session.execute(count_stmt).scalar()
results = self._paginate(session, offset, limit, order_by, desc)
return results, count

def _paginate(
self,
session,
offset: int,
limit: int,
order_by: str,
desc: bool = False
) -> list:
order_by = getattr(self.model, order_by)
if desc:
order_by = getattr(order_by, 'desc')()
return session.exec(
self.stmt.order_by(order_by).offset(offset).limit(limit)
).all()

def all(self) -> list:
"""Get all results"""
with reuse_session_or_new(self.db_engine, self.session) as session:
return session.exec(self.stmt).all()

def count(self):
"""Get total results count"""
with reuse_session_or_new(self.db_engine, self.session) as session:
count_stmt = select(func.count()).select_from(self.stmt.subquery())
return session.execute(count_stmt).scalar()


class SQLModelRepo:
def __init__(self, model: SQLModel, db_engine):
def __init__(
self, model: SQLModel, db_engine,
init_stmt=None, session=None
):
""" Generic repository for SQLModel.
Args:
Expand All @@ -109,87 +54,202 @@ def __init__(self, model: SQLModel, db_engine):
users_repo.get_by_id(1)
"""
self.model = model
self._init_stmt = init_stmt
self.db_engine = db_engine
self._session = None
self.session = session

def __call__(self, session):
new_repo = SQLModelRepo(model=self.model, db_engine=self.db_engine)
new_repo._session = session
new_repo.session = session
return new_repo

def create(self, **kwargs):
"""Create a new record and save to the database."""
instance = self.model(**kwargs)
with reuse_session_or_new(self.db_engine, self._session) as session:
with reuse_session_or_new(self.db_engine, self.session) as session:
session.add(instance)
session.commit()
session.refresh(instance)
return instance

def get_by_id(self, id, *fields):
"""Fetch an object by its primary key."""
select_obj = self._get_select_obj(fields)
with reuse_session_or_new(self.db_engine, self._session) as session:
stmt = self.init_stmt(*fields)
with reuse_session_or_new(self.db_engine, self.session) as session:
return session.exec(
select(*select_obj).where(
stmt.where(
getattr(self.model, 'id') == id
)
).first()

def save(self, instance):
"""Save the current object (instance) to the database."""
with reuse_session_or_new(self.db_engine, self._session) as session:
with reuse_session_or_new(self.db_engine, self.session) as session:
session.add(instance)
session.commit()
session.refresh(instance)

def save_or_update(self, instance):
"""Save the current object (instance) to the database."""
with reuse_session_or_new(self.db_engine, self.session) as session:
existing_obj = session.exec(
select(self.model).where(
self.model.id == instance.id
)
).first()
if existing_obj:
for k, v in instance.model_dump().items():
setattr(existing_obj, k, v)
session.add(existing_obj)
instance = existing_obj
else:
session.add(instance)
session.commit()
session.refresh(instance)

def update(self, id, **kwargs):
"""Record partial update."""
with reuse_session_or_new(self.db_engine, self._session) as session:
set_statements = ", ".join(
f"{field} = :{field}" for field in kwargs.keys()
)
kwargs['id'] = id
query = f"""
UPDATE {self.model.__tablename__}
SET {set_statements}
WHERE id = :id
"""
session.execute(text(query), kwargs)
with reuse_session_or_new(self.db_engine, self.session) as session:
update_stmt = update(self.model).where(
self.model.id == id
).values(**kwargs)
session.execute(update_stmt)
session.commit()

def update_all(self, **kwargs):
"""Partial update for all selected records."""
with reuse_session_or_new(self.db_engine, self.session) as session:
if self._init_stmt:
update_stmt = update(self.model).where(
self.init_stmt().whereclause
).values(**kwargs)
else:
update_stmt = update(self.model).values(**kwargs)
session.execute(update_stmt)
session.commit()

def delete(self, instance):
"""Delete an object from the database."""
with reuse_session_or_new(self.db_engine, self._session) as session:
with reuse_session_or_new(self.db_engine, self.session) as session:
session.delete(instance)
session.commit()

def all(self, *fields) -> list:
"""Return all records."""
select_obj = self._get_select_obj(fields)
with reuse_session_or_new(self.db_engine, self._session) as session:
return session.exec(select(*select_obj)).all()
def delete_all(self):
"""Delete all records in query."""
with reuse_session_or_new(self.db_engine, self.session) as session:
if self._init_stmt:
delete_stmt = delete(self.model).where(
self.init_stmt().whereclause
)
else:
delete_stmt = delete(self.model)
session.execute(delete_stmt)
session.commit()

def filter(self, *filters, _fields=(), **kwargs) -> CollectionResult:
def filter(self, *filters, _fields=(), **kwargs) -> 'SQLModelRepo':
"""Filter records based on provided conditions."""
select_obj = self._get_select_obj(_fields)
stmt = select(*select_obj).where(
stmt = self.init_stmt(*_fields).where(
*filters,
*[
getattr(self.model, k) == v
if isinstance(k, str) else k == v
for k, v in kwargs.items()
]
)
return CollectionResult(
stmt=stmt,
return SQLModelRepo(
init_stmt=stmt,
model=self.model,
db_engine=self.db_engine,
session=self._session
session=self.session
)

def paginate(
self,
offset: int,
limit: int,
order_by: str,
desc: bool = False
) -> list:
"""Paginate results"""
with reuse_session_or_new(self.db_engine, self.session) as session:
return self._paginate(session, offset, limit, order_by, desc)

def paginate_with_total(
self,
offset: int,
limit: int,
order_by: str,
desc: bool = False
) -> (list, int):
"""Paginate results and fetch total count
Returns:
tuple(list, int) - Items and total count.
"""
with reuse_session_or_new(self.db_engine, self.session) as session:
count_stmt = select(func.count()).select_from(
self.init_stmt().subquery()
)
count = session.execute(count_stmt).scalar()
results = self._paginate(session, offset, limit, order_by, desc)
return results, count

def _paginate(
self,
session,
offset: int,
limit: int,
order_by: str,
desc: bool = False
) -> list:
order_by = getattr(self.model, order_by)
if desc:
order_by = getattr(order_by, 'desc')()
return session.exec(
self.init_stmt().order_by(order_by).offset(offset).limit(limit)
).all()

def all(self) -> list:
"""Get all results"""
with reuse_session_or_new(self.db_engine, self.session) as session:
return session.exec(self.init_stmt()).all()

def count(self):
"""Get total results count"""
with reuse_session_or_new(self.db_engine, self.session) as session:
count_stmt = select(func.count()).select_from(
self.init_stmt().subquery()
)
return session.execute(count_stmt).scalar()

def first(self):
with reuse_session_or_new(self.db_engine, self.session) as session:
return session.exec(self.init_stmt()).first()

def get_or_404(self, id):
if not (obj := self.get_by_id(id)):
raise HTTPException(
status_code=404,
detail=f'{self.model.__name__.title()} with id {id} not found'
)
return obj

def delete_or_404(self, id):
obj = self.get_or_404(id)
self.delete(obj)

def update_or_404(self, id, **kwargs):
if self.get_or_404(id):
self.update(id, **kwargs)

def _get_select_obj(self, fields=None):
return (
[self.model] if not fields
else [getattr(self.model, f) for f in fields]
)

def init_stmt(self, *fields):
if self._init_stmt is not None:
return self._init_stmt
else:
return select(*self._get_select_obj(fields))
11 changes: 11 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def test_all():
assert total_count == 12
assert users[0].username == 'user9'

assert users_repo.first()
assert users_repo.count() == 12

# Paginate the results (order by username in ascending order)
users, total_count = (
users_repo.filter()
Expand All @@ -97,6 +100,14 @@ def test_all():
with Session(engine) as session:
assert users_repo(session).all()

users[0].username = '1'
users_repo.save_or_update(users[0])

users_repo.delete_all()

users = users_repo.all()
assert not users


if __name__ == '__main__':
test_all()

0 comments on commit 799ef1a

Please sign in to comment.