Skip to content

Commit

Permalink
Enahncement: API
Browse files Browse the repository at this point in the history
  • Loading branch information
pylover committed Aug 22, 2024
1 parent 67f4f49 commit 417d4cc
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 34 deletions.
14 changes: 5 additions & 9 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,28 @@ class Foo(Base):

@app.route()
@json
@app.db
def get(req):
with app.db.session() as session:
result = session.scalars(select(Foo)).all()
return {f.id: f.title for f in result}

@app.route()
@json
@app.db
def got(req):
Foo(title='foo')
with app.db.session.begin():
app.db.session.add(Foo(title='foo'))
raise statuses.created()

@app.route()
@json
@app.db
def err(req):
Foo(title='qux')
with app.db.session() as session:
session.add(Foo(title='qux'))
raise statuses.badrequest()

def getfoo(title):
with app.db.session() as session:
result = session.scalars(
select(Foo).where(Foo.title == title)
).first()
return result
return session.query(Foo).filter_by(title=title).first()

with Given():
assert status == 200
Expand Down
4 changes: 4 additions & 0 deletions yhttp/ext/sqlalchemy/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,8 @@ def ready(app):
def shutdown(app):
app.db.disconnect()

@app.when
def endresponse(response):
app.db.session.reset()

app.db = db
29 changes: 4 additions & 25 deletions yhttp/ext/sqlalchemy/orm.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
import functools

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, close_all_sessions, Session, \
scoped_session
from yhttp.core import HTTPStatus


class ORM:
def __init__(self, basemodel, url=None):
self.url = url
self.engine = None
self.basemodel = basemodel
self._session = scoped_session(sessionmaker())

def session(self):
return self._session()
self.session = scoped_session(sessionmaker())

def copy(self, url=None):
return ORM(self.basemodel, url=url or self.app.settings.db.url)
Expand All @@ -28,12 +22,12 @@ def connect(self, url=None):
assert u is not None

self.engine = create_engine(u, isolation_level='REPEATABLE READ')
self._session.configure(bind=self.engine)
self.session.configure(bind=self.engine)

def disconnect(self):
close_all_sessions()
self._session.expunge_all()
self._session.remove()
self.session.expunge_all()
self.session.remove()
self.engine.dispose()
self.engine = None

Expand All @@ -58,18 +52,3 @@ def connect(self, url=None):
)

return super().connect(url=url or self.app.settings.db.url)

def __call__(self, handler):
@functools.wraps(handler)
def outter(*a, **kw):
try:
return handler(*a, **kw)
except HTTPStatus as ex:
if ex.keepheaders:
return ex

raise
finally:
self._session.reset()

return outter

0 comments on commit 417d4cc

Please sign in to comment.