diff --git a/requirements-ci.txt b/requirements-ci.txt index 23f05ee..1c6cbfd 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -5,4 +5,4 @@ coveralls pytest >= 4.4.0 pytest-cov flake8 -yhttp-dev >= 3.1.1 +yhttp-dev >= 3.1.2 diff --git a/setup.py b/setup.py index 4a7ccb5..0619d8b 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ dependencies = [ - 'yhttp >= 5, < 6', + 'yhttp >= 5.0.2, < 6', 'yhttp-dbmanager >= 4, < 5', 'sqlalchemy >= 2', ] diff --git a/tests/test_extension.py b/tests/test_extension.py index bc78e84..93a6d66 100644 --- a/tests/test_extension.py +++ b/tests/test_extension.py @@ -5,7 +5,7 @@ from yhttp.core import json, statuses -from yhttp.ext.sqlalchemy import install +from yhttp.ext import sqlalchemy as saext, dbmanager def test_extension(app, Given, freshdb): @@ -23,11 +23,13 @@ class Foo(Base): id: Mapped[int] = mapped_column(primary_key=True) title: Mapped[str] = mapped_column(String(30)) - dbsession = install(app, Base, create_objects=True) + dbmanager.install(app) + saext.install(app, Base) app.ready() + app.db.create_objects() def mockup(): - with app.db.sessionfactory() as session, session.begin(): + with app.db.begin() as session: foo = Foo(title='foo 1') bar = Foo(title='foo 2') session.add_all([foo, bar]) @@ -36,27 +38,27 @@ def mockup(): @app.route() @json - @dbsession + @app.db.session def get(req): - result = req.session.scalars(select(Foo)).all() + result = req.dbsession.scalars(select(Foo)).all() return {f.id: f.title for f in result} @app.route() @json - @dbsession + @app.db.session def got(req): Foo(title='foo') raise statuses.created() @app.route() @json - @dbsession + @app.db.session def err(req): Foo(title='qux') raise statuses.badrequest() def getfoo(title): - with app.db.sessionfactory() as session, session.begin(): + with app.db.begin() as session: result = session.scalars( select(Foo).where(Foo.title == title) ).first() @@ -84,7 +86,7 @@ def test_exceptions(app, freshdb): class Base(DeclarativeBase): pass - install(app, Base) + saext.install(app, Base) if 'db' in app.settings: del app.settings['db'] diff --git a/yhttp/ext/sqlalchemy/orm.py b/yhttp/ext/sqlalchemy/orm.py index f615de5..487c9da 100644 --- a/yhttp/ext/sqlalchemy/orm.py +++ b/yhttp/ext/sqlalchemy/orm.py @@ -1,5 +1,8 @@ +import functools + from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from yhttp.core import HTTPStatus class Manager: @@ -33,23 +36,17 @@ def deinitialize(self): def create_objects(self): return self.basemodel.metadata.create_all(self.engine) + def begin(self): + return self.sessionfactory.begin() + def session(self, handler): @functools.wraps(handler) def outter(req, *a, **kw): - app = req.application - try: - req.dbsession = app.db.sessionfactory() - except AttributeError: - print( - 'Please install yhttp-sqlalchemy extention first.', - file=sys.stderr - ) - raise - - with req.dbsession as session, session.begin(): + with self.begin() as session: + req.dbsession = session try: - return func(req, *a, **kw) - except y.HTTPStatus as ex: + return handler(req, *a, **kw) + except HTTPStatus as ex: if ex.keepheaders: return ex