diff --git a/pyproject.toml b/pyproject.toml index ff440f1..cab9be5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sqlmodel_repo" -version = "0.0.1" +version = "0.0.2" description = "Active record mixin for SQLModel" readme = "README.md" keywords = ["sqlmodel", "sqlalchemy", "orm", "active record"] diff --git a/sqlmodel_repo.py b/sqlmodel_repo.py index 562096b..c38bf00 100644 --- a/sqlmodel_repo.py +++ b/sqlmodel_repo.py @@ -1,7 +1,12 @@ from contextlib import contextmanager from typing import Optional -from sqlmodel import Session, SQLModel, select, func, text +from sqlmodel import Session, SQLModel, select, delete, update, func + +try: + from fastapi.exceptions import HTTPException +except Exception: + pass @contextmanager @@ -32,71 +37,11 @@ def reuse_session_or_new(db_engine=None, session: Optional[Session] = None): session.close() -class CollectionResult: - def __init__(self, stmt, model, db_engine, session=None): - self.stmt = stmt - self.model = model - self.db_engine = db_engine - self.session = session - - def paginate( - self, - offset: int, - limit: int, - order_by: str, - desc: bool = False - ) -> list: - """Paginate results""" - with reuse_session_or_new(self.db_engine, self.session) as session: - return self._paginate(session, offset, limit, order_by, desc) - - def paginate_with_total( - self, - offset: int, - limit: int, - order_by: str, - desc: bool = False - ) -> (list, int): - """Paginate results and fetch total count - - Returns: - tuple(list, int) - Items and total count. - """ - with reuse_session_or_new(self.db_engine, self.session) as session: - count_stmt = select(func.count()).select_from(self.stmt.subquery()) - count = session.execute(count_stmt).scalar() - results = self._paginate(session, offset, limit, order_by, desc) - return results, count - - def _paginate( - self, - session, - offset: int, - limit: int, - order_by: str, - desc: bool = False - ) -> list: - order_by = getattr(self.model, order_by) - if desc: - order_by = getattr(order_by, 'desc')() - return session.exec( - self.stmt.order_by(order_by).offset(offset).limit(limit) - ).all() - - def all(self) -> list: - """Get all results""" - with reuse_session_or_new(self.db_engine, self.session) as session: - return session.exec(self.stmt).all() - - def count(self): - """Get total results count""" - with reuse_session_or_new(self.db_engine, self.session) as session: - count_stmt = select(func.count()).select_from(self.stmt.subquery()) - return session.execute(count_stmt).scalar() - - class SQLModelRepo: - def __init__(self, model: SQLModel, db_engine): + def __init__( + self, model: SQLModel, db_engine, + init_stmt=None, session=None + ): """ Generic repository for SQLModel. Args: @@ -109,18 +54,19 @@ def __init__(self, model: SQLModel, db_engine): users_repo.get_by_id(1) """ self.model = model + self._init_stmt = init_stmt self.db_engine = db_engine - self._session = None + self.session = session def __call__(self, session): new_repo = SQLModelRepo(model=self.model, db_engine=self.db_engine) - new_repo._session = session + new_repo.session = session return new_repo def create(self, **kwargs): """Create a new record and save to the database.""" instance = self.model(**kwargs) - with reuse_session_or_new(self.db_engine, self._session) as session: + with reuse_session_or_new(self.db_engine, self.session) as session: session.add(instance) session.commit() session.refresh(instance) @@ -128,52 +74,81 @@ def create(self, **kwargs): def get_by_id(self, id, *fields): """Fetch an object by its primary key.""" - select_obj = self._get_select_obj(fields) - with reuse_session_or_new(self.db_engine, self._session) as session: + stmt = self.init_stmt(*fields) + with reuse_session_or_new(self.db_engine, self.session) as session: return session.exec( - select(*select_obj).where( + stmt.where( getattr(self.model, 'id') == id ) ).first() def save(self, instance): """Save the current object (instance) to the database.""" - with reuse_session_or_new(self.db_engine, self._session) as session: + with reuse_session_or_new(self.db_engine, self.session) as session: session.add(instance) session.commit() session.refresh(instance) + def save_or_update(self, instance): + """Save the current object (instance) to the database.""" + with reuse_session_or_new(self.db_engine, self.session) as session: + existing_obj = session.exec( + select(self.model).where( + self.model.id == instance.id + ) + ).first() + if existing_obj: + for k, v in instance.model_dump().items(): + setattr(existing_obj, k, v) + session.add(existing_obj) + instance = existing_obj + else: + session.add(instance) + session.commit() + session.refresh(instance) + def update(self, id, **kwargs): """Record partial update.""" - with reuse_session_or_new(self.db_engine, self._session) as session: - set_statements = ", ".join( - f"{field} = :{field}" for field in kwargs.keys() - ) - kwargs['id'] = id - query = f""" - UPDATE {self.model.__tablename__} - SET {set_statements} - WHERE id = :id - """ - session.execute(text(query), kwargs) + with reuse_session_or_new(self.db_engine, self.session) as session: + update_stmt = update(self.model).where( + self.model.id == id + ).values(**kwargs) + session.execute(update_stmt) + session.commit() + + def update_all(self, **kwargs): + """Partial update for all selected records.""" + with reuse_session_or_new(self.db_engine, self.session) as session: + if self._init_stmt: + update_stmt = update(self.model).where( + self.init_stmt().whereclause + ).values(**kwargs) + else: + update_stmt = update(self.model).values(**kwargs) + session.execute(update_stmt) session.commit() def delete(self, instance): """Delete an object from the database.""" - with reuse_session_or_new(self.db_engine, self._session) as session: + with reuse_session_or_new(self.db_engine, self.session) as session: session.delete(instance) session.commit() - def all(self, *fields) -> list: - """Return all records.""" - select_obj = self._get_select_obj(fields) - with reuse_session_or_new(self.db_engine, self._session) as session: - return session.exec(select(*select_obj)).all() + def delete_all(self): + """Delete all records in query.""" + with reuse_session_or_new(self.db_engine, self.session) as session: + if self._init_stmt: + delete_stmt = delete(self.model).where( + self.init_stmt().whereclause + ) + else: + delete_stmt = delete(self.model) + session.execute(delete_stmt) + session.commit() - def filter(self, *filters, _fields=(), **kwargs) -> CollectionResult: + def filter(self, *filters, _fields=(), **kwargs) -> 'SQLModelRepo': """Filter records based on provided conditions.""" - select_obj = self._get_select_obj(_fields) - stmt = select(*select_obj).where( + stmt = self.init_stmt(*_fields).where( *filters, *[ getattr(self.model, k) == v @@ -181,15 +156,100 @@ def filter(self, *filters, _fields=(), **kwargs) -> CollectionResult: for k, v in kwargs.items() ] ) - return CollectionResult( - stmt=stmt, + return SQLModelRepo( + init_stmt=stmt, model=self.model, db_engine=self.db_engine, - session=self._session + session=self.session ) + def paginate( + self, + offset: int, + limit: int, + order_by: str, + desc: bool = False + ) -> list: + """Paginate results""" + with reuse_session_or_new(self.db_engine, self.session) as session: + return self._paginate(session, offset, limit, order_by, desc) + + def paginate_with_total( + self, + offset: int, + limit: int, + order_by: str, + desc: bool = False + ) -> (list, int): + """Paginate results and fetch total count + + Returns: + tuple(list, int) - Items and total count. + """ + with reuse_session_or_new(self.db_engine, self.session) as session: + count_stmt = select(func.count()).select_from( + self.init_stmt().subquery() + ) + count = session.execute(count_stmt).scalar() + results = self._paginate(session, offset, limit, order_by, desc) + return results, count + + def _paginate( + self, + session, + offset: int, + limit: int, + order_by: str, + desc: bool = False + ) -> list: + order_by = getattr(self.model, order_by) + if desc: + order_by = getattr(order_by, 'desc')() + return session.exec( + self.init_stmt().order_by(order_by).offset(offset).limit(limit) + ).all() + + def all(self) -> list: + """Get all results""" + with reuse_session_or_new(self.db_engine, self.session) as session: + return session.exec(self.init_stmt()).all() + + def count(self): + """Get total results count""" + with reuse_session_or_new(self.db_engine, self.session) as session: + count_stmt = select(func.count()).select_from( + self.init_stmt().subquery() + ) + return session.execute(count_stmt).scalar() + + def first(self): + with reuse_session_or_new(self.db_engine, self.session) as session: + return session.exec(self.init_stmt()).first() + + def get_or_404(self, id): + if not (obj := self.get_by_id(id)): + raise HTTPException( + status_code=404, + detail=f'{self.model.__name__.title()} with id {id} not found' + ) + return obj + + def delete_or_404(self, id): + obj = self.get_or_404(id) + self.delete(obj) + + def update_or_404(self, id, **kwargs): + if self.get_or_404(id): + self.update(id, **kwargs) + def _get_select_obj(self, fields=None): return ( [self.model] if not fields else [getattr(self.model, f) for f in fields] ) + + def init_stmt(self, *fields): + if self._init_stmt is not None: + return self._init_stmt + else: + return select(*self._get_select_obj(fields)) diff --git a/test.py b/test.py index 5028c5f..7644a02 100644 --- a/test.py +++ b/test.py @@ -87,6 +87,9 @@ def test_all(): assert total_count == 12 assert users[0].username == 'user9' + assert users_repo.first() + assert users_repo.count() == 12 + # Paginate the results (order by username in ascending order) users, total_count = ( users_repo.filter() @@ -97,6 +100,14 @@ def test_all(): with Session(engine) as session: assert users_repo(session).all() + users[0].username = '1' + users_repo.save_or_update(users[0]) + + users_repo.delete_all() + + users = users_repo.all() + assert not users + if __name__ == '__main__': test_all()