Skip to content

Commit

Permalink
fix: serialiaze properly unions when types are similar. Closes #749
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosschroh committed Sep 20, 2024
1 parent 879fe7e commit 05b629c
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 58 deletions.
7 changes: 6 additions & 1 deletion dataclasses_avroschema/faust/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ def standardize_type(self) -> typing.Any:
user-defined pydantic json_encoders prior to passing values
to the standard type conversion factory
"""
return standardize_custom_type(self)
return {
field_name: standardize_custom_type(
field_name=field_name, value=value, model=self, base_class=AvroRecord
)
for field_name, value in self.asdict().items()
}

def serialize(self, serialization_type: str = AVRO) -> bytes:
"""
Expand Down
31 changes: 10 additions & 21 deletions dataclasses_avroschema/pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,22 @@ def generate_dataclass(cls: Type[CT]) -> Type[CT]:
def json_schema(cls: Type[CT], *args: Any, **kwargs: Any) -> str:
return json.dumps(cls.model_json_schema(*args, **kwargs))

@classmethod
def standardize_type(cls: Type[CT], data: dict) -> Any:
"""
Standardization factory that converts data according to the
user-defined pydantic json_encoders prior to passing values
to the standard type conversion factory
"""
for value in data.values():
if isinstance(value, dict):
cls.standardize_type(value)

return standardize_custom_type(data)

def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> JsonDict:
"""
Returns this model in dictionary form. This method differs from
pydantic's dict by converting all values to their Avro representation.
It also doesn't provide the exclude, include, by_alias, etc.
parameters that dict provides.
"""
data = self.model_dump()
standardize_method = standardize_factory or self.standardize_type

# the standardize called can be replaced if we have a custom implementation of asdict
# for now I think is better to use the native implementation
return standardize_method(data)

standardize_method = standardize_factory or standardize_custom_type

return {
field_name: standardize_method(
field_name=field_name, value=value, model=self, base_class=AvroBaseModel
)
for field_name, value in self.model_dump().items()
}

@classmethod
def parse_obj(cls: Type[CT], data: Dict) -> CT:
return cls.model_validate(obj=data)
Expand All @@ -66,7 +55,7 @@ def serialize(self, serialization_type: str = AVRO) -> bytes:
schema = self.avro_schema_to_python()

return serialization.serialize(
self.asdict(standardize_factory=self.standardize_type),
self.asdict(),
schema,
serialization_type=serialization_type,
)
Expand Down
35 changes: 9 additions & 26 deletions dataclasses_avroschema/pydantic/v1/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,21 @@ def generate_dataclass(cls: Type[CT]) -> Type[CT]:
def json_schema(cls: Type[CT], *args: Any, **kwargs: Any) -> str:
return cls.schema_json(*args, **kwargs)

@classmethod
def standardize_type(cls: Type[CT], data: dict) -> Any:
"""
Standardization factory that converts data according to the
user-defined pydantic json_encoders prior to passing values
to the standard type conversion factory
"""
encoders = cls.__config__.json_encoders
for k, v in data.items():
v_type = type(v)
if v_type in encoders:
encode_method = encoders[v_type]
data[k] = encode_method(v)
elif isinstance(v, dict):
cls.standardize_type(v)

return standardize_custom_type(data)

def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> JsonDict:
"""
Returns this model in dictionary form. This method differs from
pydantic's dict by converting all values to their Avro representation.
It also doesn't provide the exclude, include, by_alias, etc.
parameters that dict provides.
"""
data = dict(self)

standardize_method = standardize_factory or self.standardize_type

# the standardize called can be replaced if we have a custom implementation of asdict
# for now I think is better to use the native implementation
return standardize_method(data)
standardize_method = standardize_factory or standardize_custom_type

return {
field_name: standardize_method(
field_name=field_name, value=value, model=self, base_class=AvroBaseModel
)
for field_name, value in dict(self).items()
}

def serialize(self, serialization_type: str = AVRO) -> bytes:
"""
Expand All @@ -67,7 +50,7 @@ def serialize(self, serialization_type: str = AVRO) -> bytes:
schema = self.avro_schema_to_python()

return serialization.serialize(
self.asdict(standardize_factory=self.standardize_type),
self.asdict(),
schema,
serialization_type=serialization_type,
)
Expand Down
10 changes: 6 additions & 4 deletions dataclasses_avroschema/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,12 @@ def _reset_parser(cls: "Type[CT]") -> None:

def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> JsonDict:
if standardize_factory is not None:
return dataclasses.asdict(
self,
dict_factory=lambda x: {key: standardize_factory(value) for key, value in x},
) # type: ignore
return {
field.name: standardize_factory(
field_name=field.name, value=getattr(self, field.name), model=self, base_class=AvroModel
)
for field in dataclasses.fields(self)
}
return dataclasses.asdict(self) # type: ignore

def serialize(self, serialization_type: str = AVRO) -> bytes:
Expand Down
24 changes: 18 additions & 6 deletions dataclasses_avroschema/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import enum
import types
import typing
from datetime import datetime, timezone
from functools import lru_cache
Expand Down Expand Up @@ -83,17 +84,28 @@ def rebuild_annotation(a_type: typing.Any, field_info: FieldInfo) -> typing.Type
return Annotated[a_type, field_info] # type: ignore[return-value]


def standardize_custom_type(value: typing.Any) -> typing.Any:
def standardize_custom_type(*, field_name: str, value: typing.Any, model, base_class) -> typing.Any:
if isinstance(value, dict):
return {k: standardize_custom_type(v) for k, v in value.items()}
return {
k: standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class)
for k, v in value.items()
}
elif isinstance(value, list):
return [standardize_custom_type(v) for v in value]
return [
standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class) for v in value
]
elif isinstance(value, tuple):
return tuple(standardize_custom_type(v) for v in value)
return tuple(
standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class) for v in value
)
elif isinstance(value, enum.Enum):
return value.value
elif is_pydantic_model(type(value)) or is_faust_record(type(value)): # type: ignore[arg-type]
return standardize_custom_type(value.asdict())
elif isinstance(value, base_class):
asdict = value.asdict(standardize_factory=standardize_custom_type)

if isinstance(model.__annotations__[field_name], types.UnionType):
return (value.__class__.__name__, asdict)
return asdict

return value

Expand Down

0 comments on commit 05b629c

Please sign in to comment.