Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avro schemas containing a union where two subschemas have same fields but different name result in incorrect serialization #749

Closed
mauro-palsgraaf opened this issue Sep 16, 2024 · 6 comments

Comments

@mauro-palsgraaf
Copy link

Context
We use this library in combination with Kafka where we have a topic containing multiple type of events. All must be on the same topic since we need the ordering.

Describe the bug
Union types where subschemas have the same fields and types do not produce the correct binary output according to the avro specification. This results in losing compatability with other languages following the specification. The problem is being caused by the following code in AvroBaseModel:

def serialize(self, serialization_type: str = AVRO) -> bytes:
    """
    Overrides the base AvroModel's serialize method to inject this
    class's standardization factory method
    """
    schema = self.avro_schema_to_python()

    return serialization.serialize(
        self.asdict(standardize_factory=self.standardize_type),
        schema,
        serialization_type=serialization_type,
    )

By transforming the pydantic object to a dict and passing it further, the typing information is lost. The fastavro library will determine the index by finding the type in the union that has the most fields in common. As the example below demonstrates, all fields are the same and it will just pick the first one.

A possible solution would be to iterate over the schema after turning it to a dict and then add the type information as the first item in the tuple. Fastavro will be able to handle that and will use the type information to determine the int for the binary output instead of the most fields. See the write_union in fastavro below:

def write_union(encoder, datum, schema, named_schemas, fname, options):
    """A union is encoded by first writing a long value indicating the
    zero-based position within the union of the schema of its value. The value
    is then encoded per the indicated schema within the union."""

    best_match_index = -1
    if isinstance(datum, tuple) and not options.get("disable_tuple_notation"):
        (name, datum) = datum
        for index, candidate in enumerate(schema):
            extracted_type = extract_record_type(candidate)
            if extracted_type in NAMED_TYPES:
                schema_name = candidate["name"]
            else:
                schema_name = extracted_type
            if name == schema_name:
                best_match_index = index
                break

        if best_match_index == -1:
            field = f"on field {fname}" if fname else ""
            msg = (
                f"provided union type name {name} not found in schema "
                + f"{schema} {field}"
            )
            raise ValueError(msg)
        index = best_match_index
    else:
        pytype = type(datum)
        most_fields = -1

To Reproduce
Consider the following schema (very minimal example):

{
  "namespace": "com.example",
  "type": "record",
  "name": "TopicEvents",
  "fields": [
    {
      "name": "event", "type": [
        "EventOne",
        "EventTwo",
      ]
    }
  ]
}

{
  "namespace": "com.example",
  "type": "record",
  "name": "EventOne",
  "fields": [{ "name": "field1", "type": "string", "schema_tag": "type": "string" }]
}

{
  "namespace": "com.example",
  "type": "record",
  "name": "EventTwo",
  "fields": [{ "name": "field1", "type": "string", "schema_tag": "type": "string" }]
}

An example that is incorrect would be:

from typing import Literal
from pydantic import Field
from dataclasses_avroschema.pydantic.main import AvroBaseModel

class EventOne(AvroBaseModel):
    field1: str = Field(...)
    schema_tag: Literal["EventOne"] = Field(
        default="EventOne"
    )

class EventTwo(AvroBaseModel):
    field1: str = Field(...)
    schema_tag: Literal["EventTwo"] = Field(
        default="EventTwo"
    )

class Events(AvroBaseModel):
    event: EventOne | EventTwo = Field(discriminator="schema_tag")
    
serialized = Events(event=EventTwo(field1="hello world")).serialize()

print(serialized)

The result is: b'\x00\x16hello world\x10EventTwo', but the expected result would be: b'\x02\x16hello world\x10EventTwo' according to the avro specification, since the int in the beginning indicates the index of the type of the union. Currently, the result of serializing Events(event=EventOne(field1="hello world")).serialize() and Events(event=EventTwo(field1="hello world")).serialize() is the same.

Expected behavior
Events(event=EventTwo(field1="hello world")).serialize() should result in the following value:b'\x02\x16hello world\x10EventTwo'

As soon as we figured out a solution and agree on how to fix it, i can help with the implementation if necessary

@marcosschroh
Copy link
Owner

Hi, thanks for reporting the bug. I did not know that the int in the beginning indicates the index of the type of the union, it makes sense. I think this issue is also happening when using dataclasses as well.

This is related to #742 and #584.

If we confirm this we should reopen both issues.

@mauro-palsgraaf
Copy link
Author

Didn't know it either, but as a reference: https://avro.apache.org/docs/1.11.1/specification/#complex-types-1

#742 is exactly the same, #584 is the same but for dataclasses indeed.

It's not complete, but the serialization is correct whenever the type name is added as a tuple:

    def serialize(self, serialization_type: str = AVRO) -> bytes:
        """
        Overrides the base AvroModel's serialize method to inject this
        class's standardization factory method
        """
        schema = self.avro_schema_to_python()

        data = self.asdict(standardize_factory=self.standardize_type)

        data["event"] = (type(self.event).__name__, data["event"])

        return serialization.serialize(
            data,
            schema,
            serialization_type=serialization_type,
        )

Maybe we can find the unions given the schema, alter the result of asdict with the type information and pass that to the serialization.serialize()?

@marcosschroh
Copy link
Owner

marcosschroh commented Sep 17, 2024

Didn't know it either, but as a reference: https://avro.apache.org/docs/1.11.1/specification/#complex-types-1

Nice, good to know! So definitely we should fix this problem and attach the type to unions.

I tried your suggestion using dataclasses:

import dataclasses
from typing import Literal

from dataclasses_avroschema import AvroModel


@dataclasses.dataclass
class EventOne(AvroModel):
    field1: str
    schema_tag: Literal["EventOne"] = "EventOne"

@dataclasses.dataclass
class EventTwo(AvroModel):
    field1: str
    schema_tag: Literal["EventTwo"] = "EventTwo"

@dataclasses.dataclass
class Events(AvroModel):
    event: EventOne | EventTwo


# It does not perform data validation with dataclasses
data = {'event': ('EventTwo', {'field1': 'hello world', 'schema_tag': 'EventTwo'})}

event = Events.parse_obj(data=data)
print(event)
# >>> Events(event=('EventTwo', {'field1': 'hello world', 'schema_tag': 'EventTwo'}))

serialized = event.serialize()
print(serialized)

# We have the proper byte
# >>> b'\x02\x16hello world\x10EventTwo'

# Then when deserialize
print(Events.deserialize(serialized))

# fastavro returns {'event': {'field1': 'hello world', 'schema_tag': 'EventTwo'}}

# THE WRONG ONE
# >>> Events(event=EventOne(field1='hello world', schema_tag='EventTwo'))

The method deserialize is calling a wrapper on fastavro, so it seems to me that there is a bug in fastavro in the deserialization process as well. The serialization works as expected.

@mauro-palsgraaf
Copy link
Author

mauro-palsgraaf commented Sep 17, 2024

So I've just had a quick look at the deserialize method in serialization.py where a part looks like this (as reference):

if serialization_type == "avro":
    input_stream: typing.Union[io.BytesIO, io.StringIO] = io.BytesIO(data)

    payload = fastavro.schemaless_reader(
        input_stream,
        writer_schema=writer_schema or schema,
        reader_schema=schema,
    )

The fastavro.schemaless_reader method seems to have options to configure this behavior:

def schemaless_reader(
    fo: IO,
    writer_schema: Schema,
    reader_schema: Optional[Schema] = None,
    return_record_name: bool = False,
    return_record_name_override: bool = False,
    handle_unicode_errors: str = "strict",
    return_named_type: bool = False,
    return_named_type_override: bool = False,
) -> AvroMessage:

I've just set the return_named_type = True to the schemaless_reader call. This gives the following result for dataclasses:

print(Events.deserialize(serialized))
# Events(event=('EventTwo', {'field1': 'hello world', 'schema_tag': 'EventTwo'}))

and the following result for pydantic:

print(Events.deserialize(serialized))
# pydantic_core._pydantic_core.ValidationError: 1 validation error for Events
# event
#   Input should be a valid dictionary or object to extract fields from [type=model_attributes_type, input_value=('EventOne', {'field1': '...chema_tag': 'EventTwo'}), input_type=tuple]
#    For further information visit https://errors.pydantic.dev/2.8/v/model_attributes_type

I think we need to do some extra processing in parse_obj in AvroModel and AvroBaseModel to use the first element in the tuple to determine the type in case of a union.

@marcosschroh
Copy link
Owner

@mauro-palsgraaf with the latest version it is fixed:

from typing import Literal

from pydantic import Field

from dataclasses_avroschema.pydantic.main import AvroBaseModel


class EventOne(AvroBaseModel):
    field1: str = Field(...)
    schema_tag: Literal["EventOne"] = Field(
        default="EventOne"
    )

class EventTwo(AvroBaseModel):
    field1: str = Field(...)
    schema_tag: Literal["EventTwo"] = Field(
        default="EventTwo"
    )

class Events(AvroBaseModel):
    event: EventOne | EventTwo = Field(discriminator="schema_tag")
    
serialized = Events(event=EventTwo(field1="hello world")).serialize()

print(serialized)
# >>> b'\x02\x16hello world\x10EventTwo'

print(Events.deserialize(serialized))
# >>> event=EventTwo(field1='hello world', schema_tag='EventTwo')

PS: It also works without Field(discriminator="schema_tag")

@mauro-palsgraaf
Copy link
Author

Really nice, thank you for fixing this so quickly! Will happily test this out on Monday 🙂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants