Skip to content

Commit

Permalink
[export] Move serialized custom class objs to toplevel (pytorch#114371)
Browse files Browse the repository at this point in the history
Summary:
Move the serialized CustomClassHolder objects to the toplevel SerializedArtifact instead of embedding the bytes in the graph.

Currently the CustomClassHolder objects are embedded in the graph instead of being lifted to the ExportedProgram, so there's some logic introduced to lift it to the higher level of the serialized ExportedProgram. However, once that CustomClassHolder objects get lifted, we can remove the TODOs I added.

Test Plan: CI

Reviewed By: zhxchen17

Differential Revision: D51479125

Pull Request resolved: pytorch#114371
Approved by: https://github.com/ydwu4
  • Loading branch information
angelayi authored and pytorchmergebot committed Nov 22, 2023
1 parent 6a86cf0 commit f961bda
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
2 changes: 1 addition & 1 deletion torch/_export/serde/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class GraphArgument:

@dataclass
class CustomObjArgument:
blob: bytes
name: str


# This is actually a union type
Expand Down
56 changes: 35 additions & 21 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import math
import operator
import pickle
import typing

from contextlib import contextmanager
Expand Down Expand Up @@ -166,7 +165,7 @@ def _reverse_map(d: Dict[Any, Enum]):
class SerializedArtifact:
exported_program: Union[ExportedProgram, bytes]
state_dict: bytes
tensor_constants: bytes
constants: bytes


def deserialize_device(d: Device) -> torch.device:
Expand Down Expand Up @@ -303,7 +302,6 @@ class GraphState:
sym_int_values: Dict[str, SymInt] = field(default_factory=dict)
sym_bool_values: Dict[str, SymBool] = field(default_factory=dict)
is_single_tensor_return: bool = False
constants: Dict[str, torch.Tensor] = field(default_factory=dict)


class GraphModuleSerializer:
Expand All @@ -315,6 +313,7 @@ def __init__(
self.graph_state = GraphState()
self.graph_signature = graph_signature
self.module_call_graph = module_call_graph
self.custom_objs: Dict[str, torch._C.ScriptObject] = {}

@contextmanager
def save_graph_state(self):
Expand Down Expand Up @@ -640,19 +639,20 @@ def serialize_optional_tensor_args(a):
return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg])
elif isinstance(arg, torch._C.ScriptObject):
if not (
hasattr(type(arg), "__getstate__") and
hasattr(type(arg), "__setstate__")
arg._has_method("__getstate__") and # type: ignore[attr-defined]
arg._has_method("__setstate__") # type: ignore[attr-defined]
):
raise SerializeError(
f"Unable to serialize ScriptObject {arg}. Please define "
f"Unable to serialize custom class {arg}. Please define "
"serialization methods via def_pickle()."
)
# Custom objects through torchind are serializable with pickle,
# through implementing the .def_pickle function. This should result
# in the object containing a __getstate__ and __setstate__
# serialize/deserialize function.
blob = pickle.dumps(arg)
return Argument.create(as_custom_obj=CustomObjArgument(blob))
custom_obj_name = f"_custom_obj_{len(self.custom_objs)}"
self.custom_objs[custom_obj_name] = arg
return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name))
else:
raise SerializeError(f"Unsupported argument type: {type(arg)}")

Expand Down Expand Up @@ -949,14 +949,23 @@ def __init__(self, opset_version: Optional[Dict[str, int]] = None):
self.opset_version["aten"] = torch._C._get_max_operator_version()

def serialize(self, exported_program: ep.ExportedProgram) -> SerializedArtifact:
serialized_graph_module = (
GraphModuleSerializer(
exported_program.graph_signature,
exported_program.module_call_graph
).serialize(exported_program.graph_module)
gm_serializer = GraphModuleSerializer(
exported_program.graph_signature,
exported_program.module_call_graph
)
serialized_graph_module = gm_serializer.serialize(exported_program.graph_module)
serialized_range_constraints = serialize_range_constraints(exported_program.range_constraints)

# TODO: Directly serialize exported_program.constants once
# CustomClassHolders get stored in the ExportedProgram rather than in
# the graph
constants = {}
for n, c in gm_serializer.custom_objs.items():
constants[n] = c
for n, t in exported_program.tensor_constants.items():
assert n not in constants
constants[n] = t

return SerializedArtifact(
ExportedProgram(
graph_module=serialized_graph_module,
Expand All @@ -966,7 +975,7 @@ def serialize(self, exported_program: ep.ExportedProgram) -> SerializedArtifact:
dialect=exported_program.dialect,
),
serialize_torch_artifact(exported_program.state_dict),
serialize_torch_artifact(exported_program.tensor_constants),
serialize_torch_artifact(constants),
)


Expand Down Expand Up @@ -1251,6 +1260,7 @@ def deserialize(
self,
serialized_graph_module: GraphModule,
symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None,
constants: Optional[Dict[str, Any]] = None,
) -> Result:
self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True)
self.fake_tensor_mode = FakeTensorMode(
Expand All @@ -1260,6 +1270,7 @@ def deserialize(
)
self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {}
self.symbol_name_to_range = {} if symbol_name_to_range is None else symbol_name_to_range
self.constants = {} if constants is None else constants

self.deserialize_graph(serialized_graph_module.graph)

Expand Down Expand Up @@ -1357,10 +1368,7 @@ def deserialize_optional_tensor_args(a):
else:
raise SerializeError(f"Unhandled argument {inp}")
elif isinstance(value, CustomObjArgument):
# Custom objects through torchind are deserializable with pickle,
# through implementing the .def_pickle function.
blob = base64.b64decode(value.blob)
return pickle.loads(blob)
return self.constants[value.name]
else:
raise SerializeError(f"Unhandled argument {inp}")

Expand Down Expand Up @@ -1549,12 +1557,19 @@ def deserialize(
k: symbolic_shapes.ValueRanges(_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val))
for k, v in serialized_artifact.exported_program.range_constraints.items()
}
constants = deserialize_torch_artifact(serialized_artifact.constants)

# TODO: No need to do this once CustomClassHolders are lifted to the ExportedProgram
tensor_constants = {
k: v for k, v in constants.items() if isinstance(v, torch.Tensor)
}

res = (
GraphModuleDeserializer()
.deserialize(
serialized_artifact.exported_program.graph_module,
symbol_name_to_range,
constants,
)
)
range_constraints = self.deserialize_range_constraints(
Expand All @@ -1566,7 +1581,6 @@ def deserialize(
upgrader = GraphModuleOpUpgrader(self.expected_opset_version, model_opset_version)

state_dict = deserialize_torch_artifact(serialized_artifact.state_dict)
tensor_constants = deserialize_torch_artifact(serialized_artifact.tensor_constants)

exported_program = ep.ExportedProgram(
res.graph_module,
Expand Down Expand Up @@ -1648,7 +1662,7 @@ def serialize(
artifact = SerializedArtifact(
json_bytes,
serialized_artifact.state_dict,
serialized_artifact.tensor_constants
serialized_artifact.constants
)
return artifact

Expand Down Expand Up @@ -1705,7 +1719,7 @@ def deserialize(
SerializedArtifact(
serialized_exported_program,
artifact.state_dict,
artifact.tensor_constants
artifact.constants
)
)
)

0 comments on commit f961bda

Please sign in to comment.