Skip to content

Commit

Permalink
Merge pull request #9 from gizatechxyz/refactor-osiris-deserializer
Browse files Browse the repository at this point in the history
Refactor osiris deserializer
  • Loading branch information
raphaelDkhn authored Mar 14, 2024
2 parents 058094d + a238162 commit de7518c
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 171 deletions.
228 changes: 72 additions & 156 deletions osiris/cairo/serde/deserialize.py
Original file line number Diff line number Diff line change
@@ -1,172 +1,88 @@
import json

import numpy as np

from .utils import felt_to_int, from_fp


def deserializer(serialized: str, dtype: str):
# Check if the serialized data is a string and needs conversion
if isinstance(serialized, str):
serialized = convert_data(serialized)

# Function to deserialize individual elements within a tuple
def deserialize_element(element, element_type):
if element_type in ("u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"):
return deserialize_int(element)
elif element_type.startswith("FP"):
return deserialize_fixed_point(element, element_type)
elif element_type.startswith("Span<") and element_type.endswith(">"):
inner_type = element_type[5:-1]
if inner_type.startswith("FP"):
return deserialize_arr_fixed_point(element, inner_type)
else:
return deserialize_arr_int(element)
elif element_type.startswith("Tensor<") and element_type.endswith(">"):
inner_type = element_type[7:-1]
if inner_type.startswith("FP"):
return deserialize_tensor_fixed_point(element, inner_type)
else:
return deserialize_tensor_int(element)
elif element_type.startswith("(") and element_type.endswith(")"):
# Recursive call for nested tuples
return deserializer(element, element_type)
else:
raise ValueError(f"Unsupported data type: {element_type}")

# Handle tuple data type
if dtype.startswith("(") and dtype.endswith(")"):
types = dtype[1:-1].split(", ")
deserialized_elements = []
i = 0 # Initialize loop counter

while i < len(serialized):
ele_type = types[len(deserialized_elements)]

if ele_type.startswith("Tensor<"):
# For Tensors, take two elements from serialized (shape and data)
ele = serialized[i:i+2]
i += 2
else:
# For other types, take one element
ele = serialized[i]
i += 1

if ele_type.startswith("Tensor<"):
deserialized_elements.append(
deserialize_element(ele, ele_type))
else:
deserialized_elements.append(
deserialize_element([ele], ele_type))

if len(deserialized_elements) != len(types):
raise ValueError(
"Serialized data length does not match tuple length")

return tuple(deserialized_elements)

else:
return deserialize_element(serialized, dtype)


def parse_return_value(return_value):
"""
Parse a ReturnValue dictionary to extract the integer value or recursively parse an array of ReturnValues (cf: OrionRunner ReturnValues).
"""
if 'Int' in return_value:
# Convert hexadecimal string to integer
return int(return_value['Int'], 16)
elif 'Array' in return_value:
# Recursively parse each item in the array
return [parse_return_value(item) for item in return_value['Array']]
else:
raise ValueError("Invalid ReturnValue format")


def convert_data(data):
"""
Convert the given JSON-like data structure to the desired format.
"""
parsed_data = json.loads(data)
result = []
for item in parsed_data:
# Parse each item based on its keys
if 'Array' in item:
# Process array items
result.append(parse_return_value(item))
elif 'Int' in item:
# Process single int items
result.append(parse_return_value(item))
else:
raise ValueError("Invalid data format")
return result


# ================= INT =================


def deserialize_int(serialized: list) -> np.int64:
return np.int64(felt_to_int(serialized[0]))


# ================= FIXED POINT =================


def deserialize_fixed_point(serialized: list, impl='FP16x16') -> np.float64:
serialized_mag = from_fp(serialized[0], impl)
serialized_sign = serialized[1]

deserialized = serialized_mag if serialized_sign == 0 else -serialized_mag
return np.float64(deserialized)
from osiris.cairo.serde.utils import felt_to_int, from_fp


# ================= ARRAY INT =================
def deserializer(serialized, dtype):

if dtype in ["u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]:
return felt_to_int(int(serialized))

def deserialize_arr_int(serialized):
elif dtype.startswith("FP"):
return deserialize_fp(serialized)

serialized = serialized[0]
elif dtype.startswith('Span<'):
return deserialize_span(serialized, dtype)

deserialized = []
for ele in serialized:
deserialized.append(felt_to_int(ele))
elif dtype.startswith('Tensor<'):
return deserialize_tensor(serialized, dtype)

return np.array(deserialized)

# ================= ARRAY FIXED POINT =================


def deserialize_arr_fixed_point(serialized: list, impl='FP16x16'):

serialized = serialized[0]

if len(serialized) % 2 != 0:
raise ValueError("Array length must be even")

deserialized = []
for i in range(0, len(serialized), 2):
mag = serialized[i]
sign = serialized[i + 1]

deserialized.append(deserialize_fixed_point([mag, sign], impl))

return np.array(deserialized)


# ================= TENSOR INT =================
elif dtype.startswith('('): # Tuple
return deserialize_tuple(serialized, dtype)

else:
raise ValueError(f"Unknown data type: {dtype}")

def deserialize_tensor_int(serialized: list) -> np.array:
shape = serialized[0]
data = deserialize_arr_int([serialized[1]])

return np.array(data, dtype=np.int64).reshape(shape)
def deserialize_fp(serialized):
parts = serialized.split()
value = from_fp(int(parts[0]))
if len(parts) > 1 and parts[1] == '1': # Check for negative sign
value = -value
return value


# ================= TENSOR FIXED POINT =================
def deserialize_span(serialized, dtype):
inner_type = dtype[5:-1]
elements = serialized[1:-1].split()
if inner_type.startswith("FP"):
# For fixed point, elements consist of two parts (value and sign)
deserialized_elements = [deserializer(' '.join(elements[i:i + 2]), inner_type)
for i in range(0, len(elements), 2)]
return np.array(deserialized_elements, dtype=np.float64)
else:
return np.array([deserializer(e, inner_type) for e in elements], dtype=np.int64)

def deserialize_tensor_fixed_point(serialized: list, impl='FP16x16') -> np.array:
shape = serialized[0]
data = deserialize_arr_fixed_point([serialized[1]], impl)

return np.array(data, dtype=np.float64).reshape(shape)
def deserialize_tensor(serialized, dtype):
inner_type = dtype[7:-1]
parts = serialized.split('] [')
dims = [int(d) for d in parts[0][1:].split()]
values = parts[1][:-1].split()
if inner_type.startswith("FP"):
tensor_data = np.array([deserializer(' '.join(values[i:i + 2]), inner_type)
for i in range(0, len(values), 2)])
else:
tensor_data = np.array(
[deserializer(v, inner_type) for v in values])
return tensor_data.reshape(dims)


def deserialize_tuple(serialized, dtype):
types = dtype[1:-1].split(', ')
if 'Tensor' in types[0]:
tensor_end = find_nth_occurrence(serialized, ']', 2)
depth = 1
for i in range(tensor_end, len(serialized)):
if serialized[i] == '[':
depth += 1
elif serialized[i] == ']':
depth -= 1
if depth == 0:
tensor_end = i + 1
break
part1 = deserializer(serialized[:tensor_end].strip(), types[0])
part2 = deserializer(serialized[tensor_end:].strip(), types[1])
else:
split_index = serialized.find(']') + 2
part1 = deserializer(serialized[:split_index].strip(), types[0])
part2 = deserializer(serialized[split_index:].strip(), types[1])
return part1, part2


def find_nth_occurrence(string, sub_string, n):
start_index = string.find(sub_string)
while start_index >= 0 and n > 1:
start_index = string.find(sub_string, start_index + 1)
n -= 1
return start_index
29 changes: 14 additions & 15 deletions tests/test_deserialize.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,87 @@
import numpy as np
import numpy.testing as npt
import pytest
from math import isclose

from osiris.cairo.serde.deserialize import *


def test_deserialize_int():
serialized = '[{"Int":"2A"}]'
serialized = '42'
deserialized = deserializer(serialized, 'u32')
assert deserialized == 42

serialized = '[{"Int":"800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}]'
serialized = '3618502788666131213697322783095070105623107215331596699973092056135872020439'
deserialized = deserializer(serialized, 'i32')
assert deserialized == -42


def test_deserialize_fp():
serialized = '[{"Int":"2A6B85"}, {"Int":"0"}]'
serialized = '2780037 0'
deserialized = deserializer(serialized, 'FP16x16')
assert isclose(deserialized, 42.42, rel_tol=1e-7)

serialized = '[{"Int":"2A6B85"}, {"Int":"1"}]'
serialized = '2780037 1'
deserialized = deserializer(serialized, 'FP16x16')
assert isclose(deserialized, -42.42, rel_tol=1e-7)


def test_deserialize_array_int():
serialized = '[{"Array": [{"Int": "0x1"}, {"Int": "0x2"}]}]'
serialized = '[1 2]'
deserialized = deserializer(serialized, 'Span<u32>')
assert np.array_equal(deserialized, np.array([1, 2], dtype=np.int64))

serialized = '[{"Array": [{"Int": "2A"}, {"Int": "800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}]}]'
serialized = '[42 3618502788666131213697322783095070105623107215331596699973092056135872020439]'
deserialized = deserializer(serialized, 'Span<i32>')
assert np.array_equal(deserialized, np.array([42, -42], dtype=np.int64))


def test_deserialize_arr_fixed_point():
serialized = '[{"Array": [{"Int": "2A6B85"}, {"Int": "0"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]'
serialized = '[2780037 0 2780037 1]'
deserialized = deserializer(serialized, 'Span<FP16x16>')
expected = np.array([42.42, -42.42], dtype=np.float64)
assert np.all(np.isclose(deserialized, expected, atol=1e-7))


def test_deserialize_tensor_int():
serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "0x1"}, {"Int": "0x2"}, {"Int": "0x3"}, {"Int": "0x4"}]}]'
serialized = '[2 2] [1 2 3 4]'
deserialized = deserializer(serialized, 'Tensor<i32>')
assert np.array_equal(deserialized, np.array(
([1, 2], [3, 4]), dtype=np.int64))

serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A"}, {"Int": "2A"},{"Int": "800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}, {"Int": "800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}]}]'
serialized = '[2 2] [42 42 3618502788666131213697322783095070105623107215331596699973092056135872020439 3618502788666131213697322783095070105623107215331596699973092056135872020439]'
deserialized = deserializer(serialized, 'Tensor<i32>')
assert np.array_equal(deserialized, np.array([[42, 42], [-42, -42]]))


def test_deserialize_tensor_fixed_point():
serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]'
serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1]'
expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]])
deserialized = deserializer(serialized, 'Tensor<FP16x16>')
assert np.allclose(deserialized, expected_array, atol=1e-7)


def test_deserialize_tuple_int():
serialized = '[{"Int":"0x1"},{"Int":"0x3"}]'
serialized = '1 3'
deserialized = deserializer(serialized, '(u32, u32)')
assert deserialized == (1, 3)


def test_deserialize_tuple_span():
serialized = '[{"Array":[{"Int":"0x1"},{"Int":"0x2"}]},{"Int":"0x3"}]'
serialized = '[1 2] 3'
deserialized = deserializer(serialized, '(Span<u32>, u32)')
expected = (np.array([1, 2]), 3)
npt.assert_array_equal(deserialized[0], expected[0])
assert deserialized[1] == expected[1]


def test_deserialize_tuple_span_tensor_fp():
serialized = '[{"Array":[{"Int":"0x1"},{"Int":"0x2"}]},{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]'
serialized = '[1 2] [2 2] [2780037 0 2780037 0 2780037 1 2780037 1]'
deserialized = deserializer(serialized, '(Span<u32>, Tensor<FP16x16>)')
expected = (np.array([1, 2]), np.array([[42.42, 42.42], [-42.42, -42.42]]))
npt.assert_array_equal(deserialized[0], expected[0])
assert np.allclose(deserialized[1], expected[1], atol=1e-7)

serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}, {"Array":[{"Int":"0x1"},{"Int":"0x2"}]}]'
serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1] [1 2]'
deserialized = deserializer(serialized, '(Tensor<FP16x16>, Span<u32>)')
expected = (np.array([[42.42, 42.42], [-42.42, -42.42]]), np.array([1, 2]))
assert np.allclose(deserialized[0], expected[0], atol=1e-7)
Expand Down

0 comments on commit de7518c

Please sign in to comment.