Skip to content

Commit

Permalink
Reduce memory overhead for TransportableObject
Browse files Browse the repository at this point in the history
* Always represent TransportableObject internally as a single array of
bytes. Various properties, such as `header`, or `object_string`, decode
various segments of the byte array.

* Store the serialized object as raw picklebytes without
base64-encoding. As a result, `get_deserialized()` no longer needs to
create a temporary copy of the raw picklebytes. The data segment is
directly unpickled. Base64-encoding is applied to the data segment or
the entire internal buffer whenever a print friendly representation of
the `TransportableObject` is desired.

* Since the properties of `TransportableObject` are simply views into
the underlying buffer, `TransportableObject` may itself be serialized
efficiently by simply writing out the byte array.
`

Add backward compatibility layer for deserialization
  • Loading branch information
cjao committed Sep 11, 2024
1 parent 05bf94d commit ef32837
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 181 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for Python 3.11
- Removed official support for Python 3.8
- Improved memory overhead for operations involving TransportableObject

## [0.235.1-rc.0] - 2024-06-10

Expand Down
254 changes: 120 additions & 134 deletions covalent/_workflow/transportable_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import base64
import json
import platform
from typing import Any, Callable, Tuple
from typing import Any, Callable, Dict, Tuple

import cloudpickle

Expand All @@ -29,77 +29,12 @@
DATA_OFFSET_BYTES = 8
HEADER_OFFSET = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES
BYTE_ORDER = "big"


class _TOArchive:
"""Archived transportable object."""

def __init__(self, header: bytes, object_string: bytes, data: bytes):
"""
Initialize TOArchive.
Args:
header: Archived transportable object header.
object_string: Archived transportable object string.
data: Archived transportable object data.
Returns:
None
"""

self.header = header
self.object_string = object_string
self.data = data

def cat(self) -> bytes:
"""
Concatenate TOArchive.
Returns:
Concatenated TOArchive.
"""

header_size = len(self.header)
string_size = len(self.object_string)
data_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size + string_size
string_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size

data_offset = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER, signed=False)
string_offset = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER, signed=False)

return string_offset + data_offset + self.header + self.object_string + self.data

@staticmethod
def load(serialized: bytes, header_only: bool, string_only: bool) -> "_TOArchive":
"""
Load TOArchive object from serialized bytes.
Args:
serialized: Serialized transportable object.
header_only: Load header only.
string_only: Load string only.
Returns:
Archived transportable object.
"""

string_offset = TOArchiveUtils.string_offset(serialized)
header = TOArchiveUtils.parse_header(serialized, string_offset)
object_string = b""
data = b""

if not header_only:
data_offset = TOArchiveUtils.data_offset(serialized)
object_string = TOArchiveUtils.parse_string(serialized, string_offset, data_offset)

if not string_only:
data = TOArchiveUtils.parse_data(serialized, data_offset)
return _TOArchive(header, object_string, data)
TOBJ_FMT_STR = "0.1"


class TOArchiveUtils:
"""Utilities for reading serialized TransportableObjects"""

@staticmethod
def data_offset(serialized: bytes) -> int:
size64 = serialized[STRING_OFFSET_BYTES : STRING_OFFSET_BYTES + DATA_OFFSET_BYTES]
Expand All @@ -119,24 +54,38 @@ def string_byte_range(serialized: bytes) -> Tuple[int, int]:

@staticmethod
def data_byte_range(serialized: bytes) -> Tuple[int, int]:
"""Return byte range for the b64 picklebytes"""
"""Return byte range for the picklebytes"""
start_byte = TOArchiveUtils.data_offset(serialized)
return start_byte, -1

@staticmethod
def parse_header(serialized: bytes, string_offset: int) -> bytes:
def header(serialized: bytes) -> dict:
string_offset = TOArchiveUtils.string_offset(serialized)
header = serialized[HEADER_OFFSET:string_offset]
return header
return json.loads(header.decode("utf-8"))

@staticmethod
def parse_string(serialized: bytes, string_offset: int, data_offset: int) -> bytes:
def string_segment(serialized: bytes) -> bytes:
string_offset = TOArchiveUtils.string_offset(serialized)
data_offset = TOArchiveUtils.data_offset(serialized)
return serialized[string_offset:data_offset]

@staticmethod
def parse_data(serialized: bytes, data_offset: int) -> bytes:
def data_segment(serialized: bytes) -> bytes:
data_offset = TOArchiveUtils.data_offset(serialized)
return serialized[data_offset:]


class _ByteArrayFile:
"""File-like interface for appending to a bytearray."""

def __init__(self, buf: bytearray):
self._buf = buf

def write(self, data: bytes):
self._buf.extend(data)


class TransportableObject:
"""
A function is converted to a transportable object by serializing it using cloudpickle
Expand All @@ -149,13 +98,13 @@ class TransportableObject:
"""

def __init__(self, obj: Any) -> None:
b64object = base64.b64encode(cloudpickle.dumps(obj))
object_string_u8 = str(obj).encode("utf-8")
self._buffer = bytearray()

self._object = b64object.decode("utf-8")
self._object_string = object_string_u8.decode("utf-8")
# Reserve space for the byte offsets to be written at the end
self._buffer.extend(b"\0" * HEADER_OFFSET)

self._header = {
_header = {
"format": TOBJ_FMT_STR,
"py_version": platform.python_version(),
"cloudpickle_version": cloudpickle.__version__,
"attrs": {
Expand All @@ -164,23 +113,48 @@ def __init__(self, obj: Any) -> None:
},
}

# Write header and object string
header_u8 = json.dumps(_header).encode("utf-8")
header_len = len(header_u8)

object_string_u8 = str(obj).encode("utf-8")
object_string_len = len(object_string_u8)

self._buffer.extend(header_u8)
self._buffer.extend(object_string_u8)
del object_string_u8

# Append picklebytes (not base64-encoded)
cloudpickle.dump(obj, _ByteArrayFile(self._buffer))

# Write byte offsets
string_offset = HEADER_OFFSET + header_len
data_offset = string_offset + object_string_len

string_offset_bytes = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER)
data_offset_bytes = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER)
self._buffer[:STRING_OFFSET_BYTES] = string_offset_bytes
self._buffer[STRING_OFFSET_BYTES:HEADER_OFFSET] = data_offset_bytes

@property
def python_version(self):
return self._header["py_version"]
return self.header["py_version"]

@property
def header(self):
return self._header
return TOArchiveUtils.header(self._buffer)

@property
def attrs(self):
return self._header["attrs"]
return self.header["attrs"]

@property
def object_string(self):
# For compatibility with older Covalent
try:
return self._object_string
return (
TOArchiveUtils.string_segment(memoryview(self._buffer)).tobytes().decode("utf-8")
)
except AttributeError:
return self.__dict__["object_string"]

Expand All @@ -201,11 +175,15 @@ def get_deserialized(self) -> Callable:
"""

return cloudpickle.loads(base64.b64decode(self._object.encode("utf-8")))
return cloudpickle.loads(TOArchiveUtils.data_segment(memoryview(self._buffer)))

def to_dict(self) -> dict:
"""Return a JSON-serializable dictionary representation of self"""
return {"type": "TransportableObject", "attributes": self.__dict__.copy()}
attr_dict = {
"buffer_b64": base64.b64encode(memoryview(self._buffer)).decode("utf-8"),
}

return {"type": "TransportableObject", "attributes": attr_dict}

@staticmethod
def from_dict(object_dict) -> "TransportableObject":
Expand All @@ -219,7 +197,7 @@ def from_dict(object_dict) -> "TransportableObject":
"""

sc = TransportableObject(None)
sc.__dict__ = object_dict["attributes"]
sc._buffer = base64.b64decode(object_dict["attributes"]["buffer_b64"].encode("utf-8"))
return sc

def get_serialized(self) -> str:
Expand All @@ -233,7 +211,9 @@ def get_serialized(self) -> str:
object: The serialized transportable object.
"""

return self._object
# For backward compatibility
data_segment = TOArchiveUtils.data_segment(memoryview(self._buffer))
return base64.b64encode(data_segment).decode("utf-8")

def serialize(self) -> bytes:
"""
Expand All @@ -246,7 +226,7 @@ def serialize(self) -> bytes:
pickled_object: The serialized object alongwith the python version.
"""

return _to_archive(self).cat()
return self._buffer

def serialize_to_json(self) -> str:
"""
Expand Down Expand Up @@ -295,9 +275,7 @@ def make_transportable(obj) -> "TransportableObject":
return TransportableObject(obj)

@staticmethod
def deserialize(
serialized: bytes, *, header_only: bool = False, string_only: bool = False
) -> "TransportableObject":
def deserialize(serialized: bytes) -> "TransportableObject":
"""
Deserialize the transportable object.
Expand All @@ -307,9 +285,58 @@ def deserialize(
Returns:
object: The deserialized transportable object.
"""
to = TransportableObject(None)
header = TOArchiveUtils.header(serialized)

# For backward compatibility
if header.get("format") is None:
# Re-encode TObj serialized using older versions of the SDK,
# characterized by the lack of a "format" field in the
# header. TObj was previously serialized as
# [offsets][header][string][b64-encoded picklebytes],
# whereas starting from format 0.1 we store them as
# [offsets][header][string][picklebytes].
to._buffer = TransportableObject._upgrade_tobj_format(serialized, header)
else:
to._buffer = serialized
return to

@staticmethod
def _upgrade_tobj_format(serialized: bytes, header: Dict) -> bytes:
"""Re-encode a serialized TObj in the newer format.
This involves adding a format version in the header and
base64-decoding the data segment. Because the header at the
beginning of the byte array, the string and data offsets need
to be recomputed.
"""
buf = bytearray()

# Upgrade header and recompute byte offsets
header["format"] = TOBJ_FMT_STR
serialized_header = json.dumps(header).encode("utf-8")
string_offset = HEADER_OFFSET + len(serialized_header)

# This is just a view into the bytearray and consumes
# negligible space on its own.
string_segment = TOArchiveUtils.string_segment(serialized)

data_offset = string_offset + len(string_segment)
string_offset_bytes = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER)
data_offset_bytes = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER)

# Write the new byte offsets
buf.extend(b"\0" * HEADER_OFFSET)
buf[:STRING_OFFSET_BYTES] = string_offset_bytes
buf[STRING_OFFSET_BYTES:HEADER_OFFSET] = data_offset_bytes

ar = _TOArchive.load(serialized, header_only, string_only)
return _from_archive(ar)
buf.extend(serialized_header)
buf.extend(string_segment)

# base64-decode the data segment into raw picklebytes
buf.extend(base64.b64decode(TOArchiveUtils.data_segment(serialized)))

return buf

@staticmethod
def deserialize_list(collection: list) -> list:
Expand Down Expand Up @@ -356,44 +383,3 @@ def deserialize_dict(collection: dict) -> dict:
else:
raise TypeError("Couldn't deserialize collection")
return new_dict


def _to_archive(to: TransportableObject) -> _TOArchive:
"""
Convert a TransportableObject to a _TOArchive.
Args:
to: Transportable object to be converted.
Returns:
Archived transportable object.
"""

header = json.dumps(to._header).encode("utf-8")
object_string = to._object_string.encode("utf-8")
data = to._object.encode("utf-8")
return _TOArchive(header=header, object_string=object_string, data=data)


def _from_archive(ar: _TOArchive) -> TransportableObject:
"""
Convert a _TOArchive to a TransportableObject.
Args:
ar: Archived transportable object.
Returns:
Transportable object.
"""

decoded_object_str = ar.object_string.decode("utf-8")
decoded_data = ar.data.decode("utf-8")
decoded_header = json.loads(ar.header.decode("utf-8"))
to = TransportableObject(None)
to._header = decoded_header
to._object_string = decoded_object_str or ""
to._object = decoded_data or ""

return to
3 changes: 2 additions & 1 deletion tests/covalent_dispatcher_tests/_service/assets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

"""Unit tests for the FastAPI asset endpoints"""

import base64
import tempfile
from contextlib import contextmanager
from typing import Generator
Expand Down Expand Up @@ -704,7 +705,7 @@ def test_get_pickle_offsets():

start, end = _get_tobj_pickle_offsets(f"file://{write_file.name}")

assert data[start:].decode("utf-8") == tobj.get_serialized()
assert data[start:] == base64.b64decode(tobj.get_serialized().encode("utf-8"))


def test_generate_partial_file_slice():
Expand Down
Loading

0 comments on commit ef32837

Please sign in to comment.