Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow use of Schema hooks on OneOfSchema #130

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 84 additions & 71 deletions marshmallow_oneofschema/one_of_schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,28 @@
from marshmallow import Schema, ValidationError
from collections.abc import Mapping
import inspect

from marshmallow import Schema, ValidationError, RAISE


# these helpers copied from marshmallow.utils #


def is_generator(obj) -> bool:
"""Return True if ``obj`` is a generator"""
return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)


def is_iterable_but_not_string(obj) -> bool:
"""Return True if ``obj`` is an iterable object that isn't a string."""
return (hasattr(obj, "__iter__") and not hasattr(obj, "strip")) or is_generator(obj)


def is_collection(obj) -> bool:
"""Return True if ``obj`` is a collection type, e.g list, tuple, queryset."""
return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)


# end of helpers copied from marshmallow.utils #


class OneOfSchema(Schema):
Expand Down Expand Up @@ -63,32 +87,16 @@ def get_obj_type(self, obj):
"""Returns name of object schema"""
return obj.__class__.__name__

def dump(self, obj, *, many=None, **kwargs):
errors = {}
result_data = []
result_errors = {}
many = self.many if many is None else bool(many)
if not many:
result = result_data = self._dump(obj, **kwargs)
else:
for idx, o in enumerate(obj):
try:
result = self._dump(o, **kwargs)
result_data.append(result)
except ValidationError as error:
result_errors[idx] = error.normalized_messages()
result_data.append(error.valid_data)

result = result_data
errors = result_errors

if not errors:
return result
else:
exc = ValidationError(errors, data=obj, valid_data=result)
raise exc

def _dump(self, obj, *, update_fields=True, **kwargs):
# override the `_serialize` method of Schema, rather than `dump`
# this requires that we interact with a private API of marshmallow, but
# `_serialize` is the step that happens between pre_dump and post_dump
# hooks, so by using this rather than `load()`, we get schema hooks to work
def _serialize(self, obj, *, many=False):
if many and obj is not None:
return [self._serialize(subdoc, many=False) for subdoc in obj]
return self._dump_type_schema(obj)

def _dump_type_schema(self, obj):
obj_type = self.get_obj_type(obj)
if not obj_type:
return (
Expand All @@ -104,46 +112,58 @@ def _dump(self, obj, *, update_fields=True, **kwargs):

schema.context.update(getattr(self, "context", {}))

result = schema.dump(obj, many=False, **kwargs)
result = schema.dump(obj, many=False)
if result is not None:
result[self.type_field] = obj_type
return result

def load(self, data, *, many=None, partial=None, unknown=None, **kwargs):
errors = {}
result_data = []
result_errors = {}
many = self.many if many is None else bool(many)
if partial is None:
partial = self.partial
if not many:
try:
result = result_data = self._load(
data, partial=partial, unknown=unknown, **kwargs
)
# result_data.append(result)
except ValidationError as error:
result_errors = error.normalized_messages()
result_data.append(error.valid_data)
else:
for idx, item in enumerate(data):
try:
result = self._load(item, partial=partial, **kwargs)
result_data.append(result)
except ValidationError as error:
result_errors[idx] = error.normalized_messages()
result_data.append(error.valid_data)

result = result_data
errors = result_errors

if not errors:
return result
else:
exc = ValidationError(errors, data=data, valid_data=result)
raise exc

def _load(self, data, *, partial=None, unknown=None, **kwargs):
# override the `_deserialize` method of Schema, rather than `load`
# this requires that we interact with a private API of marshmallow, but
# `_deserialize` is the step that happens between pre_load and validation
# hooks, so by using this rather than `load()`, we get schema hooks to work
def _deserialize(
self,
data,
*,
error_store,
many=False,
partial=False,
unknown=RAISE,
index=None,
):
index = index if self.opts.index_errors else None
# if many, check for non-collection data (error) or iterate and
# re-invoke `_deserialize` on each one with many=False
# this is paraphrased from marshmallow.Schema._deserialize
if many:
if not is_collection(data):
error_store.store_error([self.error_messages["type"]], index=index)
return []
else:
return [
self._deserialize(
subdoc,
error_store=error_store,
many=False,
partial=partial,
unknown=unknown,
index=idx,
)
for idx, subdoc in enumerate(data)
]
if not isinstance(data, Mapping):
error_store.store_error([self.error_messages["type"]], index=index)
return self.dict_class()

try:
result = self._load_type_schema(data, partial=partial, unknown=unknown)
except ValidationError as err:
error_store.store_error(err.messages, index=index)
result = err.valid_data

return result

def _load_type_schema(self, data, *, partial=None, unknown=None):
if not isinstance(data, dict):
raise ValidationError({"_schema": "Invalid data type: %s" % data})

Expand Down Expand Up @@ -173,11 +193,4 @@ def _load(self, data, *, partial=None, unknown=None, **kwargs):

schema.context.update(getattr(self, "context", {}))

return schema.load(data, many=False, partial=partial, unknown=unknown, **kwargs)

def validate(self, data, *, many=None, partial=None):
try:
self.load(data, many=many, partial=partial)
except ValidationError as ve:
return ve.messages
return {}
return schema.load(data, many=False, partial=partial, unknown=unknown)
23 changes: 23 additions & 0 deletions tests/test_one_of_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,29 @@ class TestSchema(OneOfSchema):
TestSchema(unknown="exclude").load({"type": "Bar", "bar": 123})
assert Nonlocal.data["type"] == "Bar"

def test_post_dump_remove_type_field(self):
# test using a @post_dump hook to remove the type field which
# OneOfSchema will add to the data by default

# define a schema without post_dump
class MySchemaVariant1(OneOfSchema):
type_schemas = {"Foo": FooSchema, "Bar": BarSchema}

# and a variant with post_dump
class MySchemaVariant2(MySchemaVariant1):
@m.post_dump
def remove_type_field(self, data, **kwargs):
del data["type"]
return data

# sanity check: `type` should be present in a dump from Variant1
assert MySchemaVariant1().dump(Foo("someval")) == {
"type": "Foo",
"value": "someval",
}
# now check that the post_dump hook fired
assert MySchemaVariant2().dump(Foo("someval")) == {"value": "someval"}

def test_load_non_dict(self):
with pytest.raises(m.ValidationError) as exc_info:
MySchema().load(123)
Expand Down