Skip to content

Commit

Permalink
fix: serialiaze properly unions when types are similar. Closes #749
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosschroh committed Sep 24, 2024
1 parent 879fe7e commit a75a4f3
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 45 deletions.
5 changes: 4 additions & 1 deletion dataclasses_avroschema/faust/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def standardize_type(self) -> typing.Any:
user-defined pydantic json_encoders prior to passing values
to the standard type conversion factory
"""
return standardize_custom_type(self)
return {
field_name: standardize_custom_type(field_name=field_name, value=value, model=self, base_class=AvroRecord)
for field_name, value in self.asdict().items()
}

def serialize(self, serialization_type: str = AVRO) -> bytes:
"""
Expand Down
27 changes: 8 additions & 19 deletions dataclasses_avroschema/pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,21 @@ def generate_dataclass(cls: Type[CT]) -> Type[CT]:
def json_schema(cls: Type[CT], *args: Any, **kwargs: Any) -> str:
return json.dumps(cls.model_json_schema(*args, **kwargs))

@classmethod
def standardize_type(cls: Type[CT], data: dict) -> Any:
"""
Standardization factory that converts data according to the
user-defined pydantic json_encoders prior to passing values
to the standard type conversion factory
"""
for value in data.values():
if isinstance(value, dict):
cls.standardize_type(value)

return standardize_custom_type(data)

def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> JsonDict:
"""
Returns this model in dictionary form. This method differs from
pydantic's dict by converting all values to their Avro representation.
It also doesn't provide the exclude, include, by_alias, etc.
parameters that dict provides.
"""
data = self.model_dump()
standardize_method = standardize_factory or self.standardize_type
standardize_method = standardize_factory or standardize_custom_type

# the standardize called can be replaced if we have a custom implementation of asdict
# for now I think is better to use the native implementation
return standardize_method(data)
return {
field_name: standardize_method(
field_name=field_name, value=getattr(self, field_name), model=self, base_class=AvroBaseModel
)
for field_name, field_info in self.model_fields.items()
}

@classmethod
def parse_obj(cls: Type[CT], data: Dict) -> CT:
Expand All @@ -66,7 +55,7 @@ def serialize(self, serialization_type: str = AVRO) -> bytes:
schema = self.avro_schema_to_python()

return serialization.serialize(
self.asdict(standardize_factory=self.standardize_type),
self.asdict(),
schema,
serialization_type=serialization_type,
)
Expand Down
27 changes: 12 additions & 15 deletions dataclasses_avroschema/pydantic/v1/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, Type, TypeVar
from typing import Any, Callable, Dict, Optional, Type, TypeVar

from fastavro.validation import validate

Expand Down Expand Up @@ -26,23 +26,21 @@ def generate_dataclass(cls: Type[CT]) -> Type[CT]:
def json_schema(cls: Type[CT], *args: Any, **kwargs: Any) -> str:
return cls.schema_json(*args, **kwargs)

@classmethod
def standardize_type(cls: Type[CT], data: dict) -> Any:
def _standardize_type(self) -> Dict[str, Any]:
"""
Standardization factory that converts data according to the
user-defined pydantic json_encoders prior to passing values
to the standard type conversion factory
"""
encoders = cls.__config__.json_encoders
encoders = self.__config__.json_encoders
data = dict(self)

for k, v in data.items():
v_type = type(v)
if v_type in encoders:
encode_method = encoders[v_type]
data[k] = encode_method(v)
elif isinstance(v, dict):
cls.standardize_type(v)

return standardize_custom_type(data)
return data

def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> JsonDict:
"""
Expand All @@ -51,13 +49,12 @@ def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> Js
It also doesn't provide the exclude, include, by_alias, etc.
parameters that dict provides.
"""
data = dict(self)

standardize_method = standardize_factory or self.standardize_type
standardize_method = standardize_factory or standardize_custom_type

# the standardize called can be replaced if we have a custom implementation of asdict
# for now I think is better to use the native implementation
return standardize_method(data)
return {
field_name: standardize_method(field_name=field_name, value=value, model=self, base_class=AvroBaseModel)
for field_name, value in self._standardize_type().items()
}

def serialize(self, serialization_type: str = AVRO) -> bytes:
"""
Expand All @@ -67,7 +64,7 @@ def serialize(self, serialization_type: str = AVRO) -> bytes:
schema = self.avro_schema_to_python()

return serialization.serialize(
self.asdict(standardize_factory=self.standardize_type),
self.asdict(),
schema,
serialization_type=serialization_type,
)
Expand Down
10 changes: 6 additions & 4 deletions dataclasses_avroschema/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,12 @@ def _reset_parser(cls: "Type[CT]") -> None:

def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> JsonDict:
if standardize_factory is not None:
return dataclasses.asdict(
self,
dict_factory=lambda x: {key: standardize_factory(value) for key, value in x},
) # type: ignore
return {
field.name: standardize_factory(
field_name=field.name, value=getattr(self, field.name), model=self, base_class=AvroModel
)
for field in dataclasses.fields(self) # type: ignore
}
return dataclasses.asdict(self) # type: ignore

def serialize(self, serialization_type: str = AVRO) -> bytes:
Expand Down
28 changes: 22 additions & 6 deletions dataclasses_avroschema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,33 @@ def rebuild_annotation(a_type: typing.Any, field_info: FieldInfo) -> typing.Type
return Annotated[a_type, field_info] # type: ignore[return-value]


def standardize_custom_type(value: typing.Any) -> typing.Any:
def standardize_custom_type(*, field_name: str, value: typing.Any, model, base_class) -> typing.Any:
if isinstance(value, dict):
return {k: standardize_custom_type(v) for k, v in value.items()}
return {
k: standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class)
for k, v in value.items()
}
elif isinstance(value, list):
return [standardize_custom_type(v) for v in value]
return [
standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class) for v in value
]
elif isinstance(value, tuple):
return tuple(standardize_custom_type(v) for v in value)
return tuple(
standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class) for v in value
)
elif isinstance(value, enum.Enum):
return value.value
elif is_pydantic_model(type(value)) or is_faust_record(type(value)): # type: ignore[arg-type]
return standardize_custom_type(value.asdict())
elif isinstance(value, base_class):
if is_faust_record(type(value)): # type: ignore[arg-type]
# we need to do a trick because we can not overrride asdict from faust..
# once the function interface is introduced we can remove this check
asdict = value.standardize_type()
else:
asdict = value.asdict(standardize_factory=standardize_custom_type)

if is_union(model.__annotations__[field_name]):
asdict["-type"] = value.__class__.__name__
return asdict

return value

Expand Down
32 changes: 32 additions & 0 deletions tests/serialization/test_nested_schema_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,35 @@ class User(model_class):
user = User(name="Alex", friends=[Friend(name="Mr. Robot", hobbies=["fishing", "codding"])])

assert User.deserialize(user.serialize()) == user


@parametrize_base_model
def test_union_with_multiple_records(model_class: typing.Type[AvroModel], decorator: typing.Callable):
@decorator
class EventOne(model_class):
name: str
tag: typing.Literal["EventOne"] = "EventOne"

@decorator
class EventTwo(model_class):
name: str
tag: typing.Literal["EventTwo"] = "EventTwo"

@decorator
class EventManager(model_class):
event: typing.Union[EventOne, EventTwo]
capacity: int = 100

# check union with first element
event = EventManager(event=EventOne(name="hello Event one"))
event_serialized = event.serialize()

assert event_serialized == b"\x00\x1ehello Event one\x10EventOne\xc8\x01"
assert EventManager.deserialize(event_serialized) == event

# check union with second element
event = EventManager(event=EventTwo(name="hello Event two"), capacity=150)
event_serialized = event.serialize()

assert event_serialized == b"\x02\x1ehello Event two\x10EventTwo\xac\x02"
assert EventManager.deserialize(event_serialized) == event

0 comments on commit a75a4f3

Please sign in to comment.