Skip to content

Commit

Permalink
Add optimized ToOne sqlalchemy aware field.
Browse files Browse the repository at this point in the history
This version of fields.ToOne() is able to read the local foreign key value
to serialize the reference of the remote object.
It can be useful under some circumstances to avoid performing additional db lookups.
  • Loading branch information
ticosax committed Jan 16, 2019
1 parent 1ae0208 commit a97d956
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 6 deletions.
29 changes: 28 additions & 1 deletion flask_potion/contrib/alchemy/fields.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from flask_potion.fields import Object
from werkzeug.utils import cached_property

from flask_potion.fields import Object, ToOne as GenericToOne
from flask_potion.utils import get_value, route_from


class InlineModel(Object):
Expand All @@ -15,3 +18,27 @@ def converter(self, instance):
if instance is not None:
instance = self.model(**instance)
return instance


class ToOne(GenericToOne):
"""
Same as flask_potion.fields.ToOne
except it will use the local id stored on the ForeignKey field to serialize the field.
This is an optimisation to avoid additional lookups to the database,
in order to prevent fetching the remote object, just to obtain its id,
that we already have.
Limitations:
- It works only if the foreign key is made of a single field.
- It works only if the serialization is using the ForeignKey as source of information to Identify the remote resource.
- `attribute` parameter is ignored.
"""
def output(self, key, obj):
column = getattr(obj.__class__, key)
local_columns = column.property.local_columns
assert len(local_columns) == 1
local_column = list(local_columns)[0]
key = local_column.key
return self.format(get_value(key, obj, self.default))

def formatter(self, item):
return self.formatter_key.format(item, is_local=True)
14 changes: 10 additions & 4 deletions flask_potion/natural_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


class Key(Schema, ResourceBound):
is_local = False

def matcher_type(self):
type_ = self.response['type']
Expand Down Expand Up @@ -43,11 +44,16 @@ def schema(self):
"additionalProperties": False
}

def _id_uri(self, resource, id_):
return '{}/{}'.format(resource.route_prefix, id_)

def _item_uri(self, resource, item):
# return url_for('{}.instance'.format(self.resource.meta.id_attribute, item, None), get_value(self.resource.meta.id_attribute, item, None))
return '{}/{}'.format(resource.route_prefix, get_value(resource.manager.id_attribute, item, None))

def format(self, item):
def format(self, item, is_local=False):
if is_local:
return {'$ref': self._id_uri(self.resource, item)}
return {"$ref": self._item_uri(self.resource, item)}

def convert(self, value):
Expand All @@ -71,7 +77,7 @@ def rebind(self, resource):
def schema(self):
return self.resource.schema.fields[self.property].request

def format(self, item):
def format(self, item, is_local=False):
return self.resource.schema.fields[self.property].output(self.property, item)

@cached_property
Expand Down Expand Up @@ -101,7 +107,7 @@ def schema(self):
"additionalItems": False
}

def format(self, item):
def format(self, item, is_local=False):
return [self.resource.schema.fields[p].output(p, item) for p in self.properties]

@cached_property
Expand All @@ -123,7 +129,7 @@ def _on_bind(self, resource):
def schema(self):
return self.id_field.request

def format(self, item):
def format(self, item, is_local=False):
return self.id_field.output(self.resource.manager.id_attribute, item)

def convert(self, value):
Expand Down
53 changes: 52 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pprint import pformat

from flask import json, Flask
from flask.testing import FlaskClient
from flask_testing import TestCase
import sqlalchemy


class ApiClient(FlaskClient):
Expand Down Expand Up @@ -49,4 +52,52 @@ def create_app(self):
return app

def pp(self, obj):
print(json.dumps(obj, sort_keys=True, indent=4, separators=(',', ': ')))
print(json.dumps(obj, sort_keys=True, indent=4, separators=(',', ': ')))


class DBQueryCounter:
"""
Use as a context manager to count the number of execute()'s performed
against the given sqlalchemy connection.
Usage:
with DBQueryCounter(db.session) as ctr:
db.session.execute("SELECT 1")
db.session.execute("SELECT 1")
ctr.assert_count(2)
"""

def __init__(self, session, reset=True):
self.session = session
self.reset = reset
self.statements = []

def __enter__(self):
if self.reset:
self.session.expire_all()
sqlalchemy.event.listen(
self.session.get_bind(), 'after_execute', self._callback
)
return self

def __exit__(self, *_):
sqlalchemy.event.remove(
self.session.get_bind(), 'after_execute', self._callback
)

def get_count(self):
return len(self.statements)

def _callback(self, conn, clause_element, multiparams, params, result):
self.statements.append((clause_element, multiparams, params))

def display_all(self):
for clause, multiparams, params in self.statements:
print(pformat(str(clause)), multiparams, params)
print('\n')
count = self.get_count()
return 'Counted: {count}'.format(count=count)

def assert_count(self, expected):
count = self.get_count()
assert count == expected, self.display_all()
62 changes: 62 additions & 0 deletions tests/contrib/alchemy/test_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from flask_sqlalchemy import SQLAlchemy

from flask_potion import Api, fields
from flask_potion.resource import ModelResource
from flask_potion.contrib.alchemy.fields import ToOne as SAToOne
from tests import BaseTestCase, DBQueryCounter


class SQLAlchemyToOneRemainNoPrefetchTestCase(BaseTestCase):
"""
"""

def setUp(self):
super(SQLAlchemyToOneRemainNoPrefetchTestCase, self).setUp()
self.app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
self.api = Api(self.app)
self.sa = sa = SQLAlchemy(
self.app, session_options={"autoflush": False})

class Type(sa.Model):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(60), nullable=False)

class Machine(sa.Model):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(60), nullable=False)

type_id = sa.Column(sa.Integer, sa.ForeignKey(Type.id))
type = sa.relationship(Type, foreign_keys=[type_id])

sa.create_all()

class MachineResource(ModelResource):
class Meta:
model = Machine

class Schema:
type = SAToOne('type')

class TypeResource(ModelResource):
class Meta:
model = Type

self.MachineResource = MachineResource
self.TypeResource = TypeResource

self.api.add_resource(MachineResource)
self.api.add_resource(TypeResource)

def test_relation_serialization_does_not_load_remote_object(self):
response = self.client.post('/type', data={"name": "aaa"})
aaa_uri = response.json["$uri"]
self.client.post(
'/machine', data={"name": "foo", "type": {"$ref": aaa_uri}})
with DBQueryCounter(self.sa.session) as counter:
response = self.client.get('/machine')
self.assert200(response)
self.assertJSONEqual(
[{'$uri': '/machine/1', 'type': {'$ref': aaa_uri}, 'name': 'foo'}],
response.json)
counter.assert_count(1)

0 comments on commit a97d956

Please sign in to comment.