From 0e98416f7c1c678b33063c34a9c55e123df94985 Mon Sep 17 00:00:00 2001 From: marcosschroh Date: Fri, 20 Sep 2024 17:47:14 +0200 Subject: [PATCH] fix: serialiaze properly unions when types are similar. Closes #749 --- dataclasses_avroschema/faust/main.py | 7 ++++- dataclasses_avroschema/pydantic/main.py | 31 +++++++--------------- dataclasses_avroschema/pydantic/v1/main.py | 31 +++++++++++----------- dataclasses_avroschema/schema_generator.py | 10 ++++--- dataclasses_avroschema/utils.py | 29 +++++++++++++++----- 5 files changed, 60 insertions(+), 48 deletions(-) diff --git a/dataclasses_avroschema/faust/main.py b/dataclasses_avroschema/faust/main.py index ce0d9483..26286c2b 100644 --- a/dataclasses_avroschema/faust/main.py +++ b/dataclasses_avroschema/faust/main.py @@ -32,7 +32,12 @@ def standardize_type(self) -> typing.Any: user-defined pydantic json_encoders prior to passing values to the standard type conversion factory """ - return standardize_custom_type(self) + return { + field_name: standardize_custom_type( + field_name=field_name, value=value, model=self, base_class=AvroRecord + ) + for field_name, value in self.asdict().items() + } def serialize(self, serialization_type: str = AVRO) -> bytes: """ diff --git a/dataclasses_avroschema/pydantic/main.py b/dataclasses_avroschema/pydantic/main.py index 0b3175ba..2a6b7e6b 100644 --- a/dataclasses_avroschema/pydantic/main.py +++ b/dataclasses_avroschema/pydantic/main.py @@ -27,19 +27,6 @@ def generate_dataclass(cls: Type[CT]) -> Type[CT]: def json_schema(cls: Type[CT], *args: Any, **kwargs: Any) -> str: return json.dumps(cls.model_json_schema(*args, **kwargs)) - @classmethod - def standardize_type(cls: Type[CT], data: dict) -> Any: - """ - Standardization factory that converts data according to the - user-defined pydantic json_encoders prior to passing values - to the standard type conversion factory - """ - for value in data.values(): - if isinstance(value, dict): - cls.standardize_type(value) - - return standardize_custom_type(data) - def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> JsonDict: """ Returns this model in dictionary form. This method differs from @@ -47,13 +34,15 @@ def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> Js It also doesn't provide the exclude, include, by_alias, etc. parameters that dict provides. """ - data = self.model_dump() - standardize_method = standardize_factory or self.standardize_type - - # the standardize called can be replaced if we have a custom implementation of asdict - # for now I think is better to use the native implementation - return standardize_method(data) - + standardize_method = standardize_factory or standardize_custom_type + + return { + field_name: standardize_method( + field_name=field_name, value=value, model=self, base_class=AvroBaseModel + ) + for field_name, value in self.model_dump().items() + } + @classmethod def parse_obj(cls: Type[CT], data: Dict) -> CT: return cls.model_validate(obj=data) @@ -66,7 +55,7 @@ def serialize(self, serialization_type: str = AVRO) -> bytes: schema = self.avro_schema_to_python() return serialization.serialize( - self.asdict(standardize_factory=self.standardize_type), + self.asdict(), schema, serialization_type=serialization_type, ) diff --git a/dataclasses_avroschema/pydantic/v1/main.py b/dataclasses_avroschema/pydantic/v1/main.py index ae4c72a5..c5b74873 100644 --- a/dataclasses_avroschema/pydantic/v1/main.py +++ b/dataclasses_avroschema/pydantic/v1/main.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional, Type, TypeVar +from typing import Any, Callable, Dict, Optional, Type, TypeVar from fastavro.validation import validate @@ -26,23 +26,21 @@ def generate_dataclass(cls: Type[CT]) -> Type[CT]: def json_schema(cls: Type[CT], *args: Any, **kwargs: Any) -> str: return cls.schema_json(*args, **kwargs) - @classmethod - def standardize_type(cls: Type[CT], data: dict) -> Any: + def _standardize_type(self) -> Dict[str, Any]: """ Standardization factory that converts data according to the user-defined pydantic json_encoders prior to passing values to the standard type conversion factory """ - encoders = cls.__config__.json_encoders + encoders = self.__config__.json_encoders + data = dict(self) + for k, v in data.items(): v_type = type(v) if v_type in encoders: encode_method = encoders[v_type] data[k] = encode_method(v) - elif isinstance(v, dict): - cls.standardize_type(v) - - return standardize_custom_type(data) + return data def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> JsonDict: """ @@ -51,13 +49,14 @@ def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> Js It also doesn't provide the exclude, include, by_alias, etc. parameters that dict provides. """ - data = dict(self) - - standardize_method = standardize_factory or self.standardize_type - - # the standardize called can be replaced if we have a custom implementation of asdict - # for now I think is better to use the native implementation - return standardize_method(data) + standardize_method = standardize_factory or standardize_custom_type + + return { + field_name: standardize_method( + field_name=field_name, value=value, model=self, base_class=AvroBaseModel + ) + for field_name, value in self._standardize_type().items() + } def serialize(self, serialization_type: str = AVRO) -> bytes: """ @@ -67,7 +66,7 @@ def serialize(self, serialization_type: str = AVRO) -> bytes: schema = self.avro_schema_to_python() return serialization.serialize( - self.asdict(standardize_factory=self.standardize_type), + self.asdict(), schema, serialization_type=serialization_type, ) diff --git a/dataclasses_avroschema/schema_generator.py b/dataclasses_avroschema/schema_generator.py index 6ea830b4..ae45a5c5 100644 --- a/dataclasses_avroschema/schema_generator.py +++ b/dataclasses_avroschema/schema_generator.py @@ -117,10 +117,12 @@ def _reset_parser(cls: "Type[CT]") -> None: def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> JsonDict: if standardize_factory is not None: - return dataclasses.asdict( - self, - dict_factory=lambda x: {key: standardize_factory(value) for key, value in x}, - ) # type: ignore + return { + field.name: standardize_factory( + field_name=field.name, value=getattr(self, field.name), model=self, base_class=AvroModel + ) + for field in dataclasses.fields(self) # type: ignore + } return dataclasses.asdict(self) # type: ignore def serialize(self, serialization_type: str = AVRO) -> bytes: diff --git a/dataclasses_avroschema/utils.py b/dataclasses_avroschema/utils.py index 93cac67a..cb5b02b2 100644 --- a/dataclasses_avroschema/utils.py +++ b/dataclasses_avroschema/utils.py @@ -1,5 +1,6 @@ import dataclasses import enum +import types import typing from datetime import datetime, timezone from functools import lru_cache @@ -83,17 +84,33 @@ def rebuild_annotation(a_type: typing.Any, field_info: FieldInfo) -> typing.Type return Annotated[a_type, field_info] # type: ignore[return-value] -def standardize_custom_type(value: typing.Any) -> typing.Any: +def standardize_custom_type(*, field_name: str, value: typing.Any, model, base_class) -> typing.Any: if isinstance(value, dict): - return {k: standardize_custom_type(v) for k, v in value.items()} + return { + k: standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class) + for k, v in value.items() + } elif isinstance(value, list): - return [standardize_custom_type(v) for v in value] + return [ + standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class) for v in value + ] elif isinstance(value, tuple): - return tuple(standardize_custom_type(v) for v in value) + return tuple( + standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class) for v in value + ) elif isinstance(value, enum.Enum): return value.value - elif is_pydantic_model(type(value)) or is_faust_record(type(value)): # type: ignore[arg-type] - return standardize_custom_type(value.asdict()) + elif isinstance(value, base_class): + if is_faust_record(type(value)): # type: ignore[arg-type] + # we need to do a trick because we can not overrride asdict from faust.. + # once the function interface is introduced we can remove this check + asdict = value.standardize_type() + else: + asdict = value.asdict(standardize_factory=standardize_custom_type) + + if isinstance(model.__annotations__[field_name], types.UnionType): + return (value.__class__.__name__, asdict) + return asdict return value