Skip to content

Commit

Permalink
chore: There's a problem with deserializing nested enum with simple f…
Browse files Browse the repository at this point in the history
…ield. Dataclasses_json doesn't like that.
  • Loading branch information
kulikthebird committed Nov 22, 2024
1 parent 121bfb4 commit a5381b0
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 43 deletions.
102 changes: 60 additions & 42 deletions packages/cw-schema-codegen/playground/playground.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,54 @@
from dataclasses import dataclass, field
from dataclasses import dataclass, field, fields
from dataclasses_json import dataclass_json, config
from typing import Optional, Iterable
import sys
import json


# TODO tkulik: try to get rid of the `dataclasses_json` dependency

# TODO tkulik: Iterable['SomeEnum'] does not check the types. It doesn't even call from_json...

enum_field = lambda: field(default=None, metadata=config(exclude=lambda x: x is None))

def serialize_enum(func, *simple_variants):
def serialize(self):
for variant in simple_variants:
if getattr(self, variant) is not None:
return f'"{variant}"'
return func(self)
return serialize

def deserialize_enum(func, *simple_variants):
def deserialize(json):
if not ":" in json:
for variant in simple_variants:
if json == f'"{variant}"':
kwargs = { f"{variant}": SomeEnum.VariantIndicator() }
return SomeEnum(**kwargs)
raise Exception(f"Deserialization error, undefined variant: {json}")
return func(json)
return deserialize

def unit_structure_serialize(_slf):
return 'null'

def unit_structure_deserialize(json):
if json == 'null':
return UnitStructure()
else:
raise Exception(f"Deserialization error, undefined value: {json}")

def tuple_serialize(func):
def serialize(self):
return json.dumps(self.Tuple)
return serialize

def tuple_deserialize(func):
def deserialize(json):
return func(f'{{ "Tuple": {json} }}')
return deserialize


@dataclass_json
@dataclass
class SomeEnum:
Expand All @@ -28,57 +67,36 @@ class Field5Type:
Field3: Optional[Field3Type] = enum_field()
Field4: Optional[Iterable['SomeEnum']] = enum_field()
Field5: Optional[Field5Type] = enum_field()

def deserialize(json):
if not ":" in json:
if json == '"Field1"':
return SomeEnum(Field1=SomeEnum.VariantIndicator())
else:
raise Exception(f"Deserialization error, undefined variant: {json}")
else:
return SomeEnum.from_json(json)

def serialize(self):
if self.Field1 is not None:
return '"Field1"'
else:
return SomeEnum.to_json(self)


SomeEnum.to_json = serialize_enum(SomeEnum.to_json, "Field1")
SomeEnum.from_json = deserialize_enum(SomeEnum.from_json, "Field1")


@dataclass_json
@dataclass
class UnitStructure:
def deserialize(json):
if json == "null":
return UnitStructure()
else:
Exception(f"Deserialization error, undefined value: {json}")

def serialize(self):
return 'null'
pass

UnitStructure.to_json = unit_structure_serialize
UnitStructure.from_json = unit_structure_deserialize


@dataclass_json
@dataclass
class TupleStructure:
Tuple: tuple[int, str, int]

def deserialize(json):
return TupleStructure.from_json(f'{{ "Tuple": {json} }}')

def serialize(self):
return json.dumps(self.Tuple)
TupleStructure.to_json = tuple_serialize(TupleStructure.to_json)
TupleStructure.from_json = tuple_deserialize(TupleStructure.from_json)


@dataclass_json
@dataclass
class NamedStructure:
a: str
b: int
c: Iterable['SomeEnum']
c: SomeEnum

def deserialize(json):
return NamedStructure.from_json(json)

def serialize(self):
return self.to_json()

###
### TESTS:
Expand All @@ -88,16 +106,16 @@ def serialize(self):
input = input.rstrip()
try:
if index < 5:
deserialized = SomeEnum.deserialize(input)
deserialized = SomeEnum.from_json(input)
elif index == 5:
deserialized = UnitStructure.deserialize(input)
deserialized = UnitStructure.from_json(input)
elif index == 6:
deserialized = TupleStructure.deserialize(input)
deserialized = TupleStructure.from_json(input)
else:
deserialized = NamedStructure.deserialize(input)
deserialized = NamedStructure.from_json(input)
except:
raise(Exception(f"This json can't be deserialized: {input}"))
serialized = deserialized.serialize()
serialized = deserialized.to_json()
print(serialized)


Expand Down
3 changes: 2 additions & 1 deletion packages/cw-schema-codegen/playground/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub enum SomeEnum {
},
Field4(Box<SomeEnum>),
Field5 { a: Box<SomeEnum> },
Field6
}

#[derive(Serialize, Deserialize)]
Expand All @@ -37,7 +38,7 @@ fn main() {
println!("{}", serde_json::to_string(&SomeEnum::Field5 { a: Box::new(SomeEnum::Field1) }).unwrap());
println!("{}", serde_json::to_string(&UnitStructure {}).unwrap());
println!("{}", serde_json::to_string(&TupleStructure(10, "aasdf".to_string(), 2)).unwrap());
println!("{}", serde_json::to_string(&NamedStructure {a: "awer".to_string(), b: 4, c: SomeEnum::Field1}).unwrap());
println!("{}", serde_json::to_string(&NamedStructure {a: "awer".to_string(), b: 4, c: SomeEnum::Field6}).unwrap());
}

#[cfg(feature = "deserialize")]
Expand Down

0 comments on commit a5381b0

Please sign in to comment.