diff --git a/.flake8 b/.flake8 index cdabf5d..d75f769 100644 --- a/.flake8 +++ b/.flake8 @@ -4,4 +4,4 @@ omit = dist per-file-ignores = yhttp/ext/sqlalchemy/__init__.py: F401 - tests/conftest.py: F401 + tests/conftest.py: F401, F811 diff --git a/tests/conftest.py b/tests/conftest.py index a2e1ac8..dff9178 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,8 +8,12 @@ @pytest.fixture -def app(): +def app(freshdb): app = Application() + app.settings.merge(f''' + db: + url: {freshdb} + ''') yield app app.shutdown() diff --git a/tests/test_extension.py b/tests/test_extension.py index fad8a46..5381dd9 100644 --- a/tests/test_extension.py +++ b/tests/test_extension.py @@ -8,12 +8,7 @@ from yhttp.ext import sqlalchemy as saext, dbmanager -def test_extension(app, Given, freshdb): - app.settings.merge(f''' - db: - url: {freshdb} - ''') - +def test_extension(Given, freshdb, app): class Base(DeclarativeBase): pass @@ -28,19 +23,24 @@ class Foo(Base): app.ready() app.db.create_objects() - def mockup(): - with app.db.session() as session: - foo = Foo(title='foo 1') - bar = Foo(title='foo 2') - session.add_all([foo, bar]) + with app.db.session() as session: + foo = Foo(title='foo 1') + session.add(foo) + session.commit() + session.reset() - mockup() + with app.db.copy(freshdb) as d, d.session() as session: + bar = Foo(title='foo 2') + session.add(bar) + session.commit() + session.reset() @app.route() @json @app.db def get(req): - result = req.dbsession.scalars(select(Foo)).all() + with app.db.session() as session: + result = session.scalars(select(Foo)).all() return {f.id: f.title for f in result} @app.route() @@ -79,8 +79,6 @@ def getfoo(title): foo = getfoo('foo 1') assert foo is not None - app.shutdown() - def test_exceptions(app, freshdb): class Base(DeclarativeBase): diff --git a/yhttp/ext/sqlalchemy/install.py b/yhttp/ext/sqlalchemy/install.py index 87a803e..ffd83a3 100644 --- a/yhttp/ext/sqlalchemy/install.py +++ b/yhttp/ext/sqlalchemy/install.py @@ -9,15 +9,15 @@ def install(app, basemodel, db=None, cliarguments=None): cli.DatabaseObjectsCommand.__arguments__.extend(cliarguments) if db is None: - db = orm.DatabaseManager(app, basemodel) + db = orm.ApplicationORM(basemodel, app) if db.engine is None: @app.when def ready(app): - app.db.__enter__() + app.db.connect() @app.when def shutdown(app): - app.db.__exit__(None, None, None) + app.db.disconnect() app.db = db diff --git a/yhttp/ext/sqlalchemy/orm.py b/yhttp/ext/sqlalchemy/orm.py index 2a662bf..a34feab 100644 --- a/yhttp/ext/sqlalchemy/orm.py +++ b/yhttp/ext/sqlalchemy/orm.py @@ -1,56 +1,70 @@ import functools from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, close_all_sessions, Session +from sqlalchemy.orm import sessionmaker, close_all_sessions, Session, \ + scoped_session from yhttp.core import HTTPStatus -class DatabaseManager: - def __init__(self, app, basemodel): - self.app = app +class ORM: + def __init__(self, basemodel, url=None): + self.url = url self.engine = None - self.sessionfactory = sessionmaker() self.basemodel = basemodel + self.session = scoped_session(sessionmaker()) - def __enter__(self) -> sessionmaker: - if self.engine is None: - if 'db' not in self.app.settings: - raise ValueError( - 'Please provide db.url configuration entry, for example: ' - 'postgresql://:@/dbname' - ) - - self.engine = create_engine( - self.app.settings.db.url, - isolation_level='REPEATABLE READ' - ) + def copy(self, url=None): + return ORM(self.basemodel, url=url or self.app.settings.db.url) - self.sessionfactory.configure(bind=self.engine) - return self.sessionfactory + def create_objects(self): + return self.basemodel.metadata.create_all(self.engine) - def __exit__(self, exc_type, exc_value, traceback): + def connect(self, url=None): + u = url or self.url + assert self.engine is None + assert u is not None + + self.engine = create_engine(u, isolation_level='REPEATABLE READ') + self.session.configure(bind=self.engine) + + def disconnect(self): close_all_sessions() self.engine.dispose() + self.engine = None - def create_objects(self): - return self.basemodel.metadata.create_all(self.engine) + def __enter__(self) -> Session: + self.connect() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.disconnect() - def session(self) -> Session: - return self.sessionfactory.begin() + +class ApplicationORM(ORM): + def __init__(self, basemodel, app): + self.app = app + super().__init__(basemodel) + + def connect(self, url=None): + if 'db' not in self.app.settings or 'url' not in self.app.settings.db: + raise ValueError( + 'Please provide db.url configuration entry, for example: ' + 'postgresql://:@/dbname' + ) + + return super().connect(url=url or self.app.settings.db.url) def __call__(self, handler): @functools.wraps(handler) - def outter(req, *a, **kw): - with self.session() as session: - req.dbsession = session - try: - return handler(req, *a, **kw) - except HTTPStatus as ex: - if ex.keepheaders: - return ex - - raise - finally: - del req.dbsession + def outter(*a, **kw): + try: + return handler(*a, **kw) + except HTTPStatus as ex: + if ex.keepheaders: + return ex + + raise + finally: + self.session.reset() return outter