From 417d4ccecae46d10196c449d95040e9cb25a90d6 Mon Sep 17 00:00:00 2001 From: pylover Date: Thu, 22 Aug 2024 04:52:58 +0330 Subject: [PATCH] Enahncement: API --- tests/test_extension.py | 14 +++++--------- yhttp/ext/sqlalchemy/install.py | 4 ++++ yhttp/ext/sqlalchemy/orm.py | 29 ++++------------------------- 3 files changed, 13 insertions(+), 34 deletions(-) diff --git a/tests/test_extension.py b/tests/test_extension.py index 5381dd9..719ad9d 100644 --- a/tests/test_extension.py +++ b/tests/test_extension.py @@ -37,7 +37,6 @@ class Foo(Base): @app.route() @json - @app.db def get(req): with app.db.session() as session: result = session.scalars(select(Foo)).all() @@ -45,24 +44,21 @@ def get(req): @app.route() @json - @app.db def got(req): - Foo(title='foo') + with app.db.session.begin(): + app.db.session.add(Foo(title='foo')) raise statuses.created() @app.route() @json - @app.db def err(req): - Foo(title='qux') + with app.db.session() as session: + session.add(Foo(title='qux')) raise statuses.badrequest() def getfoo(title): with app.db.session() as session: - result = session.scalars( - select(Foo).where(Foo.title == title) - ).first() - return result + return session.query(Foo).filter_by(title=title).first() with Given(): assert status == 200 diff --git a/yhttp/ext/sqlalchemy/install.py b/yhttp/ext/sqlalchemy/install.py index ffd83a3..fc82c39 100644 --- a/yhttp/ext/sqlalchemy/install.py +++ b/yhttp/ext/sqlalchemy/install.py @@ -20,4 +20,8 @@ def ready(app): def shutdown(app): app.db.disconnect() + @app.when + def endresponse(response): + app.db.session.reset() + app.db = db diff --git a/yhttp/ext/sqlalchemy/orm.py b/yhttp/ext/sqlalchemy/orm.py index f72ff6f..db0ea85 100644 --- a/yhttp/ext/sqlalchemy/orm.py +++ b/yhttp/ext/sqlalchemy/orm.py @@ -1,9 +1,6 @@ -import functools - from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, close_all_sessions, Session, \ scoped_session -from yhttp.core import HTTPStatus class ORM: @@ -11,10 +8,7 @@ def __init__(self, basemodel, url=None): self.url = url self.engine = None self.basemodel = basemodel - self._session = scoped_session(sessionmaker()) - - def session(self): - return self._session() + self.session = scoped_session(sessionmaker()) def copy(self, url=None): return ORM(self.basemodel, url=url or self.app.settings.db.url) @@ -28,12 +22,12 @@ def connect(self, url=None): assert u is not None self.engine = create_engine(u, isolation_level='REPEATABLE READ') - self._session.configure(bind=self.engine) + self.session.configure(bind=self.engine) def disconnect(self): close_all_sessions() - self._session.expunge_all() - self._session.remove() + self.session.expunge_all() + self.session.remove() self.engine.dispose() self.engine = None @@ -58,18 +52,3 @@ def connect(self, url=None): ) return super().connect(url=url or self.app.settings.db.url) - - def __call__(self, handler): - @functools.wraps(handler) - def outter(*a, **kw): - try: - return handler(*a, **kw) - except HTTPStatus as ex: - if ex.keepheaders: - return ex - - raise - finally: - self._session.reset() - - return outter