Skip to content

Commit

Permalink
fix: serialiaze properly unions when types are similar. Closes #749
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosschroh committed Sep 27, 2024
1 parent 31e3724 commit 5a0a2b9
Show file tree
Hide file tree
Showing 13 changed files with 251 additions and 229 deletions.
11 changes: 2 additions & 9 deletions dataclasses_avroschema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
from .schema_generator import AvroModel # noqa: I001
from .types import (
Int32,
Float32,
TimeMicro,
DateTimeMicro,
condecimal,
confixed
)
from .types import Int32, Float32, TimeMicro, DateTimeMicro, condecimal, confixed
from .model_generator.generator import BaseClassEnum, ModelType, ModelGenerator
from .fields.field_utils import (
BOOLEAN,
Expand Down Expand Up @@ -132,4 +125,4 @@
"DecimalField",
"RecordField",
"AvroField",
]
]
29 changes: 9 additions & 20 deletions dataclasses_avroschema/faust/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@ def validate_avro(self) -> bool:
schema = self.avro_schema_to_python()
return validate(self.asdict(), schema)

def standardize_type(self) -> typing.Any:
def standardize_type(self, include_type: bool = True) -> typing.Any:
"""
Standardization factory that converts data according to the
user-defined pydantic json_encoders prior to passing values
to the standard type conversion factory
"""
return standardize_custom_type(self)
return {
field_name: standardize_custom_type(
field_name=field_name, value=value, model=self, base_class=AvroRecord, include_type=include_type
)
for field_name, value in self.asdict().items()
}

def serialize(self, serialization_type: str = AVRO) -> bytes:
"""
Expand All @@ -48,25 +53,9 @@ def serialize(self, serialization_type: str = AVRO) -> bytes:
serialization_type=serialization_type,
)

@classmethod
def deserialize(
cls: typing.Type[CT],
data: bytes,
serialization_type: str = AVRO,
create_instance: bool = True,
writer_schema: typing.Optional[typing.Union[JsonDict, typing.Type[CT]]] = None,
) -> typing.Union[JsonDict, CT]:
payload = cls.deserialize_to_python(data, serialization_type, writer_schema)
obj = cls.parse_obj(payload)

if not create_instance:
return obj.standardize_type()
return obj

def to_dict(self) -> JsonDict:
return self.standardize_type()
return self.standardize_type(include_type=False)

@classmethod
def _generate_parser(cls: typing.Type[CT]) -> FaustParser:
cls._metadata = cls.generate_metadata()
return FaustParser(type=cls._klass, metadata=cls._metadata, parent=cls._parent or cls)
return FaustParser(type=cls._klass, metadata=cls.get_metadata(), parent=cls._parent or cls)
2 changes: 1 addition & 1 deletion dataclasses_avroschema/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def fake(self) -> typing.Any:

def exist_type(self) -> int:
# filter by the same field types
same_types = [field.type for field in self.parent._user_defined_types if field.type == self.type]
same_types = [field.model for field in self.parent._user_defined_types if field.model == self.type]

# If length > 0, means that it is the first appearance
# of this type, otherwise exist already.
Expand Down
4 changes: 2 additions & 2 deletions dataclasses_avroschema/fields/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def get_avro_type(self) -> typing.Union[str, types.JsonDict]:
name = metadata.pop("schema_name", self.type.__name__)

if not self.exist_type():
user_defined_type = utils.UserDefinedType(name=name, type=self.type)
user_defined_type = utils.UserDefinedType(name=name, model=self.type)
self.parent._user_defined_types.add(user_defined_type)
return {
"type": field_utils.ENUM,
Expand Down Expand Up @@ -798,7 +798,7 @@ def get_avro_type(self) -> typing.Union[str, typing.List, typing.Dict]:
name = alias or metadata.schema_name or self.type.__name__

if not self.exist_type() or alias is not None:
user_defined_type = utils.UserDefinedType(name=name, type=self.type)
user_defined_type = utils.UserDefinedType(name=name, model=self.type)
self.parent._user_defined_types.add(user_defined_type)

record_type = self.type.avro_schema_to_python(parent=self.parent)
Expand Down
39 changes: 22 additions & 17 deletions dataclasses_avroschema/pydantic/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Callable, Dict, Optional, Type, TypeVar
from typing import Any, Dict, Type, TypeVar

from fastavro.validation import validate

Expand Down Expand Up @@ -27,37 +27,43 @@ def generate_dataclass(cls: Type[CT]) -> Type[CT]:
def json_schema(cls: Type[CT], *args: Any, **kwargs: Any) -> str:
return json.dumps(cls.model_json_schema(*args, **kwargs))

@classmethod
def standardize_type(cls: Type[CT], data: dict) -> Any:
def _standardize_type(self) -> Dict[str, Any]:
"""
Standardization factory that converts data according to the
user-defined pydantic json_encoders prior to passing values
to the standard type conversion factory
"""
for value in data.values():
if isinstance(value, dict):
cls.standardize_type(value)
encoders = self.model_config.get("json_encoders") or {}
data = dict(self)

return standardize_custom_type(data)
for k, v in data.items():
v_type = type(v)
if v_type in encoders:
encode_method = encoders[v_type]
data[k] = encode_method(v)
return data

def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> JsonDict:
def asdict(self) -> JsonDict:
"""
Returns this model in dictionary form. This method differs from
pydantic's dict by converting all values to their Avro representation.
It also doesn't provide the exclude, include, by_alias, etc.
parameters that dict provides.
"""
data = self.model_dump()
standardize_method = standardize_factory or self.standardize_type

# the standardize called can be replaced if we have a custom implementation of asdict
# for now I think is better to use the native implementation
return standardize_method(data)
return {
field_name: standardize_custom_type(
field_name=field_name, value=field_value, model=self, base_class=AvroBaseModel
)
for field_name, field_value in self._standardize_type().items()
}

@classmethod
def parse_obj(cls: Type[CT], data: Dict) -> CT:
return cls.model_validate(obj=data)

def to_dict(self) -> JsonDict:
return self.model_dump()

def serialize(self, serialization_type: str = AVRO) -> bytes:
"""
Overrides the base AvroModel's serialize method to inject this
Expand All @@ -66,7 +72,7 @@ def serialize(self, serialization_type: str = AVRO) -> bytes:
schema = self.avro_schema_to_python()

return serialization.serialize(
self.asdict(standardize_factory=self.standardize_type),
self.asdict(),
schema,
serialization_type=serialization_type,
)
Expand Down Expand Up @@ -94,5 +100,4 @@ def fake(cls: Type[CT], **data: Any) -> CT:

@classmethod
def _generate_parser(cls: Type[CT]) -> PydanticParser:
cls._metadata = cls.generate_metadata()
return PydanticParser(type=cls._klass, metadata=cls._metadata, parent=cls._parent or cls)
return PydanticParser(type=cls._klass, metadata=cls.get_metadata(), parent=cls._parent or cls)
35 changes: 17 additions & 18 deletions dataclasses_avroschema/pydantic/v1/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, Type, TypeVar
from typing import Any, Dict, Type, TypeVar

from fastavro.validation import validate

Expand Down Expand Up @@ -26,38 +26,38 @@ def generate_dataclass(cls: Type[CT]) -> Type[CT]:
def json_schema(cls: Type[CT], *args: Any, **kwargs: Any) -> str:
return cls.schema_json(*args, **kwargs)

@classmethod
def standardize_type(cls: Type[CT], data: dict) -> Any:
def _standardize_type(self) -> Dict[str, Any]:
"""
Standardization factory that converts data according to the
user-defined pydantic json_encoders prior to passing values
to the standard type conversion factory
"""
encoders = cls.__config__.json_encoders
encoders = self.__config__.json_encoders
data = dict(self)

for k, v in data.items():
v_type = type(v)
if v_type in encoders:
encode_method = encoders[v_type]
data[k] = encode_method(v)
elif isinstance(v, dict):
cls.standardize_type(v)
return data

return standardize_custom_type(data)

def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> JsonDict:
def asdict(self) -> JsonDict:
"""
Returns this model in dictionary form. This method differs from
pydantic's dict by converting all values to their Avro representation.
It also doesn't provide the exclude, include, by_alias, etc.
parameters that dict provides.
"""
data = dict(self)

standardize_method = standardize_factory or self.standardize_type
return {
field_name: standardize_custom_type(
field_name=field_name, value=value, model=self, base_class=AvroBaseModel
)
for field_name, value in self._standardize_type().items()
}

# the standardize called can be replaced if we have a custom implementation of asdict
# for now I think is better to use the native implementation
return standardize_method(data)
def to_dict(self) -> JsonDict:
return dict(self)

def serialize(self, serialization_type: str = AVRO) -> bytes:
"""
Expand All @@ -67,7 +67,7 @@ def serialize(self, serialization_type: str = AVRO) -> bytes:
schema = self.avro_schema_to_python()

return serialization.serialize(
self.asdict(standardize_factory=self.standardize_type),
self.asdict(),
schema,
serialization_type=serialization_type,
)
Expand Down Expand Up @@ -95,5 +95,4 @@ def fake(cls: Type[CT], **data: Any) -> CT:

@classmethod
def _generate_parser(cls: Type[CT]) -> PydanticV1Parser:
cls._metadata = cls.generate_metadata()
return PydanticV1Parser(type=cls._klass, metadata=cls._metadata, parent=cls._parent or cls)
return PydanticV1Parser(type=cls._klass, metadata=cls.get_metadata(), parent=cls._parent or cls)
Loading

0 comments on commit 5a0a2b9

Please sign in to comment.