From a897e0168777c1bd483ae81b807546311bd35af4 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 21 Sep 2022 12:10:26 -0700 Subject: [PATCH 01/45] checkpoint --- wheel/Cargo.lock | 82 ++--- wheel/Cargo.toml | 5 +- wheel/python/clvm_rs/EvalError.py | 4 + wheel/python/clvm_rs/__init__.py | 6 + wheel/python/clvm_rs/clvm_tree.py | 88 +++++ wheel/python/clvm_rs/curry.py | 85 +++++ wheel/python/clvm_rs/curry_and_treehash.py | 66 ++++ wheel/python/clvm_rs/deser.py | 68 ++++ wheel/python/clvm_rs/program.py | 353 +++++++++++++++++++++ wheel/python/clvm_rs/run_program.py | 27 ++ wheel/python/clvm_rs/serialize.py | 193 +++++++++++ wheel/python/tests/__init__.py | 0 wheel/python/tests/as_python_test.py | 222 +++++++++++++ wheel/python/tests/test_program.py | 98 ++++++ wheel/python/tests/to_sexp_test.py | 177 +++++++++++ 15 files changed, 1420 insertions(+), 54 deletions(-) create mode 100644 wheel/python/clvm_rs/EvalError.py create mode 100644 wheel/python/clvm_rs/__init__.py create mode 100644 wheel/python/clvm_rs/clvm_tree.py create mode 100644 wheel/python/clvm_rs/curry.py create mode 100644 wheel/python/clvm_rs/curry_and_treehash.py create mode 100644 wheel/python/clvm_rs/deser.py create mode 100644 wheel/python/clvm_rs/program.py create mode 100644 wheel/python/clvm_rs/run_program.py create mode 100644 wheel/python/clvm_rs/serialize.py create mode 100644 wheel/python/tests/__init__.py create mode 100644 wheel/python/tests/as_python_test.py create mode 100644 wheel/python/tests/test_program.py create mode 100644 wheel/python/tests/to_sexp_test.py diff --git a/wheel/Cargo.lock b/wheel/Cargo.lock index 7c92860b..7739e468 100644 --- a/wheel/Cargo.lock +++ b/wheel/Cargo.lock @@ -179,26 +179,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "indoc" -version = "0.3.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47741a8bc60fb26eb8d6e0238bbb26d8575ff623fdc97b1a2c00c050b9684ed8" -dependencies = [ - "indoc-impl", - "proc-macro-hack", -] - -[[package]] -name = "indoc-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce046d161f000fffde5f432a0d034d0341dc152643b2598ed5bfce44c4f3a8f0" -dependencies = [ - "proc-macro-hack", - "proc-macro2", - "quote", - "syn", - "unindent", -] +checksum = "adab1eaa3408fb7f0c777a73e7465fd5656136fc93b670eb6df3c88c2c1344e3" [[package]] name = "instant" @@ -350,37 +333,12 @@ dependencies = [ "winapi", ] -[[package]] -name = "paste" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45ca20c77d80be666aef2b45486da86238fabe33e38306bd3118fe4af33fa880" -dependencies = [ - "paste-impl", - "proc-macro-hack", -] - -[[package]] -name = "paste-impl" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95a7db200b97ef370c8e6de0088252f7e0dfff7d047a28528e47456c0fc98b6" -dependencies = [ - "proc-macro-hack", -] - [[package]] name = "pkg-config" version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" -[[package]] -name = "proc-macro-hack" -version = "0.5.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" - [[package]] name = "proc-macro2" version = "1.0.42" @@ -392,35 +350,47 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.15.1" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cf01dbf1c05af0a14c7779ed6f3aa9deac9c3419606ac9de537a2d649005720" +checksum = "9d1a3df45cb95bd954fac00bd9609062640fd7fb9e9946a660092c9e015421fb" dependencies = [ "cfg-if", "indoc", "libc", "parking_lot", - "paste", "pyo3-build-config", + "pyo3-ffi", "pyo3-macros", "unindent", ] [[package]] name = "pyo3-build-config" -version = "0.15.2" +version = "0.16.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "779239fc40b8e18bc8416d3a37d280ca9b9fb04bda54b98037bb6748595c2410" +checksum = "9c819d397859445928609d0ec5afc2da5204e0d0f73d6bf9e153b04e83c9cdc2" dependencies = [ "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9a4e2f74dc77eea5ce11d19f0afaeb632b6590f8cbb1d5ee2f1330b766803e8" +dependencies = [ + "libc", + "pyo3-build-config", ] [[package]] name = "pyo3-macros" -version = "0.15.1" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67701eb32b1f9a9722b4bc54b548ff9d7ebfded011c12daece7b9063be1fd755" +checksum = "dbff3a1579934968a53bcc78ac33663ed2577accda05d484097679cc8d28e52d" dependencies = [ + "proc-macro2", "pyo3-macros-backend", "quote", "syn", @@ -428,9 +398,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.15.1" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f44f09e825ee49a105f2c7b23ebee50886a9aee0746f4dd5a704138a64b0218a" +checksum = "90644126c8c1ac7b47f794dd20a5729f8646b91c49edb31689c90f8cb3c33ea9" dependencies = [ "proc-macro2", "pyo3-build-config", @@ -514,6 +484,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "target-lexicon" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c02424087780c9b71cc96799eaeddff35af2bc513278cda5c99fc1f5d026d3c1" + [[package]] name = "typenum" version = "1.15.0" diff --git a/wheel/Cargo.toml b/wheel/Cargo.toml index 8ce23c6a..1994cf58 100644 --- a/wheel/Cargo.toml +++ b/wheel/Cargo.toml @@ -16,7 +16,10 @@ path = "src/lib.rs" [dependencies] clvmr = { path = ".." } -pyo3 = { version = "=0.15.1", features = ["abi3-py37", "extension-module"] } +pyo3 = { version = "=0.16.0", features = ["abi3-py37", "extension-module"] } [features] openssl = ["clvmr/openssl"] + +[package.metadata.maturin] +python-source = "python" diff --git a/wheel/python/clvm_rs/EvalError.py b/wheel/python/clvm_rs/EvalError.py new file mode 100644 index 00000000..f71f912a --- /dev/null +++ b/wheel/python/clvm_rs/EvalError.py @@ -0,0 +1,4 @@ +class EvalError(Exception): + def __init__(self, message: str, sexp): + super().__init__(message) + self._sexp = sexp diff --git a/wheel/python/clvm_rs/__init__.py b/wheel/python/clvm_rs/__init__.py new file mode 100644 index 00000000..03ab351e --- /dev/null +++ b/wheel/python/clvm_rs/__init__.py @@ -0,0 +1,6 @@ +from .clvm_rs import * + + +__doc__ = clvm_rs.__doc__ +if hasattr(clvm_rs, "__all__"): + __all__ = clvm_rs.__all__ diff --git a/wheel/python/clvm_rs/clvm_tree.py b/wheel/python/clvm_rs/clvm_tree.py new file mode 100644 index 00000000..dcbc2221 --- /dev/null +++ b/wheel/python/clvm_rs/clvm_tree.py @@ -0,0 +1,88 @@ +from clvm_rs import deserialize_as_triples + + +from typing import List, Optional, Tuple + + +class CLVMTree: + """ + This object conforms with the `CLVMObject` protocol. It's optimized for + deserialization, and keeps a reference to the serialized blob and to a + list of triples of integers, each of which corresponds to a subtree. + + It turns out every atom serialized to a blob contains a substring that + exactly matches that atom, so it ends up being not very wasteful to + simply use the blob for atom storage (especially if it's a `memoryview`, + from which you can take substrings without copying). Additionally, the + serialization for every object in the tree exactly corresponds to a + substring in the blob, so by clever caching we can very quickly generate + serializations for any subtree. + + The deserializer iterates through the blob and caches a triple of + integers for each subtree: the first two integers represent the + `(start_offset, end_offset)` within the blob that corresponds to the + serialization of that object. You can check the contents of + `blob[start_offset]` to determine if the object is a pair (in which case + that byte is 0xff) or an atom (anything else). For a pair, the third + number corresponds to the index of the array that is the "rest" of the + pair (the "first" is always this object's index plus one, so we don't + need to save that); for an atom, the third number corresponds to an + offset of where the atom's binary data is relative to + `blob[start_offset]` (so the atom data is at `blob[triple[0] + + triple[2]:triple[1]]`) + + Since each `CLVMTree` subtree keeps a reference to the original + serialized data and the list of triples, no memory is released until all + objects in the tree are garbage-collected. This happens pretty naturally + in well-behaved python code. + """ + + @classmethod + def from_bytes(cls, blob: bytes) -> "CLVMTree": + return cls(memoryview(blob), deserialize_as_triples(blob), 0) + + def __init__( + self, blob: bytes, int_triples: List[Tuple[int, int, int]], index: int + ): + self.blob = blob + self.int_triples = int_triples + self.index = index + + @property + def atom(self) -> Optional[bytes]: + if not hasattr(self, "_atom"): + start, end, atom_offset = self.int_triples[self.index] + # if `self.blob[start]` is 0xff, it's a pair + if self.blob[start] == 0xFF: + self._atom = None + else: + self._atom = bytes(self.blob[start + atom_offset:end]) + return self._atom + + @property + def pair(self) -> Optional[Tuple["CLVMTree", "CLVMTree"]]: + if not hasattr(self, "_pair"): + triples = self.int_triples + start, end, right_index = triples[self.index] + # if `self.blob[start]` is 0xff, it's a pair + if self.blob[start] == 0xFF: + left = self.__class__(self.blob, triples, self.index + 1) + right = self.__class__(self.blob, triples, right_index) + self._pair = (left, right) + else: + self._pair = None + return self._pair + + @property + def _cached_serialization(self) -> bytes: + start, end, _ = self.int_triples[self.index] + return self.blob[start:end] + + def __bytes__(self) -> bytes: + return bytes(self._cached_serialization) + + def __str__(self) -> str: + return bytes(self).hex() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}: {self}>" diff --git a/wheel/python/clvm_rs/curry.py b/wheel/python/clvm_rs/curry.py new file mode 100644 index 00000000..64b60ea5 --- /dev/null +++ b/wheel/python/clvm_rs/curry.py @@ -0,0 +1,85 @@ +from typing import Any, Optional, Tuple + +CLVM = Any + + +def at(sexp: CLVM, position: str) -> Optional[CLVM]: + """ + Take a string of only `f` and `r` characters and follow the corresponding path. + + Example: + + `assert Program.to(17) == Program.to([10, 20, 30, [15, 17], 40, 50]).at("rrrfrf")` + + """ + v = sexp + for c in position.lower(): + p = v.pair + if p is None: + return p + if c not in "rf": + raise ValueError( + f"`at` got illegal character `{c}`. Only `f` & `r` allowed" + ) + v = p[0 if c == "f" else 1] + return v + + +# Replicates the curry function from clvm_tools, taking advantage of *args +# being a list. We iterate through args in reverse building the code to +# create a clvm list. +# +# Given arguments to a function addressable by the '1' reference in clvm +# +# fixed_args = 1 +# +# Each arg is prepended as fixed_args = (c (q . arg) fixed_args) +# +# The resulting argument list is interpreted with apply (2) +# +# (2 (1 . self) rest) +# +# Resulting in a function which places its own arguments after those +# curried in in the form of a proper list. + + +def curry(sexp: CLVM, *args) -> CLVM: + fixed_args: Any = 1 + while args: + arg = args.pop() + fixed_args = [4, (1, arg), fixed_args] + return sexp.to([2, (1, sexp), fixed_args]) + + +# UNCURRY_PATTERN_FUNCTION = assemble("(a (q . (: . function)) (: . core))") +# UNCURRY_PATTERN_CORE = assemble("(c (q . (: . parm)) (: . core))") + + +ONE_PATH = Q_KW = bytes([1]) +C_KW = bytes([2]) +A_KW = bytes([4]) +NULL = bytes([]) + + +def uncurry(sexp: CLVM) -> Optional[Tuple[CLVM, CLVM]]: + if ( + at(sexp, "f").atom != A_KW + or at(sexp, "rf").atom != Q_KW + or at(sexp, "rrr").atom != NULL + ): + return None + uncurried_function = at(sexp, "rr") + core_items = [] + core = at(sexp, "rrf") + while core.atom != ONE_PATH: + if ( + at(core, "f").atom != C_KW + or at(core, "rf").atom != Q_KW + or at(sexp, "rrr").atom != NULL + ): + return None + new_item = at(core, "rr") + core_items.append(new_item) + core = at(core, "rrf") + core_items.reverse() + return uncurried_function, core_items diff --git a/wheel/python/clvm_rs/curry_and_treehash.py b/wheel/python/clvm_rs/curry_and_treehash.py new file mode 100644 index 00000000..2ae64137 --- /dev/null +++ b/wheel/python/clvm_rs/curry_and_treehash.py @@ -0,0 +1,66 @@ +from typing import List + +from .bytes32 import bytes32 +from .tree_hash import shatree_atom, shatree_pair + + +NULL = bytes.fromhex("") +ONE = bytes.fromhex("01") +TWO = bytes.fromhex("02") +Q_KW = bytes.fromhex("01") +A_KW = bytes.fromhex("02") +C_KW = bytes.fromhex("04") + + +Q_KW_TREEHASH = shatree_atom(Q_KW) +A_KW_TREEHASH = shatree_atom(A_KW) +C_KW_TREEHASH = shatree_atom(C_KW) +ONE_TREEHASH = shatree_atom(ONE) +NULL_TREEHASH = shatree_atom(NULL) + + +# The environment `E = (F . R)` recursively expands out to +# `(c . ((q . F) . EXPANSION(R)))` if R is not 0 +# `1` if R is 0 + + +def curried_values_tree_hash(arguments: List[bytes32]) -> bytes32: + if len(arguments) == 0: + return ONE_TREEHASH + + inner_curried_values = curried_values_tree_hash(arguments[1:]) + + return shatree_pair( + C_KW_TREEHASH, + shatree_pair( + shatree_pair(Q_KW_TREEHASH, arguments[0]), + shatree_pair(inner_curried_values, NULL_TREEHASH), + ), + ) + + +# The curry pattern is `(a . ((q . F) . (E . 0)))` == `(a (q . F) E) +# where `F` is the `mod` and `E` is the curried environment + + +def curry_and_treehash( + hash_of_quoted_mod_hash: bytes32, *hashed_arguments: bytes32 +) -> bytes32: + """ + `hash_of_quoted_mod_hash` : tree hash of `(q . MOD)` where `MOD` + is template to be curried + `arguments` : tree hashes of arguments to be curried + """ + + curried_values = curried_values_tree_hash(list(hashed_arguments)) + return shatree_pair( + A_KW_TREEHASH, + shatree_pair( + hash_of_quoted_mod_hash, + shatree_pair(curried_values, NULL_TREEHASH), + ), + ) + + +def calculate_hash_of_quoted_mod_hash(mod_hash: bytes32) -> bytes32: + return shatree_pair(Q_KW_TREEHASH, mod_hash) diff --git a/wheel/python/clvm_rs/deser.py b/wheel/python/clvm_rs/deser.py new file mode 100644 index 00000000..220b0b01 --- /dev/null +++ b/wheel/python/clvm_rs/deser.py @@ -0,0 +1,68 @@ +from typing import Tuple + +from .clvm_tree import CLVMTree + +MAX_SINGLE_BYTE = 0x7F +CONS_BOX_MARKER = 0xFF + + +# ATOM: serialize_offset, serialize_end, atom_offset +# PAIR: serialize_offset, serialize_end, right_index + + +def deserialized_in_place(blob: bytes, cursor: int = 0) -> CLVMTree: + def save_cursor(index, blob, cursor, obj_list, op_stack): + obj_list[index] = (obj_list[index][0], cursor, obj_list[index][2]) + return cursor + + def save_index(index, blob, cursor, obj_list, op_stack): + obj_list[index][2] = len(obj_list) + return cursor + + def parse_obj(blob, cursor, obj_list, op_stack): + if cursor >= len(blob): + raise ValueError("bad encoding") + + if blob[cursor] == CONS_BOX_MARKER: + index = len(obj_list) + obj_list.append([cursor, None, None]) + op_stack.append(lambda *args: save_cursor(index, *args)) + op_stack.append(parse_obj) + op_stack.append(lambda *args: save_index(index, *args)) + op_stack.append(parse_obj) + return cursor + 1 + atom_offset, new_cursor = _atom_size_from_cursor(blob, cursor) + obj_list.append((cursor, new_cursor, atom_offset)) + return new_cursor + + obj_list = [] + op_stack = [parse_obj] + while op_stack: + f = op_stack.pop() + cursor = f(blob, cursor, obj_list, op_stack) + + v = CLVMTree(blob, obj_list, 0) + return v + + +def _atom_size_from_cursor(blob, cursor) -> Tuple[int, int]: + # return `(size_of_prefix, cursor)` + b = blob[cursor] + if b == 0x80: + return 1, cursor + 1 + if b <= MAX_SINGLE_BYTE: + return 0, cursor + 1 + bit_count = 0 + bit_mask = 0x80 + while b & bit_mask: + bit_count += 1 + b &= 0xFF ^ bit_mask + bit_mask >>= 1 + size_blob = bytes([b]) + if bit_count > 1: + breakpoint() + size_blob += blob[cursor + 1 : cursor + bit_count] + size = int.from_bytes(size_blob, "big") + if size >= 0x400000000: + raise ValueError("blob too large") + return bit_count, cursor + size + bit_count diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py new file mode 100644 index 00000000..a61cb304 --- /dev/null +++ b/wheel/python/clvm_rs/program.py @@ -0,0 +1,353 @@ +from __future__ import annotations +import io +from typing import Dict, Iterator, List, Tuple, Optional, Any + +# from clvm import Program +from .base import CLVMObject +from .casts import to_sexp_type +from clvm_rs.clvm_rs import run_serialized_program +from clvm_rs.serialize import sexp_from_stream, sexp_to_stream +from clvm_rs.tree_hash import sha256_treehash +from .clvm_tree import CLVMTree +from .bytes32 import bytes32 + +# from chia.util.hash import std_hash +# from chia.util.byte_types import hexstr_to_bytes +# from chia.types.spend_bundle_conditions import SpendBundleConditions + + +INFINITE_COST = 0x7FFFFFFFFFFFFFFF + +NULL = bytes.fromhex("") +ONE = bytes.fromhex("01") +TWO = bytes.fromhex("02") +Q_KW = bytes.fromhex("01") +A_KW = bytes.fromhex("02") +C_KW = bytes.fromhex("04") + + +class Program(CLVMObject): + """ + A thin wrapper around s-expression data intended to be invoked with "eval". + """ + + # serialization/deserialization + + @classmethod + def parse(cls, f) -> Program: + return sexp_from_stream(f, cls.new_pair, cls.new_atom) + + def stream(self, f): + sexp_to_stream(self, f) + + @classmethod + def from_bytes(cls, blob: bytes) -> Program: + obj, cursor = cls.from_bytes_with_cursor(blob, 0) + return obj + + @classmethod + def from_bytes_with_cursor( + cls, blob: bytes, cursor: int + ) -> Tuple[Program, int]: + tree = CLVMTree.from_bytes(blob) + new_cursor = tree[-1][1] + obj = cls.wrap(tree) + return obj, new_cursor + + @classmethod + def fromhex(cls, hexstr: str) -> Program: + return cls.from_bytes(bytes.fromhex(hexstr)) + + def __bytes__(self) -> bytes: + f = io.BytesIO() + self.stream(f) # noqa + return f.getvalue() + + # high level casting with `.to` + + @classmethod + def to(cls, v: Any) -> Program: + return to_sexp_type(v, cls.new_atom, cls.new_pair) + + @classmethod + def wrap(cls, v: CLVMObject) -> Program: + if isinstance(v, Program): + return v + o = cls() + o.atom = v.atom + o.pair = v.pair + return o + + # new object creation on the python heap + + @classmethod + def new_atom(cls, v: bytes) -> Program: + o = cls() + o.atom = v + o.pair = None + return o + + @classmethod + def new_pair(cls, left: CLVMObject, right: CLVMObject) -> Program: + o = cls() + o.atom = None + o.pair = (left, right) + return o + + @classmethod + def null(cls) -> Program: + return NULL_PROGRAM + + # display + + def __str__(self) -> str: + return bytes(self).hex() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({str(self)})" + + def __eq__(self, other) -> bool: + stack = [(self, Program.to(other))] + while stack: + p1, p2 = stack.pop() + if p1.atom is None: + if p2.atom is not None: + return False + stack.append((p1.pair[1], p2.pair[1])) + stack.append((p1.pair[0], p2.pair[0])) + else: + if p1.atom != p2.atom: + return False + return True + + def __ne__(self, other) -> bool: + return not self.__eq__(other) + + def first(self) -> Optional[Program]: + if self.pair: + return self.wrap(self.pair[0]) + return None + + def rest(self) -> Optional[Program]: + if self.pair: + return self.wrap(self.pair[1]) + return None + + def as_pair(self) -> Optional[Tuple[Program, Program]]: + if self.pair: + return tuple(self.wrap(_) for _ in self.pair) + return None + + def as_atom(self) -> Optional[bytes]: + return self.atom + + def listp(self) -> bool: + return self.pair is not None + + def nullp(self) -> bool: + return self.atom == b"" + + def list_len(self) -> int: + c = 0 + v = self + while v.pair: + v = v.pair[1] + c += 1 + return c + + def at(self, position: str) -> "Program": + """ + Take a string of `f` and `r` characters and follow that path. + + Example: + + ``` + p1 = Program.to([10, 20, 30, [15, 17], 40, 50]) + assert Program.to(17) == at(p1, "rrrfrf") + ``` + + Returns `None` if an atom is hit at some intermediate node. + + ``` + p1 = Program.to(10) + assert None == at(p1, "rr") + ``` + + """ + v = self + for c in position.lower(): + if c == "f": + v = v.first() + elif c == "r": + v = v.rest() + else: + raise ValueError( + f"`at` got illegal character `{c}`. Only `f` & `r` allowed" + ) + return v + + def replace(self, **kwargs) -> "Program": + """ + Create a new program replacing the given paths (using `at` syntax). + Example: + ``` + >>> p1 = Program.to([100, 200, 300]) + >>> print(p1.replace(f=105) == Program.to([105, 200, 300])) + True + >>> p2 = [100, 200, [301, 302]] + >>> print(p1.replace(rrf=[301, 302]) == Program.to(p2)) + True + >>> p2 = [105, 200, [301, 302]] + >>> print(p1.replace(f=105, rrf=[301, 302]) == Program.to(p2)) + True + ``` + + This is a convenience method intended for use in the wallet or + command-line hacks where it would be easier to morph elements + of an existing clvm object tree than to rebuild one from scratch. + + Note that `Program` objects are immutable. This function returns a + new object; the original is left as-is. + """ + return _replace(self, **kwargs) + + def tree_hash(self) -> bytes32: + return sha256_treehash(bytes(self)) + + def run_with_cost(self, max_cost: int, args) -> Tuple[int, "Program"]: + prog_bytes = bytes(self) + args_bytes = bytes(self.to(args)) + cost, r = run_serialized_program(prog_bytes, args_bytes, max_cost, 0) + return cost, Program.to(r) + + def run(self, args) -> "Program": + cost, r = self.run_with_cost(INFINITE_COST, args) + return r + + # Replicates the curry function from clvm_tools, taking advantage of *args + # being a list. We iterate through args in reverse building the code to + # create a clvm list. + # + # Given arguments to a function addressable by the '1' reference in clvm + # + # fixed_args = 1 + # + # Each arg is prepended as fixed_args = (c (q . arg) fixed_args) + # + # The resulting argument list is interpreted with apply (2) + # + # (2 (1 . self) rest) + # + # Resulting in a function which places its own arguments after those + # curried in in the form of a proper list. + def curry(self, *args) -> "Program": + fixed_args: Any = 1 + for arg in reversed(args): + fixed_args = [4, (1, arg), fixed_args] + return Program.to([2, (1, self), fixed_args]) + + def uncurry(self) -> Optional[Tuple[Program, Program]]: + if ( + self.at("f").atom != A_KW + or self.at("rf").atom != Q_KW + or self.at("rrr").atom != NULL + ): + return None + uncurried_function = self.at("rr") + core_items = [] + core = self.at("rrf") + while core.atom != ONE: + if ( + core.at("f").atom != C_KW + or core.at("rf").atom != Q_KW + or core.at("rrr").atom != NULL + ): + return None + new_item = core.at("rr") + core_items.append(new_item) + core = core.at("rrf") + core_items.reverse() + return uncurried_function, core_items + + def as_int(self) -> int: + return int_from_bytes(self.as_atom()) + + def as_iter(self) -> Iterator[Program]: + v = self + while v.pair: + yield v.pair[0] + v = v.pair[1] + + def as_atom_iter(self) -> Iterator[bytes]: + """ + Pretend `self` is a list of atoms. Yield the corresponding atoms. + + At each step, we always assume a node to be an atom or a pair. + If the assumption is wrong, we exit early. This way we never fail + and always return SOMETHING. + """ + obj = self + while obj.pair is not None: + left, obj = obj.pair + atom = left.atom + if atom is None: + break + yield atom + + def as_atom_list(self) -> List[bytes]: + """ + Pretend `self` is a list of atoms. Return the corresponding + python list of atoms. + + At each step, we always assume a node to be an atom or a pair. + If the assumption is wrong, we exit early. This way we never fail + and always return SOMETHING. + """ + return list(self.as_atom_iter()) + + def __deepcopy__(self, memo): + return type(self).from_bytes(bytes(self)) + + +NULL_PROGRAM = Program.from_bytes(b"\x80") + + +def _replace(program: Program, **kwargs) -> Program: + # if `kwargs == {}` then `return program` unchanged + if len(kwargs) == 0: + return program + + if "" in kwargs: + if len(kwargs) > 1: + raise ValueError("conflicting paths") + return kwargs[""] + + # we've confirmed that no `kwargs` is the empty string. + # Now split `kwargs` into two groups: those + # that start with `f` and those that start with `r` + + args_by_prefix: Dict[str, Program] = {} + for k, v in kwargs.items(): + c = k[0] + if c not in "fr": + raise ValueError( + f"bad path containing {c}: must only contain `f` and `r`" + ) + args_by_prefix.setdefault(c, dict())[k[1:]] = v + + pair = program.pair + if pair is None: + raise ValueError("path into atom") + + # recurse down the tree + new_f = _replace(pair[0], **args_by_prefix.get("f", {})) + new_r = _replace(pair[1], **args_by_prefix.get("r", {})) + + return program.new_pair((new_f, new_r)) + + +def int_from_bytes(blob): + size = len(blob) + if size == 0: + return 0 + return int.from_bytes(blob, "big", signed=True) diff --git a/wheel/python/clvm_rs/run_program.py b/wheel/python/clvm_rs/run_program.py new file mode 100644 index 00000000..70932636 --- /dev/null +++ b/wheel/python/clvm_rs/run_program.py @@ -0,0 +1,27 @@ +from typing import Tuple + +from clvm_rs import CLVMObject, run_serialized_program, NO_NEG_DIV + + +from .EvalError import EvalError +from .serialize import sexp_to_bytes + + +DEFAULT_MAX_COST = (1 << 64) - 1 +DEFAULT_FLAGS = NO_NEG_DIV + + +def run_program( + program: CLVMObject, + args: CLVMObject, + max_cost=DEFAULT_MAX_COST, + flags=DEFAULT_FLAGS, +) -> Tuple[int, CLVMObject]: + program_blob = sexp_to_bytes(program) + args_blob = sexp_to_bytes(args) + cost_or_err_str, result = run_serialized_program( + program_blob, args_blob, max_cost, flags + ) + if isinstance(cost_or_err_str, str): + raise EvalError(cost_or_err_str, result) + return cost_or_err_str, result diff --git a/wheel/python/clvm_rs/serialize.py b/wheel/python/clvm_rs/serialize.py new file mode 100644 index 00000000..bbffa44e --- /dev/null +++ b/wheel/python/clvm_rs/serialize.py @@ -0,0 +1,193 @@ +# decoding: +# read a byte +# if it's 0xfe, it's nil (which might be same as 0) +# if it's 0xff, it's a cons box. Read two items, build cons +# otherwise, number of leading set bits is length in bytes to read size +# 0-0x7f are literal one byte values +# leading bits is the count of bytes to read of size +# 0x80-0xbf is a size of one byte (`and` of first byte with 0x3f for size) +# 0xc0-0xdf is a size of two bytes (`and` of first byte with 0x1f) +# 0xe0-0xef is 3 bytes (`and` of first byte with 0xf) +# 0xf0-0xf7 is 4 bytes (`and` of first byte with 0x7) +# 0xf7-0xfb is 5 bytes (`and` of first byte with 0x3) + + +MAX_SINGLE_BYTE = 0x7F +CONS_BOX_MARKER = 0xFF + + +def sexp_to_byte_iterator(sexp): + todo_stack = [sexp] + while todo_stack: + sexp = todo_stack.pop() + r = getattr(sexp, "_cached_serialization", None) + if r is not None: + yield r + continue + pair = sexp.pair + if pair: + yield bytes([CONS_BOX_MARKER]) + todo_stack.append(pair[1]) + todo_stack.append(pair[0]) + else: + yield from atom_to_byte_iterator(sexp.atom) + + +def atom_to_byte_iterator(as_atom): + size = len(as_atom) + if size == 0: + yield b"\x80" + return + if size == 1: + if as_atom[0] <= MAX_SINGLE_BYTE: + yield as_atom + return + if size < 0x40: + size_blob = bytes([0x80 | size]) + elif size < 0x2000: + size_blob = bytes([0xC0 | (size >> 8), (size >> 0) & 0xFF]) + elif size < 0x100000: + size_blob = bytes([0xE0 | (size >> 16), (size >> 8) & 0xFF, (size >> 0) & 0xFF]) + elif size < 0x8000000: + size_blob = bytes( + [ + 0xF0 | (size >> 24), + (size >> 16) & 0xFF, + (size >> 8) & 0xFF, + (size >> 0) & 0xFF, + ] + ) + elif size < 0x400000000: + size_blob = bytes( + [ + 0xF8 | (size >> 32), + (size >> 24) & 0xFF, + (size >> 16) & 0xFF, + (size >> 8) & 0xFF, + (size >> 0) & 0xFF, + ] + ) + else: + raise ValueError("sexp too long %s" % as_atom) + + yield size_blob + yield as_atom + + +def sexp_to_stream(sexp, f): + for b in sexp_to_byte_iterator(sexp): + f.write(b) + + +def sexp_to_bytes(sexp) -> bytes: + b = bytearray() + for _ in sexp_to_byte_iterator(sexp): + b.extend(_) + return bytes(b) + + +def _op_read_sexp(op_stack, val_stack, f, new_pair_f, new_atom_f): + blob = f.read(1) + if len(blob) == 0: + raise ValueError("bad encoding") + b = blob[0] + if b == CONS_BOX_MARKER: + op_stack.append(_op_cons) + op_stack.append(_op_read_sexp) + op_stack.append(_op_read_sexp) + return + val_stack.append(_atom_from_stream(f, b, new_atom_f)) + + +def _op_cons(op_stack, val_stack, f, new_pair_f, new_atom_f): + right = val_stack.pop() + left = val_stack.pop() + val_stack.append(new_pair_f((left, right))) + + +def sexp_from_stream(f, new_pair_f, new_atom_f): + op_stack = [_op_read_sexp] + val_stack = [] + + while op_stack: + func = op_stack.pop() + func(op_stack, val_stack, f, new_pair_f, new_atom_f) + return val_stack.pop() + + +def _op_consume_sexp(f): + blob = f.read(1) + if len(blob) == 0: + raise ValueError("bad encoding") + b = blob[0] + if b == CONS_BOX_MARKER: + return (blob, 2) + return (_consume_atom(f, b), 0) + + +def _consume_atom(f, b): + if b == 0x80: + return bytes([b]) + if b <= MAX_SINGLE_BYTE: + return bytes([b]) + bit_count = 0 + bit_mask = 0x80 + ll = b + while ll & bit_mask: + bit_count += 1 + ll &= 0xFF ^ bit_mask + bit_mask >>= 1 + size_blob = bytes([ll]) + if bit_count > 1: + ll = f.read(bit_count - 1) + if len(ll) != bit_count - 1: + raise ValueError("bad encoding") + size_blob += ll + size = int.from_bytes(size_blob, "big") + if size >= 0x400000000: + raise ValueError("blob too large") + blob = f.read(size) + if len(blob) != size: + raise ValueError("bad encoding") + return bytes([b]) + size_blob[1:] + blob + + +# instead of parsing the input stream, this function pulls out all the bytes +# that represent on S-expression tree, and returns them. This is more efficient +# than parsing and returning a python S-expression tree. +def sexp_buffer_from_stream(f): + ret = b"" + + depth = 1 + while depth > 0: + depth -= 1 + buf, d = _op_consume_sexp(f) + depth += d + ret += buf + return ret + + +def _atom_from_stream(f, b, new_atom_f): + if b == 0x80: + return new_atom_f(b"") + if b <= MAX_SINGLE_BYTE: + return new_atom_f(bytes([b])) + bit_count = 0 + bit_mask = 0x80 + while b & bit_mask: + bit_count += 1 + b &= 0xFF ^ bit_mask + bit_mask >>= 1 + size_blob = bytes([b]) + if bit_count > 1: + b = f.read(bit_count - 1) + if len(b) != bit_count - 1: + raise ValueError("bad encoding") + size_blob += b + size = int.from_bytes(size_blob, "big") + if size >= 0x400000000: + raise ValueError("blob too large") + blob = f.read(size) + if len(blob) != size: + raise ValueError("bad encoding") + return new_atom_f(blob) diff --git a/wheel/python/tests/__init__.py b/wheel/python/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/wheel/python/tests/as_python_test.py b/wheel/python/tests/as_python_test.py new file mode 100644 index 00000000..01fb5dfa --- /dev/null +++ b/wheel/python/tests/as_python_test.py @@ -0,0 +1,222 @@ +import unittest + +from clvm_rs.program import Program + +from blspy import G1Element + + +class dummy_class: + def __init__(self): + self.i = 0 + + +def gen_tree(depth: int) -> Program: + if depth == 0: + return Program.to(1337) + subtree = gen_tree(depth - 1) + return Program.to((subtree, subtree)) + + +fh = bytes.fromhex +H01 = fh("01") +H02 = fh("02") + + +class AsPythonTest(unittest.TestCase): + def check_as_atom_list(self, p): + v = Program.to(p) + p1 = v.as_atom_list() + self.assertEqual(p, p1) + + def test_null(self): + self.check_as_atom_list([]) + + def test_single_bytes(self): + for _ in range(256): + self.check_as_atom_list([bytes([_])]) + + def test_short_lists(self): + self.check_as_atom_list([]) + for _ in range(256): + for size in range(1, 5): + self.check_as_atom_list([bytes([_])] * size) + + def test_int(self): + v = Program.to(42) + self.assertEqual(v.atom, bytes([42])) + + def test_none(self): + v = Program.to(None) + self.assertEqual(v.atom, b"") + + def test_empty_list(self): + v = Program.to([]) + self.assertEqual(v.atom, b"") + + def test_list_of_one(self): + v = Program.to([1]) + self.assertEqual(type(v.pair[0]), Program) + self.assertEqual(type(v.pair[1]), Program) + self.assertEqual(type(v.as_pair()[0]), Program) + self.assertEqual(type(v.as_pair()[1]), Program) + self.assertEqual(v.pair[0].atom, b"\x01") + self.assertEqual(v.pair[1].atom, b"") + + def test_g1element(self): + b = fh( + "b3b8ac537f4fd6bde9b26221d49b54b17a506be147347dae5" + "d081c0a6572b611d8484e338f3432971a9823976c6a232b" + ) + v = Program.to(G1Element(b)) + self.assertEqual(v.atom, b) + + def test_complex(self): + self.check_as_atom_list([b"foo"]) + self.check_as_atom_list([b"2", b"1"]) + self.check_as_atom_list([b"", b"2", b"1"]) + self.check_as_atom_list([b"", b"1", b"2", b"30", b"40", b"90", b"600"]) + + def test_listp(self): + self.assertEqual(Program.to(42).listp(), False) + self.assertEqual(Program.to(b"").listp(), False) + self.assertEqual(Program.to(b"1337").listp(), False) + + self.assertEqual(Program.to((1337, 42)).listp(), True) + self.assertEqual(Program.to([1337, 42]).listp(), True) + + def test_nullp(self): + self.assertEqual(Program.to(b"").nullp(), True) + self.assertEqual(Program.to(b"1337").nullp(), False) + self.assertEqual(Program.to((b"", b"")).nullp(), False) + + def test_constants(self): + self.assertEqual(Program.null().nullp(), True) + + def test_list_len(self): + v = Program.to(42) + for i in range(100): + self.assertEqual(v.list_len(), i) + v = Program.to((42, v)) + self.assertEqual(v.list_len(), 100) + + def test_list_len_atom(self): + v = Program.to(42) + self.assertEqual(v.list_len(), 0) + + def test_as_int(self): + self.assertEqual(Program.to(fh("80")).as_int(), -128) + self.assertEqual(Program.to(fh("ff")).as_int(), -1) + self.assertEqual(Program.to(fh("0080")).as_int(), 128) + self.assertEqual(Program.to(fh("00ff")).as_int(), 255) + + def test_string(self): + self.assertEqual(Program.to("foobar").as_atom(), b"foobar") + + def test_deep_recursion(self): + d = b"2" + for i in range(1000): + d = [d] + v = Program.to(d) + for i in range(1000): + self.assertEqual(v.as_pair()[1].as_atom(), Program.null()) + v = v.as_pair()[0] + d = d[0] + + self.assertEqual(v.as_atom(), b"2") + self.assertEqual(d, b"2") + + def test_long_linked_list(self): + d = b"" + for i in range(1000): + d = (b"2", d) + v = Program.to(d) + for i in range(1000): + self.assertEqual(v.as_pair()[0].as_atom(), d[0]) + v = v.as_pair()[1] + d = d[1] + + self.assertEqual(v.as_atom(), b"") + self.assertEqual(d, b"") + + def test_long_list(self): + d = [1337] * 1000 + v = Program.to(d) + for i in range(1000): + self.assertEqual(v.as_pair()[0].as_int(), d[i]) + v = v.as_pair()[1] + + self.assertEqual(v.as_atom(), b"") + + def test_invalid_tuple(self): + with self.assertRaises(ValueError): + s = Program.to((dummy_class, dummy_class)) + + with self.assertRaises(ValueError): + s = Program.to((dummy_class, dummy_class, dummy_class)) + + def test_clvm_object_tuple(self): + o1 = Program.to(b"foo") + o2 = Program.to(b"bar") + self.assertEqual(Program.to((o1, o2)), (o1, o2)) + + def test_first(self): + val = Program.to(1) + self.assertEqual(val.first(), None) + val = Program.to((42, val)) + self.assertEqual(val.first(), Program.to(42)) + + def test_rest(self): + val = Program.to(1) + self.assertEqual(val.first(), None) + val = Program.to((42, val)) + self.assertEqual(val.rest(), Program.to(1)) + + def test_as_iter(self): + val = list(Program.to((1, (2, (3, (4, b""))))).as_iter()) + self.assertEqual(val, [1, 2, 3, 4]) + + val = list(Program.to(b"").as_iter()) + self.assertEqual(val, []) + + val = list(Program.to((1, b"")).as_iter()) + self.assertEqual(val, [1]) + + # these fail because the lists are not null-terminated + self.assertEqual(list(Program.to(1).as_iter()), []) + self.assertEqual(list(Program.to((1, (2, (3, (4, 5))))).as_iter()), [1, 2, 3, 4]) + + def test_eq(self): + val = Program.to(1) + + self.assertTrue(val == 1) + self.assertFalse(val == 2) + + # mismatching types + self.assertFalse(val == [1]) + self.assertFalse(val == [1, 2]) + self.assertFalse(val == (1, 2)) + self.assertRaises(ValueError, lambda: val == (dummy_class, dummy_class)) + + def test_eq_tree(self): + val1 = gen_tree(2) + val2 = gen_tree(2) + val3 = gen_tree(3) + + self.assertTrue(val1 == val2) + self.assertTrue(val2 == val1) + self.assertFalse(val1 == val3) + self.assertFalse(val3 == val1) + + def test_str(self): + self.assertEqual(str(Program.to(1)), "01") + self.assertEqual(str(Program.to(1337)), "820539") + self.assertEqual(str(Program.to(-1)), "81ff") + self.assertEqual(str(gen_tree(1)), "ff820539820539") + self.assertEqual(str(gen_tree(2)), "ffff820539820539ff820539820539") + + def test_repr(self): + self.assertEqual(repr(Program.to(1)), "Program(01)") + self.assertEqual(repr(Program.to(1337)), "Program(820539)") + self.assertEqual(repr(Program.to(-1)), "Program(81ff)") + self.assertEqual(repr(gen_tree(1)), "Program(ff820539820539)") + self.assertEqual(repr(gen_tree(2)), "Program(ffff820539820539ff820539820539)") diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py new file mode 100644 index 00000000..b9c38d27 --- /dev/null +++ b/wheel/python/tests/test_program.py @@ -0,0 +1,98 @@ +from unittest import TestCase + +from clvm_rs.program import Program +from clvm_rs.EvalError import EvalError +#from clvm.operators import KEYWORD_TO_ATOM +#from clvm_tools.binutils import assemble, disassemble + + +class TestProgram(TestCase): + def test_at(self): + p = Program.to([10, 20, 30, [15, 17], 40, 50]) + + self.assertEqual(p.first(), p.at("f")) + self.assertEqual(Program.to(10), p.at("f")) + + self.assertEqual(p.rest(), p.at("r")) + self.assertEqual(Program.to([20, 30, [15, 17], 40, 50]), p.at("r")) + + self.assertEqual(p.rest().rest().rest().first().rest().first(), p.at("rrrfrf")) + self.assertEqual(Program.to(17), p.at("rrrfrf")) + + self.assertRaises(ValueError, lambda: p.at("q")) + self.assertEqual(None, p.at("ff")) + + def test_replace(self): + p1 = Program.to([100, 200, 300]) + self.assertEqual(p1.replace(f=105), Program.to([105, 200, 300])) + self.assertEqual(p1.replace(rrf=[301, 302]), Program.to([100, 200, [301, 302]])) + self.assertEqual(p1.replace(f=105, rrf=[301, 302]), Program.to([105, 200, [301, 302]])) + self.assertEqual(p1.replace(f=100, r=200), Program.to((100, 200))) + + def test_replace_conflicts(self): + p1 = Program.to([100, 200, 300]) + self.assertRaises(ValueError, lambda: p1.replace(rr=105, rrf=200)) + + def test_replace_conflicting_paths(self): + p1 = Program.to([100, 200, 300]) + self.assertRaises(ValueError, lambda: p1.replace(ff=105)) + + def test_replace_bad_path(self): + p1 = Program.to([100, 200, 300]) + self.assertRaises(ValueError, lambda: p1.replace(q=105)) + self.assertRaises(ValueError, lambda: p1.replace(rq=105)) + + +def check_idempotency(f, *args): + prg = Program.to(f) + curried = prg.curry(*args) + + r = disassemble(curried) + f_0, args_0 = curried.uncurry() + + assert disassemble(f_0) == disassemble(f) + assert disassemble(args_0) == disassemble(Program.to(list(args))) + return r + + +def ztest_curry_uncurry(): + PLUS = KEYWORD_TO_ATOM["+"][0] + f = assemble("(+ 2 5)") + actual_disassembly = check_idempotency(f, 200, 30) + assert actual_disassembly == f"(a (q {PLUS} 2 5) (c (q . 200) (c (q . 30) 1)))" + + f = assemble("(+ 2 5)") + args = assemble("(+ (q . 50) (q . 60))") + # passing "args" here wraps the arguments in a list + actual_disassembly = check_idempotency(f, args) + assert actual_disassembly == f"(a (q {PLUS} 2 5) (c (q {PLUS} (q . 50) (q . 60)) 1))" + + +def ztest_uncurry_not_curried(): + # this function has not been curried + plus = Program.to(assemble("(+ 2 5)")) + assert plus.uncurry() == (plus, Program.to(0)) + + +def ztest_uncurry(): + # this is a positive test + plus = Program.to(assemble("(2 (q . (+ 2 5)) (c (q . 1) 1))")) + assert plus.uncurry() == (Program.to(assemble("(+ 2 5)")), Program.to([1])) + + +def ztest_uncurry_top_level_garbage(): + # there's garbage at the end of the top-level list + plus = Program.to(assemble("(2 (q . 1) (c (q . 1) (q . 1)) (q . 0x1337))")) + assert plus.uncurry() == (plus, Program.to(0)) + + +def ztest_uncurry_not_pair(): + # the second item in the list is expected to be a pair, with a qoute + plus = Program.to(assemble("(2 1 (c (q . 1) (q . 1)))")) + assert plus.uncurry() == (plus, Program.to(0)) + + +def ztest_uncurry_args_garbage(): + # there's garbage at the end of the args list + plus = Program.to(assemble("(2 (q . 1) (c (q . 1) (q . 1) (q . 0x1337)))")) + assert plus.uncurry() == (plus, Program.to(0)) diff --git a/wheel/python/tests/to_sexp_test.py b/wheel/python/tests/to_sexp_test.py new file mode 100644 index 00000000..5a65b8d5 --- /dev/null +++ b/wheel/python/tests/to_sexp_test.py @@ -0,0 +1,177 @@ +import unittest + +from typing import Optional, Tuple, Any +from clvm_rs.program import Program +# from clvm.CLVMObject import CLVMObject + + +def convert_atom_to_bytes(castable: Any) -> Optional[bytes]: + return Program.to(castable).atom + + +def looks_like_clvm_object(o: Any) -> bool: + d = dir(o) + return "atom" in d and "pair" in d + + +def validate_program(program): + validate_stack = [program] + while validate_stack: + v = validate_stack.pop() + assert isinstance(v, Program) + if v.pair: + assert isinstance(v.pair, tuple) + v1, v2 = v.pair + assert looks_like_clvm_object(v1) + assert looks_like_clvm_object(v2) + s1, s2 = v.as_pair() + validate_stack.append(s1) + validate_stack.append(s2) + else: + assert isinstance(v.atom, bytes) + + +def print_leaves(tree: Program) -> str: + a = tree.as_atom() + if a is not None: + if len(a) == 0: + return "() " + return "%d " % a[0] + + ret = "" + for i in tree.as_pair(): + ret += print_leaves(i) + + return ret + + +def print_tree(tree: Program) -> str: + a = tree.as_atom() + if a is not None: + if len(a) == 0: + return "() " + return "%d " % a[0] + + ret = "(" + for i in tree.as_pair(): + ret += print_tree(i) + ret += ")" + return ret + + +class ToProgramTest(unittest.TestCase): + def test_cast_1(self): + # this was a problem in `clvm_tools` and is included + # to prevent regressions + program = Program.to(b"foo") + t1 = program.to([1, program]) + validate_program(t1) + + def test_wrap_program(self): + # it's a bit of a layer violation that CLVMObject unwraps Program, but we + # rely on that in a fair number of places for now. We should probably + # work towards phasing that out + o = Program.to(Program.to(1)) + assert o.atom == bytes([1]) + + def test_arbitrary_underlying_tree(self): + + # Program provides a view on top of a tree of arbitrary types, as long as + # those types implement the CLVMObject protocol. This is an example of + # a tree that's generated + class GeneratedTree: + + depth: int = 4 + val: int = 0 + + def __init__(self, depth, val): + assert depth >= 0 + self.depth = depth + self.val = val + + @property + def atom(self) -> Optional[bytes]: + if self.depth > 0: + return None + return bytes([self.val]) + + @property + def pair(self) -> Optional[Tuple[Any, Any]]: + if self.depth == 0: + return None + new_depth: int = self.depth - 1 + return (GeneratedTree(new_depth, self.val), GeneratedTree(new_depth, self.val + 2**new_depth)) + + tree = Program.to(GeneratedTree(5, 0)) + assert print_leaves(tree) == "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 " + \ + "16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 " + + tree = Program.to(GeneratedTree(3, 0)) + assert print_leaves(tree) == "0 1 2 3 4 5 6 7 " + + tree = Program.to(GeneratedTree(3, 10)) + assert print_leaves(tree) == "10 11 12 13 14 15 16 17 " + + def test_looks_like_clvm_object(self): + + # this function can't look at the values, that would cause a cascade of + # eager evaluation/conversion + class dummy: + pass + + obj = dummy() + obj.pair = None + obj.atom = None + print(dir(obj)) + assert looks_like_clvm_object(obj) + + obj = dummy() + obj.pair = None + assert not looks_like_clvm_object(obj) + + obj = dummy() + obj.atom = None + assert not looks_like_clvm_object(obj) + + def test_list_conversions(self): + a = Program.to([1, 2, 3]) + assert print_tree(a) == "(1 (2 (3 () )))" + + def test_string_conversions(self): + a = Program.to("foobar") + assert a.as_atom() == "foobar".encode() + + def test_int_conversions(self): + a = Program.to(1337) + assert a.as_atom() == bytes([0x5, 0x39]) + + def test_none_conversions(self): + a = Program.to(None) + assert a.as_atom() == b"" + + def test_empty_list_conversions(self): + a = Program.to([]) + assert a.as_atom() == b"" + + def test_eager_conversion(self): + with self.assertRaises(ValueError): + Program.to(("foobar", (1, {}))) + + def test_convert_atom(self): + assert convert_atom_to_bytes(0x133742) == bytes([0x13, 0x37, 0x42]) + assert convert_atom_to_bytes(0x833742) == bytes([0x00, 0x83, 0x37, 0x42]) + assert convert_atom_to_bytes(0) == b"" + + assert convert_atom_to_bytes("foobar") == "foobar".encode() + assert convert_atom_to_bytes("") == b"" + + assert convert_atom_to_bytes(b"foobar") == b"foobar" + assert convert_atom_to_bytes(None) == b"" + assert convert_atom_to_bytes([]) == b"" + + assert convert_atom_to_bytes([1, 2, 3]) == None + + assert convert_atom_to_bytes((1, 2)) == None + + with self.assertRaises(ValueError): + assert convert_atom_to_bytes({}) From 01e11d085d1d2df42554636707ca3b7fac8017e5 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 21 Sep 2022 14:23:24 -0700 Subject: [PATCH 02/45] Add serialize tests. --- wheel/python/clvm_rs/clvm_tree.py | 2 +- wheel/python/clvm_rs/deser.py | 12 ++- wheel/python/clvm_rs/program.py | 7 +- wheel/python/clvm_rs/serialize.py | 54 +----------- wheel/python/tests/serialize_test.py | 126 +++++++++++++++++++++++++++ 5 files changed, 138 insertions(+), 63 deletions(-) create mode 100644 wheel/python/tests/serialize_test.py diff --git a/wheel/python/clvm_rs/clvm_tree.py b/wheel/python/clvm_rs/clvm_tree.py index dcbc2221..2d25417e 100644 --- a/wheel/python/clvm_rs/clvm_tree.py +++ b/wheel/python/clvm_rs/clvm_tree.py @@ -1,4 +1,4 @@ -from clvm_rs import deserialize_as_triples +from .deser import deserialize_as_triples from typing import List, Optional, Tuple diff --git a/wheel/python/clvm_rs/deser.py b/wheel/python/clvm_rs/deser.py index 220b0b01..6b7cec4b 100644 --- a/wheel/python/clvm_rs/deser.py +++ b/wheel/python/clvm_rs/deser.py @@ -1,6 +1,4 @@ -from typing import Tuple - -from .clvm_tree import CLVMTree +from typing import List, Tuple MAX_SINGLE_BYTE = 0x7F CONS_BOX_MARKER = 0xFF @@ -10,7 +8,9 @@ # PAIR: serialize_offset, serialize_end, right_index -def deserialized_in_place(blob: bytes, cursor: int = 0) -> CLVMTree: +def deserialize_as_triples( + blob: bytes, cursor: int = 0 +) -> List[Tuple[int, int, int]]: def save_cursor(index, blob, cursor, obj_list, op_stack): obj_list[index] = (obj_list[index][0], cursor, obj_list[index][2]) return cursor @@ -40,9 +40,7 @@ def parse_obj(blob, cursor, obj_list, op_stack): while op_stack: f = op_stack.pop() cursor = f(blob, cursor, obj_list, op_stack) - - v = CLVMTree(blob, obj_list, 0) - return v + return obj_list def _atom_size_from_cursor(blob, cursor) -> Tuple[int, int]: diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index a61cb304..bc080ca8 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -50,8 +50,11 @@ def from_bytes_with_cursor( cls, blob: bytes, cursor: int ) -> Tuple[Program, int]: tree = CLVMTree.from_bytes(blob) - new_cursor = tree[-1][1] - obj = cls.wrap(tree) + if tree.atom is not None: + obj = cls.new_atom(tree.atom) + else: + obj = cls.wrap(tree) + new_cursor = len(bytes(tree)) + cursor return obj, new_cursor @classmethod diff --git a/wheel/python/clvm_rs/serialize.py b/wheel/python/clvm_rs/serialize.py index bbffa44e..b6033ef2 100644 --- a/wheel/python/clvm_rs/serialize.py +++ b/wheel/python/clvm_rs/serialize.py @@ -102,7 +102,7 @@ def _op_read_sexp(op_stack, val_stack, f, new_pair_f, new_atom_f): def _op_cons(op_stack, val_stack, f, new_pair_f, new_atom_f): right = val_stack.pop() left = val_stack.pop() - val_stack.append(new_pair_f((left, right))) + val_stack.append(new_pair_f(left, right)) def sexp_from_stream(f, new_pair_f, new_atom_f): @@ -115,58 +115,6 @@ def sexp_from_stream(f, new_pair_f, new_atom_f): return val_stack.pop() -def _op_consume_sexp(f): - blob = f.read(1) - if len(blob) == 0: - raise ValueError("bad encoding") - b = blob[0] - if b == CONS_BOX_MARKER: - return (blob, 2) - return (_consume_atom(f, b), 0) - - -def _consume_atom(f, b): - if b == 0x80: - return bytes([b]) - if b <= MAX_SINGLE_BYTE: - return bytes([b]) - bit_count = 0 - bit_mask = 0x80 - ll = b - while ll & bit_mask: - bit_count += 1 - ll &= 0xFF ^ bit_mask - bit_mask >>= 1 - size_blob = bytes([ll]) - if bit_count > 1: - ll = f.read(bit_count - 1) - if len(ll) != bit_count - 1: - raise ValueError("bad encoding") - size_blob += ll - size = int.from_bytes(size_blob, "big") - if size >= 0x400000000: - raise ValueError("blob too large") - blob = f.read(size) - if len(blob) != size: - raise ValueError("bad encoding") - return bytes([b]) + size_blob[1:] + blob - - -# instead of parsing the input stream, this function pulls out all the bytes -# that represent on S-expression tree, and returns them. This is more efficient -# than parsing and returning a python S-expression tree. -def sexp_buffer_from_stream(f): - ret = b"" - - depth = 1 - while depth > 0: - depth -= 1 - buf, d = _op_consume_sexp(f) - depth += d - ret += buf - return ret - - def _atom_from_stream(f, b, new_atom_f): if b == 0x80: return new_atom_f(b"") diff --git a/wheel/python/tests/serialize_test.py b/wheel/python/tests/serialize_test.py new file mode 100644 index 00000000..0a44f83f --- /dev/null +++ b/wheel/python/tests/serialize_test.py @@ -0,0 +1,126 @@ +import io +import unittest + +from clvm_rs.program import Program +from clvm_rs.serialize import atom_to_byte_iterator + + +TEXT = b"the quick brown fox jumps over the lazy dogs" + + +class InfiniteStream(io.TextIOBase): + def __init__(self, b): + self.buf = b + + def read(self, n): + ret = b"" + while n > 0 and len(self.buf) > 0: + ret += self.buf[0:1] + self.buf = self.buf[1:] + n -= 1 + ret += b" " * n + return ret + + +class LargeAtom: + def __len__(self): + return 0x400000001 + + +class SerializeTest(unittest.TestCase): + def check_serde(self, s): + v = Program.to(s) + b = bytes(v) + v1 = Program.parse(io.BytesIO(b)) + if v != v1: + print("%s: %d %s %s" % (v, len(b), b, v1)) + breakpoint() + b = bytes(v) + v1 = Program.parse(io.BytesIO(b)) + self.assertEqual(v, v1) + + def test_zero(self): + v = Program.to(b"\x00") + self.assertEqual(bytes(v), b"\x00") + + def test_empty(self): + v = Program.to(b"") + self.assertEqual(bytes(v), b"\x80") + + def test_empty_string(self): + self.check_serde(b"") + + def test_single_bytes(self): + for _ in range(256): + self.check_serde(bytes([_])) + + def test_short_lists(self): + self.check_serde([]) + for _ in range(0, 2048, 8): + for size in range(1, 5): + self.check_serde([_] * size) + + def test_cons_box(self): + self.check_serde((None, None)) + self.check_serde((None, [1, 2, 30, 40, 600, (None, 18)])) + self.check_serde((100, (TEXT, (30, (50, (90, (TEXT, TEXT + TEXT))))))) + + def test_long_blobs(self): + text = TEXT * 300 + for _, t in enumerate(text): + t1 = text[:_] + self.check_serde(t1) + + def test_blob_limit(self): + with self.assertRaises(ValueError): + for b in atom_to_byte_iterator(LargeAtom()): + print("%02x" % b) + + def test_very_long_blobs(self): + for size in [0x40, 0x2000, 0x100000, 0x8000000]: + count = size // len(TEXT) + text = TEXT * count + assert len(text) < size + self.check_serde(text) + text = TEXT * (count + 1) + assert len(text) > size + self.check_serde(text) + + def test_very_deep_tree(self): + blob = b"a" + for depth in [10, 100, 1000, 10000, 100000]: + s = Program.to(blob) + for _ in range(depth): + s = Program.to((s, blob)) + self.check_serde(s) + + def test_deserialize_empty(self): + bytes_in = b"" + with self.assertRaises(ValueError): + Program.parse(io.BytesIO(bytes_in)) + + def test_deserialize_truncated_size(self): + # fe means the total number of bytes in the length-prefix is 7 + # one for each bit set. 5 bytes is too few + bytes_in = b"\xfe " + with self.assertRaises(ValueError): + Program.parse(io.BytesIO(bytes_in)) + + def test_deserialize_truncated_blob(self): + # this is a complete length prefix. The blob is supposed to be 63 bytes + # the blob itself is truncated though, it's less than 63 bytes + bytes_in = b"\xbf " + + with self.assertRaises(ValueError): + Program.parse(io.BytesIO(bytes_in)) + + def test_deserialize_large_blob(self): + # this length prefix is 7 bytes long, the last 6 bytes specifies the + # length of the blob, which is 0xffffffffffff, or (2^48 - 1) + # we don't support blobs this large, and we should fail immediately + # when exceeding the max blob size, rather than trying to read this + # many bytes from the stream + bytes_in = b"\xfe" + b"\xff" * 6 + + with self.assertRaises(ValueError): + Program.parse(InfiniteStream(bytes_in)) From f8993f17b33691dae1fc182c3f09bd373213ca73 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 21 Sep 2022 22:58:34 -0700 Subject: [PATCH 03/45] Test improvements. --- wheel/python/clvm_rs/curry.py | 85 ----------------- wheel/python/clvm_rs/keywords.py | 6 ++ wheel/python/clvm_rs/program.py | 27 +++--- .../{as_python_test.py => test_as_python.py} | 4 +- wheel/python/tests/test_program.py | 94 ++++++++++++------- .../{serialize_test.py => test_serialize.py} | 0 .../{to_sexp_test.py => test_to_program.py} | 13 ++- 7 files changed, 92 insertions(+), 137 deletions(-) delete mode 100644 wheel/python/clvm_rs/curry.py create mode 100644 wheel/python/clvm_rs/keywords.py rename wheel/python/tests/{as_python_test.py => test_as_python.py} (98%) rename wheel/python/tests/{serialize_test.py => test_serialize.py} (100%) rename wheel/python/tests/{to_sexp_test.py => test_to_program.py} (93%) diff --git a/wheel/python/clvm_rs/curry.py b/wheel/python/clvm_rs/curry.py deleted file mode 100644 index 64b60ea5..00000000 --- a/wheel/python/clvm_rs/curry.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Any, Optional, Tuple - -CLVM = Any - - -def at(sexp: CLVM, position: str) -> Optional[CLVM]: - """ - Take a string of only `f` and `r` characters and follow the corresponding path. - - Example: - - `assert Program.to(17) == Program.to([10, 20, 30, [15, 17], 40, 50]).at("rrrfrf")` - - """ - v = sexp - for c in position.lower(): - p = v.pair - if p is None: - return p - if c not in "rf": - raise ValueError( - f"`at` got illegal character `{c}`. Only `f` & `r` allowed" - ) - v = p[0 if c == "f" else 1] - return v - - -# Replicates the curry function from clvm_tools, taking advantage of *args -# being a list. We iterate through args in reverse building the code to -# create a clvm list. -# -# Given arguments to a function addressable by the '1' reference in clvm -# -# fixed_args = 1 -# -# Each arg is prepended as fixed_args = (c (q . arg) fixed_args) -# -# The resulting argument list is interpreted with apply (2) -# -# (2 (1 . self) rest) -# -# Resulting in a function which places its own arguments after those -# curried in in the form of a proper list. - - -def curry(sexp: CLVM, *args) -> CLVM: - fixed_args: Any = 1 - while args: - arg = args.pop() - fixed_args = [4, (1, arg), fixed_args] - return sexp.to([2, (1, sexp), fixed_args]) - - -# UNCURRY_PATTERN_FUNCTION = assemble("(a (q . (: . function)) (: . core))") -# UNCURRY_PATTERN_CORE = assemble("(c (q . (: . parm)) (: . core))") - - -ONE_PATH = Q_KW = bytes([1]) -C_KW = bytes([2]) -A_KW = bytes([4]) -NULL = bytes([]) - - -def uncurry(sexp: CLVM) -> Optional[Tuple[CLVM, CLVM]]: - if ( - at(sexp, "f").atom != A_KW - or at(sexp, "rf").atom != Q_KW - or at(sexp, "rrr").atom != NULL - ): - return None - uncurried_function = at(sexp, "rr") - core_items = [] - core = at(sexp, "rrf") - while core.atom != ONE_PATH: - if ( - at(core, "f").atom != C_KW - or at(core, "rf").atom != Q_KW - or at(sexp, "rrr").atom != NULL - ): - return None - new_item = at(core, "rr") - core_items.append(new_item) - core = at(core, "rrf") - core_items.reverse() - return uncurried_function, core_items diff --git a/wheel/python/clvm_rs/keywords.py b/wheel/python/clvm_rs/keywords.py new file mode 100644 index 00000000..f246b1e3 --- /dev/null +++ b/wheel/python/clvm_rs/keywords.py @@ -0,0 +1,6 @@ +NULL = bytes.fromhex("") +ONE = bytes.fromhex("01") +TWO = bytes.fromhex("02") +Q_KW = bytes.fromhex("01") +A_KW = bytes.fromhex("02") +C_KW = bytes.fromhex("04") diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index bc080ca8..9de023f5 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -249,27 +249,26 @@ def curry(self, *args) -> "Program": fixed_args = [4, (1, arg), fixed_args] return Program.to([2, (1, self), fixed_args]) - def uncurry(self) -> Optional[Tuple[Program, Program]]: + def uncurry(self) -> Tuple[Program, Optional[Program]]: if ( - self.at("f").atom != A_KW - or self.at("rf").atom != Q_KW - or self.at("rrr").atom != NULL + self.at("f") != A_KW + or self.at("rff") != Q_KW + or self.at("rrr") != NULL ): - return None - uncurried_function = self.at("rr") + return self, None + uncurried_function = self.at("rfr") core_items = [] core = self.at("rrf") - while core.atom != ONE: + while core != ONE: if ( - core.at("f").atom != C_KW - or core.at("rf").atom != Q_KW - or core.at("rrr").atom != NULL + core.at("f") != C_KW + or core.at("rff") != Q_KW + or core.at("rrr") != NULL ): - return None - new_item = core.at("rr") + return self, None + new_item = core.at("rfr") core_items.append(new_item) core = core.at("rrf") - core_items.reverse() return uncurried_function, core_items def as_int(self) -> int: @@ -346,7 +345,7 @@ def _replace(program: Program, **kwargs) -> Program: new_f = _replace(pair[0], **args_by_prefix.get("f", {})) new_r = _replace(pair[1], **args_by_prefix.get("r", {})) - return program.new_pair((new_f, new_r)) + return program.new_pair(Program.to(new_f), Program.to(new_r)) def int_from_bytes(blob): diff --git a/wheel/python/tests/as_python_test.py b/wheel/python/tests/test_as_python.py similarity index 98% rename from wheel/python/tests/as_python_test.py rename to wheel/python/tests/test_as_python.py index 01fb5dfa..b6760410 100644 --- a/wheel/python/tests/as_python_test.py +++ b/wheel/python/tests/test_as_python.py @@ -183,7 +183,9 @@ def test_as_iter(self): # these fail because the lists are not null-terminated self.assertEqual(list(Program.to(1).as_iter()), []) - self.assertEqual(list(Program.to((1, (2, (3, (4, 5))))).as_iter()), [1, 2, 3, 4]) + self.assertEqual( + list(Program.to((1, (2, (3, (4, 5))))).as_iter()), [1, 2, 3, 4] + ) def test_eq(self): val = Program.to(1) diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py index b9c38d27..aa7c37c1 100644 --- a/wheel/python/tests/test_program.py +++ b/wheel/python/tests/test_program.py @@ -2,8 +2,8 @@ from clvm_rs.program import Program from clvm_rs.EvalError import EvalError -#from clvm.operators import KEYWORD_TO_ATOM -#from clvm_tools.binutils import assemble, disassemble + +from clvm_rs.keywords import A_KW, C_KW, Q_KW class TestProgram(TestCase): @@ -26,7 +26,9 @@ def test_replace(self): p1 = Program.to([100, 200, 300]) self.assertEqual(p1.replace(f=105), Program.to([105, 200, 300])) self.assertEqual(p1.replace(rrf=[301, 302]), Program.to([100, 200, [301, 302]])) - self.assertEqual(p1.replace(f=105, rrf=[301, 302]), Program.to([105, 200, [301, 302]])) + self.assertEqual( + p1.replace(f=105, rrf=[301, 302]), Program.to([105, 200, [301, 302]]) + ) self.assertEqual(p1.replace(f=100, r=200), Program.to((100, 200))) def test_replace_conflicts(self): @@ -43,56 +45,82 @@ def test_replace_bad_path(self): self.assertRaises(ValueError, lambda: p1.replace(rq=105)) -def check_idempotency(f, *args): - prg = Program.to(f) - curried = prg.curry(*args) +def check_idempotency(p, *args): + curried = p.curry(*args) - r = disassemble(curried) f_0, args_0 = curried.uncurry() - assert disassemble(f_0) == disassemble(f) - assert disassemble(args_0) == disassemble(Program.to(list(args))) - return r + assert f_0 == p + assert len(args_0) == len(args) + for a, a0 in zip(args, args_0): + assert a == a0 + return curried -def ztest_curry_uncurry(): - PLUS = KEYWORD_TO_ATOM["+"][0] - f = assemble("(+ 2 5)") - actual_disassembly = check_idempotency(f, 200, 30) - assert actual_disassembly == f"(a (q {PLUS} 2 5) (c (q . 200) (c (q . 30) 1)))" - f = assemble("(+ 2 5)") - args = assemble("(+ (q . 50) (q . 60))") - # passing "args" here wraps the arguments in a list - actual_disassembly = check_idempotency(f, args) - assert actual_disassembly == f"(a (q {PLUS} 2 5) (c (q {PLUS} (q . 50) (q . 60)) 1))" +def test_curry_uncurry(): + PLUS = Program.fromhex("10") # `+` + + p = Program.fromhex("ff10ff02ff0580") # `(+ 2 5)` + + curried_p = check_idempotency(p) + assert curried_p == [A_KW, [Q_KW, PLUS, 2, 5], 1] + + curried_p = check_idempotency(p, b"dogs") + assert curried_p == [A_KW, [Q_KW, PLUS, 2, 5], [C_KW, (Q_KW, "dogs"), 1]] + curried_p = check_idempotency(p, 200, 30) + assert curried_p == [ + A_KW, + [Q_KW, PLUS, 2, 5], + [C_KW, (Q_KW, 200), [C_KW, (Q_KW, 30), 1]], + ] -def ztest_uncurry_not_curried(): + # passing "args" here wraps the arguments in a list + curried_p = check_idempotency(p, 50, 60, 70, 80) + assert curried_p == [ + A_KW, + [Q_KW, PLUS, 2, 5], + [ + C_KW, + (Q_KW, 50), + [C_KW, (Q_KW, 60), [C_KW, (Q_KW, 70), [C_KW, (Q_KW, 80), 1]]], + ], + ] + + +def test_uncurry_not_curried(): # this function has not been curried - plus = Program.to(assemble("(+ 2 5)")) - assert plus.uncurry() == (plus, Program.to(0)) + plus = Program.fromhex("ff10ff02ff0580") # `(+ 2 5)` + assert plus.uncurry() == (plus, None) -def ztest_uncurry(): +def test_uncurry(): # this is a positive test - plus = Program.to(assemble("(2 (q . (+ 2 5)) (c (q . 1) 1))")) - assert plus.uncurry() == (Program.to(assemble("(+ 2 5)")), Program.to([1])) + # `(a (q . (+ 2 5)) (c (q . 1) 1))` + plus = Program.fromhex("ff02ffff01ff10ff02ff0580ffff04ffff0101ff018080") + prog = Program.fromhex("ff10ff02ff0580") # `(+ 2 5)` + args = Program.fromhex("01") # `1` + assert plus.uncurry() == (prog, [args]) -def ztest_uncurry_top_level_garbage(): + +def test_uncurry_top_level_garbage(): # there's garbage at the end of the top-level list - plus = Program.to(assemble("(2 (q . 1) (c (q . 1) (q . 1)) (q . 0x1337))")) - assert plus.uncurry() == (plus, Program.to(0)) + # `(a (q . 1) (c (q . 1) (q . 1)) (q . 0x1337))` + plus = Program.fromhex("ff02ffff0101ffff04ffff0101ffff010180ffff0182133780") + assert plus.uncurry() == (plus, None) -def ztest_uncurry_not_pair(): +def test_uncurry_not_pair(): # the second item in the list is expected to be a pair, with a qoute - plus = Program.to(assemble("(2 1 (c (q . 1) (q . 1)))")) + # `(a 1 (c (q . 1) (q . 1)))` + plus = Program.fromhex("ff02ff01ffff04ffff0101ffff01018080") assert plus.uncurry() == (plus, Program.to(0)) -def ztest_uncurry_args_garbage(): +def test_uncurry_args_garbage(): # there's garbage at the end of the args list - plus = Program.to(assemble("(2 (q . 1) (c (q . 1) (q . 1) (q . 0x1337)))")) + # `(a (q . 1) (c (q . 1) (q . 1) (q . 4919)))` + plus = Program.fromhex("ff02ffff0101ffff04ffff0101ffff0101ffff018213378080") assert plus.uncurry() == (plus, Program.to(0)) diff --git a/wheel/python/tests/serialize_test.py b/wheel/python/tests/test_serialize.py similarity index 100% rename from wheel/python/tests/serialize_test.py rename to wheel/python/tests/test_serialize.py diff --git a/wheel/python/tests/to_sexp_test.py b/wheel/python/tests/test_to_program.py similarity index 93% rename from wheel/python/tests/to_sexp_test.py rename to wheel/python/tests/test_to_program.py index 5a65b8d5..cb092614 100644 --- a/wheel/python/tests/to_sexp_test.py +++ b/wheel/python/tests/test_to_program.py @@ -2,7 +2,6 @@ from typing import Optional, Tuple, Any from clvm_rs.program import Program -# from clvm.CLVMObject import CLVMObject def convert_atom_to_bytes(castable: Any) -> Optional[bytes]: @@ -100,11 +99,17 @@ def pair(self) -> Optional[Tuple[Any, Any]]: if self.depth == 0: return None new_depth: int = self.depth - 1 - return (GeneratedTree(new_depth, self.val), GeneratedTree(new_depth, self.val + 2**new_depth)) + return ( + GeneratedTree(new_depth, self.val), + GeneratedTree(new_depth, self.val + 2**new_depth), + ) tree = Program.to(GeneratedTree(5, 0)) - assert print_leaves(tree) == "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 " + \ - "16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 " + assert ( + print_leaves(tree) + == "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 " + + "16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 " + ) tree = Program.to(GeneratedTree(3, 0)) assert print_leaves(tree) == "0 1 2 3 4 5 6 7 " From 91c24921db41977084b54f9c67869ec6c2776943 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 28 Sep 2022 13:52:09 -0700 Subject: [PATCH 04/45] Some speed improvements. --- wheel/python/clvm_rs/casts.py | 128 ++++++++++++++++++ wheel/python/clvm_rs/curry_and_treehash.py | 9 +- wheel/python/clvm_rs/program.py | 34 +++-- wheel/python/tests/test_curry_and_treehash.py | 21 +++ wheel/src/adapt_response.rs | 4 +- wheel/src/lazy_node.rs | 9 +- 6 files changed, 178 insertions(+), 27 deletions(-) create mode 100644 wheel/python/clvm_rs/casts.py create mode 100644 wheel/python/tests/test_curry_and_treehash.py diff --git a/wheel/python/clvm_rs/casts.py b/wheel/python/clvm_rs/casts.py new file mode 100644 index 00000000..84d86828 --- /dev/null +++ b/wheel/python/clvm_rs/casts.py @@ -0,0 +1,128 @@ +from typing import Any, Tuple, Union + +AtomCastableType = Union[ + bytes, + str, + int, + None, +] + + +CastableType = Union[ + AtomCastableType, + list, + Tuple["CastableType", "CastableType"], +] + + +def looks_like_clvm_object(o: Any) -> bool: + d = dir(o) + return "atom" in d and "pair" in d + + +def int_to_bytes(v): + byte_count = (v.bit_length() + 8) >> 3 + if v == 0: + return b"" + r = v.to_bytes(byte_count, "big", signed=True) + # make sure the string returned is minimal + # ie. no leading 00 or ff bytes that are unnecessary + while len(r) > 1 and r[0] == (0xFF if r[1] & 0x80 else 0): + r = r[1:] + return r + + +NULL = b"" + + +def to_atom_type(v: AtomCastableType) -> bytes: + + if isinstance(v, bytes): + return v + if isinstance(v, str): + return v.encode() + if isinstance(v, int): + return int_to_bytes(v) + if hasattr(v, "__bytes__"): + return bytes(v) + if v is None: + return NULL + if v == []: + return NULL + + raise ValueError("can't cast %s (%s) to bytes" % (type(v), v)) + + +def to_clvm_object( + v: CastableType, + to_atom_f, + to_pair_f, +): + stack = [v] + ops = [(0, None)] # convert + + while len(ops) > 0: + op, target = ops.pop() + # convert value + if op == 0: + if looks_like_clvm_object(stack[-1]): + obj = stack.pop() + if obj.pair is None: + new_obj = to_atom_f(obj.atom) + else: + new_obj = to_pair_f(obj.pair[0], obj.pair[1]) + stack.append(new_obj) + continue + v = stack.pop() + if isinstance(v, tuple): + if len(v) != 2: + raise ValueError("can't cast tuple of size %d" % len(v)) + left, right = v + target = len(stack) + ll_right = looks_like_clvm_object(right) + ll_left = looks_like_clvm_object(left) + if ll_right and ll_left: + stack.append(to_pair_f(left, right)) + else: + ops.append((3, None)) # cons + stack.append(right) + ops.append((0, None)) # convert + ops.append((2, None)) # roll + stack.append(left) + ops.append((0, None)) # convert + continue + if isinstance(v, list): + target = len(stack) + stack.append(to_atom_f(NULL)) + for _ in v: + stack.append(_) + ops.append((1, target)) # prepend list + # we only need to convert if it's not already the right + # type + if not looks_like_clvm_object(_): + ops.append((0, None)) # convert + continue + stack.append(to_atom_f(to_atom_type(v))) + continue + + if op == 1: # prepend list + stack[target] = to_pair_f(stack.pop(), stack[target]) + continue + if op == 2: # roll + p1 = stack.pop() + p2 = stack.pop() + stack.append(p1) + stack.append(p2) + continue + if op == 3: # cons + right = stack.pop() + left = stack.pop() + obj = to_pair_f(left, right) + stack.append(obj) + continue + # there's exactly one item left at this point + if len(stack) != 1: + raise ValueError("internal error") + + # stack[0] implements the clvm object protocol + return stack[0] diff --git a/wheel/python/clvm_rs/curry_and_treehash.py b/wheel/python/clvm_rs/curry_and_treehash.py index 2ae64137..45d50b29 100644 --- a/wheel/python/clvm_rs/curry_and_treehash.py +++ b/wheel/python/clvm_rs/curry_and_treehash.py @@ -1,17 +1,10 @@ from typing import List from .bytes32 import bytes32 +from .keywords import NULL, ONE, TWO, Q_KW, A_KW, C_KW from .tree_hash import shatree_atom, shatree_pair -NULL = bytes.fromhex("") -ONE = bytes.fromhex("01") -TWO = bytes.fromhex("02") -Q_KW = bytes.fromhex("01") -A_KW = bytes.fromhex("02") -C_KW = bytes.fromhex("04") - - Q_KW_TREEHASH = shatree_atom(Q_KW) A_KW_TREEHASH = shatree_atom(A_KW) C_KW_TREEHASH = shatree_atom(C_KW) diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index 9de023f5..c680ba6b 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -4,26 +4,20 @@ # from clvm import Program from .base import CLVMObject -from .casts import to_sexp_type +from .casts import to_clvm_object from clvm_rs.clvm_rs import run_serialized_program from clvm_rs.serialize import sexp_from_stream, sexp_to_stream from clvm_rs.tree_hash import sha256_treehash from .clvm_tree import CLVMTree from .bytes32 import bytes32 +from .keywords import NULL, ONE, TWO, Q_KW, A_KW, C_KW # from chia.util.hash import std_hash # from chia.util.byte_types import hexstr_to_bytes # from chia.types.spend_bundle_conditions import SpendBundleConditions -INFINITE_COST = 0x7FFFFFFFFFFFFFFF - -NULL = bytes.fromhex("") -ONE = bytes.fromhex("01") -TWO = bytes.fromhex("02") -Q_KW = bytes.fromhex("01") -A_KW = bytes.fromhex("02") -C_KW = bytes.fromhex("04") +MAX_COST = 0x7FFFFFFFFFFFFFFF class Program(CLVMObject): @@ -68,9 +62,15 @@ def __bytes__(self) -> bytes: # high level casting with `.to` + def __init__(self) -> Program: + self.atom = b"" + self.pair = None + self.wrapped = self + self._cached_sha256_treehash = None + @classmethod def to(cls, v: Any) -> Program: - return to_sexp_type(v, cls.new_atom, cls.new_pair) + return cls.wrap(to_clvm_object(v, cls.new_atom, cls.new_pair)) @classmethod def wrap(cls, v: CLVMObject) -> Program: @@ -79,8 +79,12 @@ def wrap(cls, v: CLVMObject) -> Program: o = cls() o.atom = v.atom o.pair = v.pair + o.wrapped = v return o + def unwrap(self) -> CLVMObject: + return self.wrapped + # new object creation on the python heap @classmethod @@ -215,16 +219,18 @@ def replace(self, **kwargs) -> "Program": return _replace(self, **kwargs) def tree_hash(self) -> bytes32: - return sha256_treehash(bytes(self)) + return sha256_treehash(self.unwrap()) - def run_with_cost(self, max_cost: int, args) -> Tuple[int, "Program"]: + def run_with_cost( + self, args, max_cost: int = MAX_COST + ) -> Tuple[int, "Program"]: prog_bytes = bytes(self) args_bytes = bytes(self.to(args)) cost, r = run_serialized_program(prog_bytes, args_bytes, max_cost, 0) return cost, Program.to(r) def run(self, args) -> "Program": - cost, r = self.run_with_cost(INFINITE_COST, args) + cost, r = self.run_with_cost(args, MAX_COST) return r # Replicates the curry function from clvm_tools, taking advantage of *args @@ -311,7 +317,7 @@ def __deepcopy__(self, memo): return type(self).from_bytes(bytes(self)) -NULL_PROGRAM = Program.from_bytes(b"\x80") +NULL_PROGRAM = Program.fromhex("80") def _replace(program: Program, **kwargs) -> Program: diff --git a/wheel/python/tests/test_curry_and_treehash.py b/wheel/python/tests/test_curry_and_treehash.py new file mode 100644 index 00000000..3bf8b7b5 --- /dev/null +++ b/wheel/python/tests/test_curry_and_treehash.py @@ -0,0 +1,21 @@ +from clvm_rs.program import Program +from clvm_rs.curry_and_treehash import calculate_hash_of_quoted_mod_hash, curry_and_treehash + + +def test_curry_and_treehash() -> None: + + arbitrary_mod = Program.fromhex("ff10ff02ff0580") # `(+ 2 5)` + arbitrary_mod_hash = arbitrary_mod.tree_hash() + + # we don't really care what `arbitrary_mod` is. We just need some code + + quoted_mod_hash = calculate_hash_of_quoted_mod_hash(arbitrary_mod_hash) + + for v in range(500): + args = [v, v * v, v * v * v] + # we don't really care about the arguments either + puzzle = arbitrary_mod.curry(*args) + puzzle_hash_via_curry = puzzle.tree_hash() + hashed_args = [Program.to(_).tree_hash() for _ in args] + puzzle_hash_via_f = curry_and_treehash(quoted_mod_hash, *hashed_args) + assert puzzle_hash_via_curry == puzzle_hash_via_f diff --git a/wheel/src/adapt_response.rs b/wheel/src/adapt_response.rs index 6b30886a..18d55209 100644 --- a/wheel/src/adapt_response.rs +++ b/wheel/src/adapt_response.rs @@ -15,11 +15,11 @@ pub fn adapt_response( ) -> PyResult<(u64, LazyNode)> { match response { Ok(reduction) => { - let val = LazyNode::new(Rc::new(allocator), reduction.1); + let val = LazyNode::new(py, Rc::new(allocator), reduction.1); Ok((reduction.0, val)) } Err(eval_err) => { - let sexp = LazyNode::new(Rc::new(allocator), eval_err.0).to_object(py); + let sexp = LazyNode::new(py, Rc::new(allocator), eval_err.0).to_object(py); let msg = eval_err.1.to_object(py); let tuple = PyTuple::new(py, [msg, sexp]); let value_error: PyErr = PyValueError::new_err(tuple.to_object(py)); diff --git a/wheel/src/lazy_node.rs b/wheel/src/lazy_node.rs index d3495eae..425664dc 100644 --- a/wheel/src/lazy_node.rs +++ b/wheel/src/lazy_node.rs @@ -9,6 +9,8 @@ use pyo3::types::{PyBytes, PyTuple}; pub struct LazyNode { allocator: Rc, node: NodePtr, + #[pyo3(get, set)] + _cached_sha256_treehash: PyObject, } impl ToPyObject for LazyNode { @@ -25,8 +27,8 @@ impl LazyNode { pub fn pair(&self, py: Python) -> PyResult> { match &self.allocator.sexp(self.node) { SExp::Pair(p1, p2) => { - let r1 = Self::new(self.allocator.clone(), *p1); - let r2 = Self::new(self.allocator.clone(), *p2); + let r1 = Self::new(py, self.allocator.clone(), *p1); + let r2 = Self::new(py, self.allocator.clone(), *p2); let v: &PyTuple = PyTuple::new(py, &[r1, r2]); Ok(Some(v.into())) } @@ -44,10 +46,11 @@ impl LazyNode { } impl LazyNode { - pub const fn new(a: Rc, n: NodePtr) -> Self { + pub fn new(py: Python, a: Rc, n: NodePtr) -> Self { Self { allocator: a, node: n, + _cached_sha256_treehash: py.None(), } } } From 62142b962e208c88793f99e0cf734ebf21d920cb Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 28 Sep 2022 17:29:54 -0700 Subject: [PATCH 05/45] More improvements. --- wheel/python/benchmarks/deserialization.py | 77 ++++++++++++++++++++++ wheel/python/clvm_rs/__init__.py | 1 + wheel/python/clvm_rs/base.py | 21 ++++++ wheel/python/clvm_rs/bytes32.py | 1 + wheel/python/clvm_rs/clvm_tree.py | 30 +++++---- wheel/python/clvm_rs/curry_and_treehash.py | 2 +- wheel/python/clvm_rs/deser.py | 22 +++++-- wheel/python/clvm_rs/program.py | 2 +- wheel/python/clvm_rs/serialize.py | 4 +- wheel/python/clvm_rs/tree_hash.py | 74 +++++++++++++++++++++ 10 files changed, 213 insertions(+), 21 deletions(-) create mode 100644 wheel/python/benchmarks/deserialization.py create mode 100644 wheel/python/clvm_rs/base.py create mode 100644 wheel/python/clvm_rs/bytes32.py create mode 100644 wheel/python/clvm_rs/tree_hash.py diff --git a/wheel/python/benchmarks/deserialization.py b/wheel/python/benchmarks/deserialization.py new file mode 100644 index 00000000..f05a1871 --- /dev/null +++ b/wheel/python/benchmarks/deserialization.py @@ -0,0 +1,77 @@ +import io +import time + +from clvm_rs.program import Program + + + +def bench(f, name: str): + start = time.time() + r = f() + end = time.time() + d = end - start + print(f"{name}: {end-start:1.4f} s") + print() + return r + + + +sha_prog = Program.fromhex("ff0bff0180") + +print(sha_prog.run("food")) +#breakpoint() + + +obj = bench(lambda: Program.parse(open("block-2500014.compressed.bin", "rb")), "parse") +bench(lambda: bytes(obj), "to_bytes") + +obj1 = bench(lambda: Program.from_bytes(open("block-2500014.compressed.bin", "rb").read()), "from_bytes") +bench(lambda: bytes(obj1), "to_bytes") + +cost, output = bench(lambda: obj.run_with_cost(0), "run") + +print(f"cost = {cost}") +blob = bench(lambda: print(f"output = {len(bytes(output))}"), "serialize LazyNode") +blob = bench(lambda: bytes(output), "serialize LazyNode again") + +bench(lambda: print(output.tree_hash().hex()), "print run tree hash LazyNode") +bench(lambda: print(output.tree_hash().hex()), "print run tree hash again LazyNode") + +des_output = bench(lambda: Program.from_bytes(blob), "from_bytes output") +bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash") +bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash again") + +reparsed_output = bench(lambda: Program.parse(io.BytesIO(blob)), "parse output") +bench(lambda: print(reparsed_output.tree_hash().hex()), "print parsed tree hash") +bench(lambda: print(reparsed_output.tree_hash().hex()), "print parsed tree hash again") + + +foo = Program.to("foo") +o0 = Program.to((foo, obj)) +o1 = Program.to((foo, obj1)) + + +def compare(): + assert o0 == o1 + + +bench(compare, "compare") + +bench(lambda: bytes(o0), "to_bytes o0") +bench(lambda: bytes(o1), "to_bytes o1") + +bench(lambda: print(o0.tree_hash().hex()), "o0 tree_hash") +bench(lambda: print(o0.tree_hash().hex()), "o0 tree_hash (again)") + +bench(lambda: print(o1.tree_hash().hex()), "o1 tree_hash") +bench(lambda: print(o1.tree_hash().hex()), "o1 tree_hash (again)") + +o2 = Program.to((foo, output)) + +bench(lambda: print(o2.tree_hash().hex()), "o2 tree_hash") +bench(lambda: print(o2.tree_hash().hex()), "o2 tree_hash (again)") + +# start = time.time() +# obj1 = sexp_from_stream(io.BytesIO(out), SExp.to, allow_backrefs=True) +# end = time.time() +# print(end-start) diff --git a/wheel/python/clvm_rs/__init__.py b/wheel/python/clvm_rs/__init__.py index 03ab351e..c418e1ea 100644 --- a/wheel/python/clvm_rs/__init__.py +++ b/wheel/python/clvm_rs/__init__.py @@ -1,5 +1,6 @@ from .clvm_rs import * +from .base import CLVMObject __doc__ = clvm_rs.__doc__ if hasattr(clvm_rs, "__all__"): diff --git a/wheel/python/clvm_rs/base.py b/wheel/python/clvm_rs/base.py new file mode 100644 index 00000000..90982859 --- /dev/null +++ b/wheel/python/clvm_rs/base.py @@ -0,0 +1,21 @@ +from typing import Optional, Protocol, Tuple + + +class CLVMObjectStore(Protocol): + atom: Optional[bytes] + pair: Optional[Tuple["CLVMObjectStore", "CLVMObjectStore"]] + + @classmethod + def new_atom(cls, v: bytes) -> "CLVMObjectStore": + raise NotImplementedError() + + @classmethod + def new_pair(cls, p1, p2) -> "CLVMObjectStore": + raise NotImplementedError() + + +CLVMObject = CLVMObjectStore + + +class PythonHeapCLVMObject(CLVMObjectStore): + pass diff --git a/wheel/python/clvm_rs/bytes32.py b/wheel/python/clvm_rs/bytes32.py new file mode 100644 index 00000000..f5803599 --- /dev/null +++ b/wheel/python/clvm_rs/bytes32.py @@ -0,0 +1 @@ +bytes32 = bytes diff --git a/wheel/python/clvm_rs/clvm_tree.py b/wheel/python/clvm_rs/clvm_tree.py index 2d25417e..42eaad30 100644 --- a/wheel/python/clvm_rs/clvm_tree.py +++ b/wheel/python/clvm_rs/clvm_tree.py @@ -1,4 +1,4 @@ -from .deser import deserialize_as_triples +from .deser import deserialize_as_tuples from typing import List, Optional, Tuple @@ -8,7 +8,7 @@ class CLVMTree: """ This object conforms with the `CLVMObject` protocol. It's optimized for deserialization, and keeps a reference to the serialized blob and to a - list of triples of integers, each of which corresponds to a subtree. + list of tuples of integers, each of which corresponds to a subtree. It turns out every atom serialized to a blob contains a substring that exactly matches that atom, so it ends up being not very wasteful to @@ -32,42 +32,46 @@ class CLVMTree: triple[2]:triple[1]]`) Since each `CLVMTree` subtree keeps a reference to the original - serialized data and the list of triples, no memory is released until all + serialized data and the list of tuples, no memory is released until all objects in the tree are garbage-collected. This happens pretty naturally in well-behaved python code. """ @classmethod def from_bytes(cls, blob: bytes) -> "CLVMTree": - return cls(memoryview(blob), deserialize_as_triples(blob), 0) + return cls(memoryview(blob), deserialize_as_tuples(blob), 0) def __init__( - self, blob: bytes, int_triples: List[Tuple[int, int, int]], index: int + self, + blob: bytes, + int_tuples: List[Tuple[int, int, int, bytes]], + index: int, ): self.blob = blob - self.int_triples = int_triples + self.int_tuples = int_tuples self.index = index + self._cached_sha256_treehash = int_tuples[index][3] @property def atom(self) -> Optional[bytes]: if not hasattr(self, "_atom"): - start, end, atom_offset = self.int_triples[self.index] + start, end, atom_offset, hash = self.int_tuples[self.index] # if `self.blob[start]` is 0xff, it's a pair if self.blob[start] == 0xFF: self._atom = None else: - self._atom = bytes(self.blob[start + atom_offset:end]) + self._atom = self.blob[start + atom_offset : end] return self._atom @property def pair(self) -> Optional[Tuple["CLVMTree", "CLVMTree"]]: if not hasattr(self, "_pair"): - triples = self.int_triples - start, end, right_index = triples[self.index] + tuples = self.int_tuples + start, end, right_index, hash = tuples[self.index] # if `self.blob[start]` is 0xff, it's a pair if self.blob[start] == 0xFF: - left = self.__class__(self.blob, triples, self.index + 1) - right = self.__class__(self.blob, triples, right_index) + left = self.__class__(self.blob, tuples, self.index + 1) + right = self.__class__(self.blob, tuples, right_index) self._pair = (left, right) else: self._pair = None @@ -75,7 +79,7 @@ def pair(self) -> Optional[Tuple["CLVMTree", "CLVMTree"]]: @property def _cached_serialization(self) -> bytes: - start, end, _ = self.int_triples[self.index] + start, end, _, _ = self.int_tuples[self.index] return self.blob[start:end] def __bytes__(self) -> bytes: diff --git a/wheel/python/clvm_rs/curry_and_treehash.py b/wheel/python/clvm_rs/curry_and_treehash.py index 45d50b29..d53f9c2f 100644 --- a/wheel/python/clvm_rs/curry_and_treehash.py +++ b/wheel/python/clvm_rs/curry_and_treehash.py @@ -1,7 +1,7 @@ from typing import List from .bytes32 import bytes32 -from .keywords import NULL, ONE, TWO, Q_KW, A_KW, C_KW +from .keywords import NULL, ONE, Q_KW, A_KW, C_KW from .tree_hash import shatree_atom, shatree_pair diff --git a/wheel/python/clvm_rs/deser.py b/wheel/python/clvm_rs/deser.py index 6b7cec4b..bf9be040 100644 --- a/wheel/python/clvm_rs/deser.py +++ b/wheel/python/clvm_rs/deser.py @@ -1,5 +1,8 @@ from typing import List, Tuple +from .tree_hash import shatree_atom, shatree_pair + + MAX_SINGLE_BYTE = 0x7F CONS_BOX_MARKER = 0xFF @@ -8,11 +11,20 @@ # PAIR: serialize_offset, serialize_end, right_index -def deserialize_as_triples( +def deserialize_as_tuples( blob: bytes, cursor: int = 0 -) -> List[Tuple[int, int, int]]: +) -> List[Tuple[int, int, int, bytes]]: def save_cursor(index, blob, cursor, obj_list, op_stack): - obj_list[index] = (obj_list[index][0], cursor, obj_list[index][2]) + assert blob[obj_list[index][0]] == 0xFF + left_hash = obj_list[index + 1][3] + right_hash = obj_list[obj_list[index][2]][3] + my_hash = shatree_pair(left_hash, right_hash) + obj_list[index] = ( + obj_list[index][0], + cursor, + obj_list[index][2], + my_hash, + ) return cursor def save_index(index, blob, cursor, obj_list, op_stack): @@ -32,7 +44,8 @@ def parse_obj(blob, cursor, obj_list, op_stack): op_stack.append(parse_obj) return cursor + 1 atom_offset, new_cursor = _atom_size_from_cursor(blob, cursor) - obj_list.append((cursor, new_cursor, atom_offset)) + my_hash = shatree_atom(blob[cursor + atom_offset : new_cursor]) + obj_list.append((cursor, new_cursor, atom_offset, my_hash)) return new_cursor obj_list = [] @@ -58,7 +71,6 @@ def _atom_size_from_cursor(blob, cursor) -> Tuple[int, int]: bit_mask >>= 1 size_blob = bytes([b]) if bit_count > 1: - breakpoint() size_blob += blob[cursor + 1 : cursor + bit_count] size = int.from_bytes(size_blob, "big") if size >= 0x400000000: diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index c680ba6b..01451a39 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -10,7 +10,7 @@ from clvm_rs.tree_hash import sha256_treehash from .clvm_tree import CLVMTree from .bytes32 import bytes32 -from .keywords import NULL, ONE, TWO, Q_KW, A_KW, C_KW +from .keywords import NULL, ONE, Q_KW, A_KW, C_KW # from chia.util.hash import std_hash # from chia.util.byte_types import hexstr_to_bytes diff --git a/wheel/python/clvm_rs/serialize.py b/wheel/python/clvm_rs/serialize.py index b6033ef2..0283c1b2 100644 --- a/wheel/python/clvm_rs/serialize.py +++ b/wheel/python/clvm_rs/serialize.py @@ -47,7 +47,9 @@ def atom_to_byte_iterator(as_atom): elif size < 0x2000: size_blob = bytes([0xC0 | (size >> 8), (size >> 0) & 0xFF]) elif size < 0x100000: - size_blob = bytes([0xE0 | (size >> 16), (size >> 8) & 0xFF, (size >> 0) & 0xFF]) + size_blob = bytes( + [0xE0 | (size >> 16), (size >> 8) & 0xFF, (size >> 0) & 0xFF] + ) elif size < 0x8000000: size_blob = bytes( [ diff --git a/wheel/python/clvm_rs/tree_hash.py b/wheel/python/clvm_rs/tree_hash.py new file mode 100644 index 00000000..e71ef098 --- /dev/null +++ b/wheel/python/clvm_rs/tree_hash.py @@ -0,0 +1,74 @@ +""" +This is an implementation of `sha256_treehash`, used to calculate +puzzle hashes in clvm. + +This implementation goes to great pains to be non-recursive so we don't +have to worry about blowing out the python stack. +""" + +from hashlib import sha256 + +from clvm_rs import CLVMObject + +bytes32 = bytes + + +ONE = bytes.fromhex("01") +TWO = bytes.fromhex("02") + + +def shatree_atom(atom: bytes) -> bytes32: + s = sha256() + s.update(ONE) + s.update(atom) + return bytes32(s.digest()) + + +def shatree_pair(left_hash: bytes32, right_hash: bytes32) -> bytes32: + s = sha256() + s.update(TWO) + s.update(left_hash) + s.update(right_hash) + return bytes32(s.digest()) + + +def sha256_treehash(sexp: CLVMObject) -> bytes32: + def handle_sexp(sexp_stack, op_stack) -> None: + sexp = sexp_stack.pop() + r = getattr(sexp, "_cached_sha256_treehash", None) + if r is not None: + sexp_stack.append(r) + elif sexp.pair: + p0, p1 = sexp.pair + sexp_stack.append(p0) + sexp_stack.append(p1) + op_stack.append(handle_pair) + op_stack.append(handle_sexp) + op_stack.append(roll) + op_stack.append(handle_sexp) + else: + r = shatree_atom(sexp.atom) + sexp_stack.append(r) + if hasattr(sexp, "_cached_sha256_treehash"): + sexp._cached_sha256_treehash = r + + def handle_pair(sexp_stack, op_stack) -> None: + p0 = sexp_stack.pop() + p1 = sexp_stack.pop() + r = shatree_pair(p0, p1) + sexp_stack.append(r) + if hasattr(sexp, "_cached_sha256_treehash"): + sexp._cached_sha256_treehash = r + + def roll(sexp_stack, op_stack) -> None: + p0 = sexp_stack.pop() + p1 = sexp_stack.pop() + sexp_stack.append(p0) + sexp_stack.append(p1) + + sexp_stack = [sexp] + op_stack = [handle_sexp] + while len(op_stack) > 0: + op = op_stack.pop() + op(sexp_stack, op_stack) + return bytes32(sexp_stack[0]) From 94080d61eb73f50f56592ff9a1f584847c2b4d58 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Tue, 4 Oct 2022 18:13:06 -0700 Subject: [PATCH 06/45] Improve coverage. --- src/lib.rs | 1 + src/serialize.rs | 821 ++++++++++++++++++++++++++ wheel/python/clvm_rs/base.py | 9 +- wheel/python/clvm_rs/casts.py | 8 +- wheel/python/clvm_rs/deser.py | 7 +- wheel/python/clvm_rs/program.py | 15 +- wheel/python/tests/test_as_python.py | 17 +- wheel/python/tests/test_program.py | 33 ++ wheel/python/tests/test_serialize.py | 25 +- wheel/python/tests/test_to_program.py | 12 +- wheel/src/api.rs | 29 + 11 files changed, 944 insertions(+), 33 deletions(-) create mode 100644 src/serialize.rs diff --git a/src/lib.rs b/src/lib.rs index 34184ae7..9bf0d5ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ pub mod reduction; pub mod run_program; pub mod runtime_dialect; pub mod serde; +pub mod serialize; pub mod sha2; pub mod traverse_path; diff --git a/src/serialize.rs b/src/serialize.rs new file mode 100644 index 00000000..b97b2dfe --- /dev/null +++ b/src/serialize.rs @@ -0,0 +1,821 @@ +use std::io; +use std::io::{Cursor, ErrorKind, Read, Seek, SeekFrom}; + +use crate::allocator::{Allocator, NodePtr, SExp}; +use crate::node::Node; + +const MAX_SINGLE_BYTE: u8 = 0x7f; +const CONS_BOX_MARKER: u8 = 0xff; + +fn bad_encoding() -> io::Error { + io::Error::new(ErrorKind::InvalidInput, "bad encoding") +} + +fn internal_error() -> io::Error { + io::Error::new(ErrorKind::InvalidInput, "internal error") +} + +/// all atoms serialize their contents verbatim. All expect those one-byte atoms +/// from 0x00-0x7f also have a prefix encoding their length. This function +/// writes the correct prefix for an atom of size `size` whose first byte is `atom_0`. +/// If the atom is of size 0, use any placeholder first byte, as it's ignored anyway. + +fn write_atom_encoding_prefix_with_size( + f: &mut dyn io::Write, + atom_0: u8, + size: u64, +) -> io::Result<()> { + if size == 0 { + f.write_all(&[0x80]) + } else if size == 1 && atom_0 < 0x80 { + Ok(()) + } else if size < 0x40 { + f.write_all(&[0x80 | (size as u8)]) + } else if size < 0x2000 { + f.write_all(&[0xc0 | (size >> 8) as u8, size as u8]) + } else if size < 0x10_0000 { + f.write_all(&[ + (0xe0 | (size >> 16)) as u8, + ((size >> 8) & 0xff) as u8, + ((size) & 0xff) as u8, + ]) + } else if size < 0x800_0000 { + f.write_all(&[ + (0xf0 | (size >> 24)) as u8, + ((size >> 16) & 0xff) as u8, + ((size >> 8) & 0xff) as u8, + ((size) & 0xff) as u8, + ]) + } else if size < 0x4_0000_0000 { + f.write_all(&[ + (0xf8 | (size >> 32)) as u8, + ((size >> 24) & 0xff) as u8, + ((size >> 16) & 0xff) as u8, + ((size >> 8) & 0xff) as u8, + ((size) & 0xff) as u8, + ]) + } else { + Err(io::Error::new(ErrorKind::InvalidData, "atom too big")) + } +} + +/// serialize an atom +fn write_atom(f: &mut dyn io::Write, atom: &[u8]) -> io::Result<()> { + let u8_0 = if !atom.is_empty() { atom[0] } else { 0 }; + write_atom_encoding_prefix_with_size(f, u8_0, atom.len() as u64)?; + f.write_all(atom) +} + +/// serialize a node +pub fn node_to_stream(node: &Node, f: &mut dyn io::Write) -> io::Result<()> { + let mut values: Vec = vec![node.node]; + let a = node.allocator; + while !values.is_empty() { + let v = values.pop().unwrap(); + let n = a.sexp(v); + match n { + SExp::Atom(atom_ptr) => { + let atom = a.buf(&atom_ptr); + write_atom(f, atom)?; + } + SExp::Pair(left, right) => { + f.write_all(&[CONS_BOX_MARKER as u8])?; + values.push(right); + values.push(left); + } + } + } + Ok(()) +} + +/// This data structure is used with `parse_triples`, which returns a triple of +/// integer values for each clvm object in a tree. + +#[derive(Debug, PartialEq, Eq)] +pub enum ParsedTriple { + Atom { + start: u64, + end: u64, + atom_offset: u32, + }, + Pair { + start: u64, + end: u64, + right_index: u32, + }, +} + +enum ParseOpRef { + ParseObj, + SaveCursor(usize), + SaveIndex(usize), +} + +fn skip_bytes(f: &mut R, skip_size: u64) -> io::Result { + io::copy(&mut f.by_ref().take(skip_size), &mut io::sink()) +} + +/// parse a serialized clvm object tree to an array of `ParsedTriple` objects + +/// This alternative mechanism of deserialization generates an array of +/// references to each clvm object. A reference contains three values: +/// a start offset within the blob, an end offset, and a third value that +/// is either: an atom offset (relative to the start offset) where the atom +/// data starts (and continues to the end offset); or an index in the array +/// corresponding to the "right" element of the pair (in which case, the +/// "left" element corresponds to the next index in the array). +/// +/// Since these values are offsets into the original buffer, that buffer needs +/// to be kept around to get the original atoms. + +pub fn parse_triples(f: &mut R) -> io::Result> { + let mut r = Vec::new(); + let mut op_stack = vec![ParseOpRef::ParseObj]; + let mut cursor: u64 = 0; + loop { + match op_stack.pop() { + None => { + break; + } + Some(op) => match op { + ParseOpRef::ParseObj => { + let mut b: [u8; 1] = [0]; + f.read_exact(&mut b)?; + let start = cursor as u64; + cursor += 1; + let b = b[0]; + if b == CONS_BOX_MARKER { + let index = r.len(); + let new_obj = ParsedTriple::Pair { + start, + end: 0, + right_index: 0, + }; + r.push(new_obj); + op_stack.push(ParseOpRef::SaveCursor(index)); + op_stack.push(ParseOpRef::ParseObj); + op_stack.push(ParseOpRef::SaveIndex(index)); + op_stack.push(ParseOpRef::ParseObj); + } else { + let (start, end, atom_offset) = { + if b <= MAX_SINGLE_BYTE { + (start, start + 1, 0) + } else { + let (atom_offset, atom_size) = decode_size(f, b)?; + skip_bytes(f, atom_size)?; + let end = start + (atom_offset as u64) + (atom_size as u64); + (start, end, atom_offset as u32) + } + }; + let new_obj = ParsedTriple::Atom { + start, + end, + atom_offset, + }; + cursor = end; + r.push(new_obj); + } + } + ParseOpRef::SaveCursor(index) => { + if let ParsedTriple::Pair { + start, + end: _, + right_index, + } = r[index] + { + r[index] = ParsedTriple::Pair { + start, + end: cursor, + right_index, + }; + } + } + ParseOpRef::SaveIndex(index) => { + if let ParsedTriple::Pair { + start, + end, + right_index: _, + } = r[index] + { + r[index] = ParsedTriple::Pair { + start, + end, + right_index: r.len() as u32, + }; + } + } + }, + } + } + Ok(r) +} + +/// decode the length prefix for an atom. Atoms whose value fit in 7 bits +/// don't have a length prefix, so those should be handled specially and +/// never passed to this function. +fn decode_size(f: &mut R, initial_b: u8) -> io::Result<(u8, u64)> { + debug_assert!((initial_b & 0x80) != 0); + if (initial_b & 0x80) == 0 { + return Err(internal_error()); + } + + let mut atom_start_offset = 0; + let mut bit_mask: u8 = 0x80; + let mut b = initial_b; + while b & bit_mask != 0 { + atom_start_offset += 1; + b &= 0xff ^ bit_mask; + bit_mask >>= 1; + } + let mut size_blob: Vec = Vec::new(); + size_blob.resize(atom_start_offset, 0); + size_blob[0] = b; + if atom_start_offset > 1 { + let remaining_buffer = &mut size_blob[1..atom_start_offset]; + f.read_exact(remaining_buffer)?; + } + // need to convert size_blob to an int + let mut atom_size: u64 = 0; + if size_blob.len() > 6 { + return Err(bad_encoding()); + } + for b in &size_blob { + atom_size <<= 8; + atom_size += *b as u64; + } + if atom_size >= 0x400000000 { + return Err(bad_encoding()); + } + Ok((atom_start_offset as u8, atom_size)) +} + +enum ParseOp { + SExp, + Cons, +} + +/// deserialize a clvm node from a `std::io::Cursor` +pub fn node_from_stream(allocator: &mut Allocator, f: &mut Cursor<&[u8]>) -> io::Result { + let mut values: Vec = Vec::new(); + let mut ops = vec![ParseOp::SExp]; + + let mut b = [0; 1]; + loop { + let op = ops.pop(); + if op.is_none() { + break; + } + match op.unwrap() { + ParseOp::SExp => { + f.read_exact(&mut b)?; + if b[0] == CONS_BOX_MARKER { + ops.push(ParseOp::Cons); + ops.push(ParseOp::SExp); + ops.push(ParseOp::SExp); + } else if b[0] == 0x01 { + values.push(allocator.one()); + } else if b[0] == 0x80 { + values.push(allocator.null()); + } else if b[0] <= MAX_SINGLE_BYTE { + values.push(allocator.new_atom(&b)?); + } else { + let (_prefix_size, blob_size) = decode_size(f, b[0])?; + if (f.get_ref().len()) < blob_size as usize { + return Err(bad_encoding()); + } + let mut blob: Vec = vec![0; blob_size as usize]; + f.read_exact(&mut blob)?; + values.push(allocator.new_atom(&blob)?); + } + } + ParseOp::Cons => { + // cons + let v2 = values.pop(); + let v1 = values.pop(); + values.push(allocator.new_pair(v1.unwrap(), v2.unwrap())?); + } + } + } + Ok(values.pop().unwrap()) +} + +pub fn node_from_bytes(allocator: &mut Allocator, b: &[u8]) -> io::Result { + let mut buffer = Cursor::new(b); + node_from_stream(allocator, &mut buffer) +} + +pub fn node_to_bytes(node: &Node) -> io::Result> { + let mut buffer = Cursor::new(Vec::new()); + + node_to_stream(node, &mut buffer)?; + let vec = buffer.into_inner(); + Ok(vec) +} + +pub fn serialized_length_from_bytes(b: &[u8]) -> io::Result { + let mut f = Cursor::new(b); + let mut ops = vec![ParseOp::SExp]; + let mut b = [0; 1]; + loop { + let op = ops.pop(); + if op.is_none() { + break; + } + match op.unwrap() { + ParseOp::SExp => { + f.read_exact(&mut b)?; + if b[0] == CONS_BOX_MARKER { + // since all we're doing is to determing the length of the + // serialized buffer, we don't need to do anything about + // "cons". So we skip pushing it to lower the pressure on + // the op stack + //ops.push(ParseOp::Cons); + ops.push(ParseOp::SExp); + ops.push(ParseOp::SExp); + } else if b[0] == 0x80 || b[0] <= MAX_SINGLE_BYTE { + // This one byte we just read was the whole atom. + // or the + // special case of NIL + } else { + let (_prefix_size, blob_size) = decode_size(&mut f, b[0])?; + f.seek(SeekFrom::Current(blob_size as i64))?; + if (f.get_ref().len() as u64) < f.position() { + return Err(bad_encoding()); + } + } + } + ParseOp::Cons => { + // cons. No need to construct any structure here. Just keep + // going + } + } + } + Ok(f.position()) +} + +use crate::sha2::{Digest, Sha256}; + +fn hash_atom(buf: &[u8]) -> [u8; 32] { + let mut ctx = Sha256::new(); + ctx.update(&[1_u8]); + ctx.update(buf); + ctx.finalize().into() +} + +fn hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] { + let mut ctx = Sha256::new(); + ctx.update(&[2_u8]); + ctx.update(left); + ctx.update(right); + ctx.finalize().into() +} + +// computes the tree-hash of a CLVM structure in serialized form +pub fn tree_hash_from_stream(f: &mut Cursor<&[u8]>) -> io::Result<[u8; 32]> { + let mut values: Vec<[u8; 32]> = Vec::new(); + let mut ops = vec![ParseOp::SExp]; + + let mut b = [0; 1]; + loop { + let op = ops.pop(); + if op.is_none() { + break; + } + match op.unwrap() { + ParseOp::SExp => { + f.read_exact(&mut b)?; + if b[0] == CONS_BOX_MARKER { + ops.push(ParseOp::Cons); + ops.push(ParseOp::SExp); + ops.push(ParseOp::SExp); + } else if b[0] == 0x80 { + values.push(hash_atom(&[])); + } else if b[0] <= MAX_SINGLE_BYTE { + values.push(hash_atom(&b)); + } else { + let (_, blob_size) = decode_size(f, b[0])?; + let blob = &f.get_ref()[f.position() as usize..]; + if blob.len() < blob_size as usize { + return Err(bad_encoding()); + } + let blob_size = blob_size as u64; + f.set_position(f.position() + blob_size); + values.push(hash_atom(&blob[..blob_size as usize])); + } + } + ParseOp::Cons => { + // cons + let v2 = values.pop(); + let v1 = values.pop(); + values.push(hash_pair(&v1.unwrap(), &v2.unwrap())); + } + } + } + Ok(values.pop().unwrap()) +} + +#[test] +fn test_tree_hash_max_single_byte() { + let mut ctx = Sha256::new(); + ctx.update(&[1_u8]); + ctx.update(&[0x7f_u8]); + let mut cursor = Cursor::<&[u8]>::new(&[0x7f_u8]); + assert_eq!( + tree_hash_from_stream(&mut cursor).unwrap(), + ctx.finalize().as_slice() + ); +} + +#[test] +fn test_tree_hash_one() { + let mut ctx = Sha256::new(); + ctx.update(&[1_u8]); + ctx.update(&[1_u8]); + let mut cursor = Cursor::<&[u8]>::new(&[1_u8]); + assert_eq!( + tree_hash_from_stream(&mut cursor).unwrap(), + ctx.finalize().as_slice() + ); +} + +#[test] +fn test_tree_hash_zero() { + let mut ctx = Sha256::new(); + ctx.update(&[1_u8]); + ctx.update(&[0_u8]); + let mut cursor = Cursor::<&[u8]>::new(&[0_u8]); + assert_eq!( + tree_hash_from_stream(&mut cursor).unwrap(), + ctx.finalize().as_slice() + ); +} + +#[test] +fn test_tree_hash_nil() { + let mut ctx = Sha256::new(); + ctx.update(&[1_u8]); + let mut cursor = Cursor::<&[u8]>::new(&[0x80_u8]); + assert_eq!( + tree_hash_from_stream(&mut cursor).unwrap(), + ctx.finalize().as_slice() + ); +} + +#[test] +fn test_tree_hash_overlong() { + let mut cursor = Cursor::<&[u8]>::new(&[0x8f, 0xff]); + let e = tree_hash_from_stream(&mut cursor).unwrap_err(); + assert_eq!(e.kind(), bad_encoding().kind()); + + let mut cursor = Cursor::<&[u8]>::new(&[0b11001111, 0xff]); + let e = tree_hash_from_stream(&mut cursor).unwrap_err(); + assert_eq!(e.kind(), bad_encoding().kind()); + + let mut cursor = Cursor::<&[u8]>::new(&[0b11001111, 0xff, 0, 0]); + let e = tree_hash_from_stream(&mut cursor).unwrap_err(); + assert_eq!(e.kind(), bad_encoding().kind()); +} + +#[cfg(test)] +use hex::FromHex; + +// these test cases were produced by: + +// from chia.types.blockchain_format.program import Program +// a = Program.to(...) +// print(bytes(a).hex()) +// print(a.get_tree_hash().hex()) + +#[test] +fn test_tree_hash_list() { + // this is the list (1 (2 (3 (4 (5 ()))))) + let buf = Vec::from_hex("ff01ff02ff03ff04ff0580").unwrap(); + let mut cursor = Cursor::<&[u8]>::new(&buf); + assert_eq!( + tree_hash_from_stream(&mut cursor).unwrap().to_vec(), + Vec::from_hex("123190dddde51acfc61f48429a879a7b905d1726a52991f7d63349863d06b1b6").unwrap() + ); +} + +#[test] +fn test_tree_hash_tree() { + // this is the tree ((1, 2), (3, 4)) + let buf = Vec::from_hex("ffff0102ff0304").unwrap(); + let mut cursor = Cursor::<&[u8]>::new(&buf); + assert_eq!( + tree_hash_from_stream(&mut cursor).unwrap().to_vec(), + Vec::from_hex("2824018d148bc6aed0847e2c86aaa8a5407b916169f15b12cea31fa932fc4c8d").unwrap() + ); +} + +#[test] +fn test_tree_hash_tree_large_atom() { + // this is the tree ((1, 2), (3, b"foobar")) + let buf = Vec::from_hex("ffff0102ff0386666f6f626172").unwrap(); + let mut cursor = Cursor::<&[u8]>::new(&buf); + assert_eq!( + tree_hash_from_stream(&mut cursor).unwrap().to_vec(), + Vec::from_hex("b28d5b401bd02b65b7ed93de8e916cfc488738323e568bcca7e032c3a97a12e4").unwrap() + ); +} + +#[test] +fn test_serialized_length_from_bytes() { + assert_eq!( + serialized_length_from_bytes(&[0x7f, 0x00, 0x00, 0x00]).unwrap(), + 1 + ); + assert_eq!( + serialized_length_from_bytes(&[0x80, 0x00, 0x00, 0x00]).unwrap(), + 1 + ); + assert_eq!( + serialized_length_from_bytes(&[0xff, 0x00, 0x00, 0x00]).unwrap(), + 3 + ); + assert_eq!( + serialized_length_from_bytes(&[0xff, 0x01, 0xff, 0x80, 0x80, 0x00]).unwrap(), + 5 + ); + + let e = serialized_length_from_bytes(&[0x8f, 0xff]).unwrap_err(); + assert_eq!(e.kind(), bad_encoding().kind()); + assert_eq!(e.to_string(), "bad encoding"); + + let e = serialized_length_from_bytes(&[0b11001111, 0xff]).unwrap_err(); + assert_eq!(e.kind(), bad_encoding().kind()); + assert_eq!(e.to_string(), "bad encoding"); + + let e = serialized_length_from_bytes(&[0b11001111, 0xff, 0, 0]).unwrap_err(); + assert_eq!(e.kind(), bad_encoding().kind()); + assert_eq!(e.to_string(), "bad encoding"); + + assert_eq!( + serialized_length_from_bytes(&[0x8f, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).unwrap(), + 16 + ); +} + +#[test] +fn test_write_atom_encoding_prefix_with_size() { + let mut buf = Vec::::new(); + assert!(write_atom_encoding_prefix_with_size(&mut buf, 0, 0).is_ok()); + assert_eq!(buf, vec![0x80]); + + for v in 0..0x7f { + let mut buf = Vec::::new(); + assert!(write_atom_encoding_prefix_with_size(&mut buf, v, 1).is_ok()); + assert_eq!(buf, vec![]); + } + + for v in 0x80..0xff { + let mut buf = Vec::::new(); + assert!(write_atom_encoding_prefix_with_size(&mut buf, v, 1).is_ok()); + assert_eq!(buf, vec![0x81]); + } + + for size in 0x1_u8..0x3f_u8 { + let mut buf = Vec::::new(); + assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, size as u64).is_ok()); + assert_eq!(buf, vec![0x80 + size]); + } + + let mut buf = Vec::::new(); + assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0b111111).is_ok()); + assert_eq!(buf, vec![0b10111111]); + + let mut buf = Vec::::new(); + assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0b1000000).is_ok()); + assert_eq!(buf, vec![0b11000000, 0b1000000]); + + let mut buf = Vec::::new(); + assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0xfffff).is_ok()); + assert_eq!(buf, vec![0b11101111, 0xff, 0xff]); + + let mut buf = Vec::::new(); + assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0xffffff).is_ok()); + assert_eq!(buf, vec![0b11110000, 0xff, 0xff, 0xff]); + + let mut buf = Vec::::new(); + assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0xffffffff).is_ok()); + assert_eq!(buf, vec![0b11111000, 0xff, 0xff, 0xff, 0xff]); + + // this is the largest possible atom size + let mut buf = Vec::::new(); + assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0x3ffffffff).is_ok()); + assert_eq!(buf, vec![0b11111011, 0xff, 0xff, 0xff, 0xff]); + + // this is too large + let mut buf = Vec::::new(); + assert!(!write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0x400000000).is_ok()); + + for (size, expected_prefix) in [ + (0x1, vec![0x81]), + (0x2, vec![0x82]), + (0x3f, vec![0xbf]), + (0x40, vec![0xc0, 0x40]), + (0x1fff, vec![0xdf, 0xff]), + (0x2000, vec![0xe0, 0x20, 0x00]), + (0xf_ffff, vec![0xef, 0xff, 0xff]), + (0x10_0000, vec![0xf0, 0x10, 0x00, 0x00]), + (0x7ff_ffff, vec![0xf7, 0xff, 0xff, 0xff]), + (0x800_0000, vec![0xf8, 0x08, 0x00, 0x00, 0x00]), + (0x3_ffff_ffff, vec![0xfb, 0xff, 0xff, 0xff, 0xff]), + ] { + let mut buf = Vec::::new(); + assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, size).is_ok()); + assert_eq!(buf, expected_prefix); + } +} + +#[test] +fn test_write_atom() { + let mut buf = Vec::::new(); + assert!(write_atom(&mut buf, &vec![]).is_ok()); + assert_eq!(buf, vec![0b10000000]); + + let mut buf = Vec::::new(); + assert!(write_atom(&mut buf, &vec![0x00]).is_ok()); + assert_eq!(buf, vec![0b00000000]); + + let mut buf = Vec::::new(); + assert!(write_atom(&mut buf, &vec![0x7f]).is_ok()); + assert_eq!(buf, vec![0x7f]); + + let mut buf = Vec::::new(); + assert!(write_atom(&mut buf, &vec![0x80]).is_ok()); + assert_eq!(buf, vec![0x81, 0x80]); + + let mut buf = Vec::::new(); + assert!(write_atom(&mut buf, &vec![0xff]).is_ok()); + assert_eq!(buf, vec![0x81, 0xff]); + + let mut buf = Vec::::new(); + assert!(write_atom(&mut buf, &vec![0xaa, 0xbb]).is_ok()); + assert_eq!(buf, vec![0x82, 0xaa, 0xbb]); + + for (size, mut expected_prefix) in [ + (0x1, vec![0x81]), + (0x2, vec![0x82]), + (0x3f, vec![0xbf]), + (0x40, vec![0xc0, 0x40]), + (0x1fff, vec![0xdf, 0xff]), + (0x2000, vec![0xe0, 0x20, 0x00]), + (0xf_ffff, vec![0xef, 0xff, 0xff]), + (0x10_0000, vec![0xf0, 0x10, 0x00, 0x00]), + (0x7ff_ffff, vec![0xf7, 0xff, 0xff, 0xff]), + (0x800_0000, vec![0xf8, 0x08, 0x00, 0x00, 0x00]), + // the next one represents 17 GB of memory, which it then has to serialize + // so let's not do it until some time in the future when all machines have + // 64 GB of memory + // (0x3_ffff_ffff, vec![0xfb, 0xff, 0xff, 0xff, 0xff]), + ] { + let mut buf = Vec::::new(); + let atom = vec![0xaa; size]; + assert!(write_atom(&mut buf, &atom).is_ok()); + expected_prefix.extend(atom); + assert_eq!(buf, expected_prefix); + } +} + +#[test] +fn test_decode_size() { + // single-byte length prefix + let mut buffer = Cursor::new(&[]); + assert_eq!(decode_size(&mut buffer, 0x80 | 0x20).unwrap(), (1, 0x20)); + + // two-byte length prefix + let first = 0b11001111; + let mut buffer = Cursor::new(&[0xaa]); + assert_eq!(decode_size(&mut buffer, first).unwrap(), (2, 0xfaa)); +} + +#[test] +fn test_large_decode_size() { + // this is an atom length-prefix 0xffffffffffff, or (2^48 - 1). + // We don't support atoms this large and we should fail before attempting to + // allocate this much memory + let first = 0b11111110; + let mut buffer = Cursor::new(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff]); + let ret = decode_size(&mut buffer, first); + let e = ret.unwrap_err(); + assert_eq!(e.kind(), bad_encoding().kind()); + assert_eq!(e.to_string(), "bad encoding"); + + // this is still too large + let first = 0b11111100; + let mut buffer = Cursor::new(&[0x4, 0, 0, 0, 0]); + let ret = decode_size(&mut buffer, first); + let e = ret.unwrap_err(); + assert_eq!(e.kind(), bad_encoding().kind()); + assert_eq!(e.to_string(), "bad encoding"); + + // But this is *just* within what we support + // Still a very large blob, probably enough for a DoS attack + let first = 0b11111100; + let mut buffer = Cursor::new(&[0x3, 0xff, 0xff, 0xff, 0xff]); + assert_eq!(decode_size(&mut buffer, first).unwrap(), (6, 0x3ffffffff)); +} + +#[test] +fn test_truncated_decode_size() { + // the stream is truncated + let first = 0b11111100; + let mut buffer = Cursor::new(&[0x4, 0, 0, 0]); + let ret = decode_size(&mut buffer, first); + let e = ret.unwrap_err(); + assert_eq!(e.kind(), ErrorKind::UnexpectedEof); +} + +#[cfg(test)] +fn check_parse_triple(h: &str, expected: Vec) -> () { + let b = Vec::from_hex(h).unwrap(); + println!("{:?}", b); + let mut f = Cursor::new(b); + let p = parse_triples(&mut f).unwrap(); + assert_eq!(p, expected); +} + +#[test] +fn test_parse_triple() { + check_parse_triple( + "80", + vec![ParsedTriple::Atom { + start: 0, + end: 1, + atom_offset: 1, + }], + ); + + check_parse_triple( + "ff648200c8", + vec![ + ParsedTriple::Pair { + start: 0, + end: 5, + right_index: 2, + }, + ParsedTriple::Atom { + start: 1, + end: 2, + atom_offset: 0, + }, + ParsedTriple::Atom { + start: 2, + end: 5, + atom_offset: 1, + }, + ], + ); + + check_parse_triple( + "ff83666f6fff83626172ff8362617a80", // `(foo bar baz)` + vec![ + ParsedTriple::Pair { + start: 0, + end: 16, + right_index: 2, + }, + ParsedTriple::Atom { + start: 1, + end: 5, + atom_offset: 1, + }, + ParsedTriple::Pair { + start: 5, + end: 16, + right_index: 4, + }, + ParsedTriple::Atom { + start: 6, + end: 10, + atom_offset: 1, + }, + ParsedTriple::Pair { + start: 10, + end: 16, + right_index: 6, + }, + ParsedTriple::Atom { + start: 11, + end: 15, + atom_offset: 1, + }, + ParsedTriple::Atom { + start: 15, + end: 16, + atom_offset: 1, + }, + ], + ); + + let s = "c0a0".to_owned() + &hex::encode([0x31u8; 160]); + check_parse_triple( + &s, + vec![ParsedTriple::Atom { + start: 0, + end: 162, + atom_offset: 2, + }], + ); +} diff --git a/wheel/python/clvm_rs/base.py b/wheel/python/clvm_rs/base.py index 90982859..b64ed62f 100644 --- a/wheel/python/clvm_rs/base.py +++ b/wheel/python/clvm_rs/base.py @@ -1,7 +1,7 @@ from typing import Optional, Protocol, Tuple -class CLVMObjectStore(Protocol): +class CLVMObject(Protocol): atom: Optional[bytes] pair: Optional[Tuple["CLVMObjectStore", "CLVMObjectStore"]] @@ -12,10 +12,3 @@ def new_atom(cls, v: bytes) -> "CLVMObjectStore": @classmethod def new_pair(cls, p1, p2) -> "CLVMObjectStore": raise NotImplementedError() - - -CLVMObject = CLVMObjectStore - - -class PythonHeapCLVMObject(CLVMObjectStore): - pass diff --git a/wheel/python/clvm_rs/casts.py b/wheel/python/clvm_rs/casts.py index 84d86828..38972507 100644 --- a/wheel/python/clvm_rs/casts.py +++ b/wheel/python/clvm_rs/casts.py @@ -27,8 +27,7 @@ def int_to_bytes(v): r = v.to_bytes(byte_count, "big", signed=True) # make sure the string returned is minimal # ie. no leading 00 or ff bytes that are unnecessary - while len(r) > 1 and r[0] == (0xFF if r[1] & 0x80 else 0): - r = r[1:] + assert not (len(r) > 1 and r[0] == (0xFF if r[1] & 0x80 else 0)) return r @@ -47,8 +46,6 @@ def to_atom_type(v: AtomCastableType) -> bytes: return bytes(v) if v is None: return NULL - if v == []: - return NULL raise ValueError("can't cast %s (%s) to bytes" % (type(v), v)) @@ -121,8 +118,7 @@ def to_clvm_object( stack.append(obj) continue # there's exactly one item left at this point - if len(stack) != 1: - raise ValueError("internal error") + assert len(stack) == 1 # stack[0] implements the clvm object protocol return stack[0] diff --git a/wheel/python/clvm_rs/deser.py b/wheel/python/clvm_rs/deser.py index bf9be040..f2ea9d92 100644 --- a/wheel/python/clvm_rs/deser.py +++ b/wheel/python/clvm_rs/deser.py @@ -73,6 +73,7 @@ def _atom_size_from_cursor(blob, cursor) -> Tuple[int, int]: if bit_count > 1: size_blob += blob[cursor + 1 : cursor + bit_count] size = int.from_bytes(size_blob, "big") - if size >= 0x400000000: - raise ValueError("blob too large") - return bit_count, cursor + size + bit_count + new_cursor = cursor + size + bit_count + if new_cursor > len(blob): + raise ValueError("end of stream") + return bit_count, new_cursor diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index 01451a39..b480a53e 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -5,8 +5,9 @@ # from clvm import Program from .base import CLVMObject from .casts import to_clvm_object +from .EvalError import EvalError from clvm_rs.clvm_rs import run_serialized_program -from clvm_rs.serialize import sexp_from_stream, sexp_to_stream +from clvm_rs.serialize import sexp_from_stream, sexp_to_stream, sexp_to_bytes from clvm_rs.tree_hash import sha256_treehash from .clvm_tree import CLVMTree from .bytes32 import bytes32 @@ -56,9 +57,7 @@ def fromhex(cls, hexstr: str) -> Program: return cls.from_bytes(bytes.fromhex(hexstr)) def __bytes__(self) -> bytes: - f = io.BytesIO() - self.stream(f) # noqa - return f.getvalue() + return sexp_to_bytes(self) # high level casting with `.to` @@ -227,7 +226,10 @@ def run_with_cost( prog_bytes = bytes(self) args_bytes = bytes(self.to(args)) cost, r = run_serialized_program(prog_bytes, args_bytes, max_cost, 0) - return cost, Program.to(r) + r = Program.to(r) + if isinstance(cost, str): + raise EvalError(cost, r) + return cost, r def run(self, args) -> "Program": cost, r = self.run_with_cost(args, MAX_COST) @@ -313,9 +315,6 @@ def as_atom_list(self) -> List[bytes]: """ return list(self.as_atom_iter()) - def __deepcopy__(self, memo): - return type(self).from_bytes(bytes(self)) - NULL_PROGRAM = Program.fromhex("80") diff --git a/wheel/python/tests/test_as_python.py b/wheel/python/tests/test_as_python.py index b6760410..fe38ca30 100644 --- a/wheel/python/tests/test_as_python.py +++ b/wheel/python/tests/test_as_python.py @@ -6,8 +6,7 @@ class dummy_class: - def __init__(self): - self.i = 0 + pass def gen_tree(depth: int) -> Program: @@ -41,9 +40,19 @@ def test_short_lists(self): for size in range(1, 5): self.check_as_atom_list([bytes([_])] * size) + def test_non_list(self): + v = Program.to([1, 2, 3, (5, 6), 7]) + p1 = v.as_atom_list() + expected = [Program.to(_).atom for _ in [1, 2, 3]] + self.assertEqual(p1, expected) + def test_int(self): v = Program.to(42) self.assertEqual(v.atom, bytes([42])) + self.assertEqual(v.as_int(), 42) + v = Program.to(0) + self.assertEqual(v.atom, bytes([])) + self.assertEqual(v.as_int(), 0) def test_none(self): v = Program.to(None) @@ -149,10 +158,10 @@ def test_long_list(self): def test_invalid_tuple(self): with self.assertRaises(ValueError): - s = Program.to((dummy_class, dummy_class)) + Program.to((dummy_class, dummy_class)) with self.assertRaises(ValueError): - s = Program.to((dummy_class, dummy_class, dummy_class)) + Program.to((dummy_class, dummy_class, dummy_class)) def test_clvm_object_tuple(self): o1 = Program.to(b"foo") diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py index aa7c37c1..2299f66a 100644 --- a/wheel/python/tests/test_program.py +++ b/wheel/python/tests/test_program.py @@ -1,5 +1,6 @@ from unittest import TestCase +from clvm_rs import CLVMObject from clvm_rs.program import Program from clvm_rs.EvalError import EvalError @@ -44,6 +45,38 @@ def test_replace_bad_path(self): self.assertRaises(ValueError, lambda: p1.replace(q=105)) self.assertRaises(ValueError, lambda: p1.replace(rq=105)) + def test_protocol(self): + nil = Program.to(0) + self.assertRaises(NotImplementedError, lambda: CLVMObject.new_atom(nil)) + self.assertRaises(NotImplementedError, lambda: CLVMObject.new_pair(nil, nil)) + + def test_first_rest(self): + p = Program.to([4, 5]) + self.assertEqual(p.first(), 4) + self.assertEqual(p.rest(), [5]) + p = Program.to(4) + self.assertEqual(p.as_pair(), None) + self.assertEqual(p.first(), None) + self.assertEqual(p.rest(), None) + + def test_simple_run(self): + p = Program.fromhex("ff10ff02ff0580") # `(+ 2 5)` + args = Program.fromhex("ff32ff3c80") # `(50 60)` + r = p.run(args) + self.assertEqual(r, 110) + + def test_run_exception(self): + p = Program.fromhex( + "ff08ffff0183666f6fffff018362617280" + ) # `(x (q . foo) (q . bar))` + err = None + try: + p.run(p) + except EvalError as ee: + err = ee + self.assertEqual(err.args, ("clvm raise",)) + self.assertEqual(err._sexp, ["foo", "bar"]) + def check_idempotency(p, *args): curried = p.curry(*args) diff --git a/wheel/python/tests/test_serialize.py b/wheel/python/tests/test_serialize.py index 0a44f83f..294817ab 100644 --- a/wheel/python/tests/test_serialize.py +++ b/wheel/python/tests/test_serialize.py @@ -31,6 +31,10 @@ class SerializeTest(unittest.TestCase): def check_serde(self, s): v = Program.to(s) b = bytes(v) + f = io.BytesIO() + v.stream(f) + b1 = f.getvalue() + self.assertEqual(b, b1) v1 = Program.parse(io.BytesIO(b)) if v != v1: print("%s: %d %s %s" % (v, len(b), b, v1)) @@ -73,8 +77,7 @@ def test_long_blobs(self): def test_blob_limit(self): with self.assertRaises(ValueError): - for b in atom_to_byte_iterator(LargeAtom()): - print("%02x" % b) + next(atom_to_byte_iterator(LargeAtom())) def test_very_long_blobs(self): for size in [0x40, 0x2000, 0x100000, 0x8000000]: @@ -124,3 +127,21 @@ def test_deserialize_large_blob(self): with self.assertRaises(ValueError): Program.parse(InfiniteStream(bytes_in)) + + def test_repr_clvm_tree(self): + o = Program.fromhex("ff8080") + self.assertEqual(repr(o.unwrap()), "") + + def test_bad_blob(self): + self.assertRaises(ValueError, lambda: Program.fromhex("ff")) + + def test_large_atom(self): + s = "foo" * 100 + p = Program.to(s) + blob = bytes(p) + p1 = Program.from_bytes(blob) + self.assertEqual(p, p1) + + def test_too_large_atom(self): + self.assertRaises(ValueError, lambda: Program.fromhex("fc")) + self.assertRaises(ValueError, lambda: Program.fromhex("fc8000000000")) diff --git a/wheel/python/tests/test_to_program.py b/wheel/python/tests/test_to_program.py index cb092614..b95ec0e3 100644 --- a/wheel/python/tests/test_to_program.py +++ b/wheel/python/tests/test_to_program.py @@ -117,6 +117,9 @@ def pair(self) -> Optional[Tuple[Any, Any]]: tree = Program.to(GeneratedTree(3, 10)) assert print_leaves(tree) == "10 11 12 13 14 15 16 17 " + # this is just for `coverage` + assert print_leaves(Program.to(0)) == "() " + def test_looks_like_clvm_object(self): # this function can't look at the values, that would cause a cascade of @@ -174,9 +177,14 @@ def test_convert_atom(self): assert convert_atom_to_bytes(None) == b"" assert convert_atom_to_bytes([]) == b"" - assert convert_atom_to_bytes([1, 2, 3]) == None + assert convert_atom_to_bytes([1, 2, 3]) is None - assert convert_atom_to_bytes((1, 2)) == None + assert convert_atom_to_bytes((1, 2)) is None with self.assertRaises(ValueError): assert convert_atom_to_bytes({}) + + def test_to_nil(self): + self.assertEqual(Program.to([]), 0) + self.assertEqual(Program.to(0), 0) + self.assertEqual(Program.to(b""), 0) diff --git a/wheel/src/api.rs b/wheel/src/api.rs index 1611462c..5c316aca 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -1,3 +1,5 @@ +use std::io; + use super::lazy_node::LazyNode; use crate::adapt_response::adapt_response; use clvmr::allocator::Allocator; @@ -6,8 +8,10 @@ use clvmr::cost::Cost; use clvmr::reduction::Response; use clvmr::run_program::run_program; use clvmr::serde::{node_from_bytes, serialized_length_from_bytes}; +use clvmr::serialize::{parse_triples, ParsedTriple}; use clvmr::{LIMIT_HEAP, LIMIT_STACK, MEMPOOL_MODE, NO_UNKNOWN_OPS}; use pyo3::prelude::*; +use pyo3::types::PyTuple; use pyo3::wrap_pyfunction; #[pyfunction] @@ -39,10 +43,35 @@ pub fn run_serialized_chia_program( adapt_response(py, allocator, r) } +fn tuple_for_parsed_triple(py: Python<'_>, p: &ParsedTriple) -> PyObject { + let tuple = match p { + ParsedTriple::Atom { + start, + end, + atom_offset, + } => PyTuple::new(py, [*start, *end, *atom_offset as u64]), + ParsedTriple::Pair { + start, + end, + right_index, + } => PyTuple::new(py, [*start, *end, *right_index as u64]), + }; + tuple.into_py(py) +} + +#[pyfunction] +fn deserialize_as_triples(py: Python, blob: &[u8]) -> PyResult> { + let mut cursor = io::Cursor::new(blob); + let r = parse_triples(&mut cursor)?; + let r = r.iter().map(|pt| tuple_for_parsed_triple(py, pt)).collect(); + Ok(r) +} + #[pymodule] fn clvm_rs(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(run_serialized_chia_program, m)?)?; m.add_function(wrap_pyfunction!(serialized_length, m)?)?; + m.add_function(wrap_pyfunction!(deserialize_as_triples, m)?)?; m.add("NO_UNKNOWN_OPS", NO_UNKNOWN_OPS)?; m.add("LIMIT_HEAP", LIMIT_HEAP)?; From e6255a3d3d93a020120b3599deffd9b3810fe682 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Sat, 8 Oct 2022 11:32:28 -0700 Subject: [PATCH 07/45] Make deserialization time tree-hashing optional. --- wheel/python/clvm_rs/__init__.py | 8 +------- wheel/python/clvm_rs/base.py | 6 +++--- wheel/python/clvm_rs/clvm_tree.py | 21 ++++++++++++--------- wheel/python/clvm_rs/deser.py | 26 ++++++++++++++++---------- wheel/python/clvm_rs/program.py | 30 +++++------------------------- wheel/python/clvm_rs/serialize.py | 4 +--- 6 files changed, 38 insertions(+), 57 deletions(-) diff --git a/wheel/python/clvm_rs/__init__.py b/wheel/python/clvm_rs/__init__.py index c418e1ea..1250efbc 100644 --- a/wheel/python/clvm_rs/__init__.py +++ b/wheel/python/clvm_rs/__init__.py @@ -1,7 +1 @@ -from .clvm_rs import * - -from .base import CLVMObject - -__doc__ = clvm_rs.__doc__ -if hasattr(clvm_rs, "__all__"): - __all__ = clvm_rs.__all__ +from .base import CLVMObject # noqa: F401 diff --git a/wheel/python/clvm_rs/base.py b/wheel/python/clvm_rs/base.py index b64ed62f..cc492194 100644 --- a/wheel/python/clvm_rs/base.py +++ b/wheel/python/clvm_rs/base.py @@ -3,12 +3,12 @@ class CLVMObject(Protocol): atom: Optional[bytes] - pair: Optional[Tuple["CLVMObjectStore", "CLVMObjectStore"]] + pair: Optional[Tuple["CLVMObject", "CLVMObject"]] @classmethod - def new_atom(cls, v: bytes) -> "CLVMObjectStore": + def new_atom(cls, v: bytes) -> "CLVMObject": raise NotImplementedError() @classmethod - def new_pair(cls, p1, p2) -> "CLVMObjectStore": + def new_pair(cls, p1, p2) -> "CLVMObject": raise NotImplementedError() diff --git a/wheel/python/clvm_rs/clvm_tree.py b/wheel/python/clvm_rs/clvm_tree.py index 42eaad30..b68f6c78 100644 --- a/wheel/python/clvm_rs/clvm_tree.py +++ b/wheel/python/clvm_rs/clvm_tree.py @@ -39,23 +39,26 @@ class CLVMTree: @classmethod def from_bytes(cls, blob: bytes) -> "CLVMTree": - return cls(memoryview(blob), deserialize_as_tuples(blob), 0) + int_tuples, tree_hashes = deserialize_as_tuples(blob) + return cls(memoryview(blob), int_tuples, tree_hashes, 0) def __init__( self, blob: bytes, - int_tuples: List[Tuple[int, int, int, bytes]], + int_tuples: List[Tuple[int, int, int]], + tree_hashes: List[Optional[bytes]], index: int, ): self.blob = blob self.int_tuples = int_tuples + self.tree_hashes = tree_hashes self.index = index - self._cached_sha256_treehash = int_tuples[index][3] + self._cached_sha256_treehash = self.tree_hashes[index] @property def atom(self) -> Optional[bytes]: if not hasattr(self, "_atom"): - start, end, atom_offset, hash = self.int_tuples[self.index] + start, end, atom_offset = self.int_tuples[self.index] # if `self.blob[start]` is 0xff, it's a pair if self.blob[start] == 0xFF: self._atom = None @@ -66,12 +69,12 @@ def atom(self) -> Optional[bytes]: @property def pair(self) -> Optional[Tuple["CLVMTree", "CLVMTree"]]: if not hasattr(self, "_pair"): - tuples = self.int_tuples - start, end, right_index, hash = tuples[self.index] + tuples, tree_hashes = self.int_tuples, self.tree_hashes + start, end, right_index = tuples[self.index] # if `self.blob[start]` is 0xff, it's a pair if self.blob[start] == 0xFF: - left = self.__class__(self.blob, tuples, self.index + 1) - right = self.__class__(self.blob, tuples, right_index) + left = self.__class__(self.blob, tuples, tree_hashes, self.index + 1) + right = self.__class__(self.blob, tuples, tree_hashes, right_index) self._pair = (left, right) else: self._pair = None @@ -79,7 +82,7 @@ def pair(self) -> Optional[Tuple["CLVMTree", "CLVMTree"]]: @property def _cached_serialization(self) -> bytes: - start, end, _, _ = self.int_tuples[self.index] + start, end, _ = self.int_tuples[self.index] return self.blob[start:end] def __bytes__(self) -> bytes: diff --git a/wheel/python/clvm_rs/deser.py b/wheel/python/clvm_rs/deser.py index f2ea9d92..f0f84f03 100644 --- a/wheel/python/clvm_rs/deser.py +++ b/wheel/python/clvm_rs/deser.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Optional, Tuple from .tree_hash import shatree_atom, shatree_pair @@ -12,18 +12,19 @@ def deserialize_as_tuples( - blob: bytes, cursor: int = 0 -) -> List[Tuple[int, int, int, bytes]]: + blob: bytes, cursor: int = 0, calculate_tree_hash: bool = True +) -> Tuple[List[Tuple[int, int, int]], List[Optional[bytes]]]: def save_cursor(index, blob, cursor, obj_list, op_stack): assert blob[obj_list[index][0]] == 0xFF - left_hash = obj_list[index + 1][3] - right_hash = obj_list[obj_list[index][2]][3] - my_hash = shatree_pair(left_hash, right_hash) + left_hash = tree_hash_list[index + 1] + right_hash = tree_hash_list[obj_list[index][2]] + tree_hash_list[index] = None + if calculate_tree_hash: + tree_hash_list[index] = shatree_pair(left_hash, right_hash) obj_list[index] = ( obj_list[index][0], cursor, obj_list[index][2], - my_hash, ) return cursor @@ -37,6 +38,7 @@ def parse_obj(blob, cursor, obj_list, op_stack): if blob[cursor] == CONS_BOX_MARKER: index = len(obj_list) + tree_hash_list.append(None) obj_list.append([cursor, None, None]) op_stack.append(lambda *args: save_cursor(index, *args)) op_stack.append(parse_obj) @@ -44,16 +46,20 @@ def parse_obj(blob, cursor, obj_list, op_stack): op_stack.append(parse_obj) return cursor + 1 atom_offset, new_cursor = _atom_size_from_cursor(blob, cursor) - my_hash = shatree_atom(blob[cursor + atom_offset : new_cursor]) - obj_list.append((cursor, new_cursor, atom_offset, my_hash)) + my_hash = None + if calculate_tree_hash: + my_hash = shatree_atom(blob[cursor + atom_offset : new_cursor]) + tree_hash_list.append(my_hash) + obj_list.append((cursor, new_cursor, atom_offset)) return new_cursor obj_list = [] + tree_hash_list = [] op_stack = [parse_obj] while op_stack: f = op_stack.pop() cursor = f(blob, cursor, obj_list, op_stack) - return obj_list + return obj_list, tree_hash_list def _atom_size_from_cursor(blob, cursor) -> Tuple[int, int]: diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index b480a53e..4ef560f4 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -1,8 +1,6 @@ from __future__ import annotations -import io from typing import Dict, Iterator, List, Tuple, Optional, Any -# from clvm import Program from .base import CLVMObject from .casts import to_clvm_object from .EvalError import EvalError @@ -13,10 +11,6 @@ from .bytes32 import bytes32 from .keywords import NULL, ONE, Q_KW, A_KW, C_KW -# from chia.util.hash import std_hash -# from chia.util.byte_types import hexstr_to_bytes -# from chia.types.spend_bundle_conditions import SpendBundleConditions - MAX_COST = 0x7FFFFFFFFFFFFFFF @@ -41,9 +35,7 @@ def from_bytes(cls, blob: bytes) -> Program: return obj @classmethod - def from_bytes_with_cursor( - cls, blob: bytes, cursor: int - ) -> Tuple[Program, int]: + def from_bytes_with_cursor(cls, blob: bytes, cursor: int) -> Tuple[Program, int]: tree = CLVMTree.from_bytes(blob) if tree.atom is not None: obj = cls.new_atom(tree.atom) @@ -220,9 +212,7 @@ def replace(self, **kwargs) -> "Program": def tree_hash(self) -> bytes32: return sha256_treehash(self.unwrap()) - def run_with_cost( - self, args, max_cost: int = MAX_COST - ) -> Tuple[int, "Program"]: + def run_with_cost(self, args, max_cost: int = MAX_COST) -> Tuple[int, "Program"]: prog_bytes = bytes(self) args_bytes = bytes(self.to(args)) cost, r = run_serialized_program(prog_bytes, args_bytes, max_cost, 0) @@ -258,21 +248,13 @@ def curry(self, *args) -> "Program": return Program.to([2, (1, self), fixed_args]) def uncurry(self) -> Tuple[Program, Optional[Program]]: - if ( - self.at("f") != A_KW - or self.at("rff") != Q_KW - or self.at("rrr") != NULL - ): + if self.at("f") != A_KW or self.at("rff") != Q_KW or self.at("rrr") != NULL: return self, None uncurried_function = self.at("rfr") core_items = [] core = self.at("rrf") while core != ONE: - if ( - core.at("f") != C_KW - or core.at("rff") != Q_KW - or core.at("rrr") != NULL - ): + if core.at("f") != C_KW or core.at("rff") != Q_KW or core.at("rrr") != NULL: return self, None new_item = core.at("rfr") core_items.append(new_item) @@ -337,9 +319,7 @@ def _replace(program: Program, **kwargs) -> Program: for k, v in kwargs.items(): c = k[0] if c not in "fr": - raise ValueError( - f"bad path containing {c}: must only contain `f` and `r`" - ) + raise ValueError(f"bad path containing {c}: must only contain `f` and `r`") args_by_prefix.setdefault(c, dict())[k[1:]] = v pair = program.pair diff --git a/wheel/python/clvm_rs/serialize.py b/wheel/python/clvm_rs/serialize.py index 0283c1b2..b6033ef2 100644 --- a/wheel/python/clvm_rs/serialize.py +++ b/wheel/python/clvm_rs/serialize.py @@ -47,9 +47,7 @@ def atom_to_byte_iterator(as_atom): elif size < 0x2000: size_blob = bytes([0xC0 | (size >> 8), (size >> 0) & 0xFF]) elif size < 0x100000: - size_blob = bytes( - [0xE0 | (size >> 16), (size >> 8) & 0xFF, (size >> 0) & 0xFF] - ) + size_blob = bytes([0xE0 | (size >> 16), (size >> 8) & 0xFF, (size >> 0) & 0xFF]) elif size < 0x8000000: size_blob = bytes( [ From d797d660c35527f2f6a7f2f0868cbfef670fcf3b Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Sat, 8 Oct 2022 11:39:45 -0700 Subject: [PATCH 08/45] Tree hash on deserialization is now optional. --- wheel/python/benchmarks/deserialization.py | 20 ++++++++++++++----- wheel/python/clvm_rs/clvm_tree.py | 6 ++++-- wheel/python/clvm_rs/deser.py | 2 +- wheel/python/clvm_rs/program.py | 12 +++++++---- wheel/python/tests/test_curry_and_treehash.py | 5 ++++- 5 files changed, 32 insertions(+), 13 deletions(-) diff --git a/wheel/python/benchmarks/deserialization.py b/wheel/python/benchmarks/deserialization.py index f05a1871..fbe381d5 100644 --- a/wheel/python/benchmarks/deserialization.py +++ b/wheel/python/benchmarks/deserialization.py @@ -4,7 +4,6 @@ from clvm_rs.program import Program - def bench(f, name: str): start = time.time() r = f() @@ -15,17 +14,19 @@ def bench(f, name: str): return r - sha_prog = Program.fromhex("ff0bff0180") print(sha_prog.run("food")) -#breakpoint() +# breakpoint() obj = bench(lambda: Program.parse(open("block-2500014.compressed.bin", "rb")), "parse") bench(lambda: bytes(obj), "to_bytes") -obj1 = bench(lambda: Program.from_bytes(open("block-2500014.compressed.bin", "rb").read()), "from_bytes") +obj1 = bench( + lambda: Program.from_bytes(open("block-2500014.compressed.bin", "rb").read()), + "from_bytes", +) bench(lambda: bytes(obj1), "to_bytes") cost, output = bench(lambda: obj.run_with_cost(0), "run") @@ -37,7 +38,16 @@ def bench(f, name: str): bench(lambda: print(output.tree_hash().hex()), "print run tree hash LazyNode") bench(lambda: print(output.tree_hash().hex()), "print run tree hash again LazyNode") -des_output = bench(lambda: Program.from_bytes(blob), "from_bytes output") +des_output = bench( + lambda: Program.from_bytes(blob), "from_bytes output (with tree hashing)" +) +bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash") +bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash again") + +des_output = bench( + lambda: Program.from_bytes(blob, calculate_tree_hash=False), + "from_bytes output (with no tree hashing)", +) bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash") bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash again") diff --git a/wheel/python/clvm_rs/clvm_tree.py b/wheel/python/clvm_rs/clvm_tree.py index b68f6c78..9b213035 100644 --- a/wheel/python/clvm_rs/clvm_tree.py +++ b/wheel/python/clvm_rs/clvm_tree.py @@ -38,8 +38,10 @@ class CLVMTree: """ @classmethod - def from_bytes(cls, blob: bytes) -> "CLVMTree": - int_tuples, tree_hashes = deserialize_as_tuples(blob) + def from_bytes(cls, blob: bytes, calculate_tree_hash: bool = True) -> "CLVMTree": + int_tuples, tree_hashes = deserialize_as_tuples( + blob, 0, calculate_tree_hash=calculate_tree_hash + ) return cls(memoryview(blob), int_tuples, tree_hashes, 0) def __init__( diff --git a/wheel/python/clvm_rs/deser.py b/wheel/python/clvm_rs/deser.py index f0f84f03..414c211a 100644 --- a/wheel/python/clvm_rs/deser.py +++ b/wheel/python/clvm_rs/deser.py @@ -12,7 +12,7 @@ def deserialize_as_tuples( - blob: bytes, cursor: int = 0, calculate_tree_hash: bool = True + blob: bytes, cursor: int, calculate_tree_hash: bool ) -> Tuple[List[Tuple[int, int, int]], List[Optional[bytes]]]: def save_cursor(index, blob, cursor, obj_list, op_stack): assert blob[obj_list[index][0]] == 0xFF diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index 4ef560f4..7de86693 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -30,13 +30,17 @@ def stream(self, f): sexp_to_stream(self, f) @classmethod - def from_bytes(cls, blob: bytes) -> Program: - obj, cursor = cls.from_bytes_with_cursor(blob, 0) + def from_bytes(cls, blob: bytes, calculate_tree_hash: bool = True) -> Program: + obj, cursor = cls.from_bytes_with_cursor( + blob, 0, calculate_tree_hash=calculate_tree_hash + ) return obj @classmethod - def from_bytes_with_cursor(cls, blob: bytes, cursor: int) -> Tuple[Program, int]: - tree = CLVMTree.from_bytes(blob) + def from_bytes_with_cursor( + cls, blob: bytes, cursor: int, calculate_tree_hash: bool = True + ) -> Tuple[Program, int]: + tree = CLVMTree.from_bytes(blob, calculate_tree_hash=calculate_tree_hash) if tree.atom is not None: obj = cls.new_atom(tree.atom) else: diff --git a/wheel/python/tests/test_curry_and_treehash.py b/wheel/python/tests/test_curry_and_treehash.py index 3bf8b7b5..17ad15bb 100644 --- a/wheel/python/tests/test_curry_and_treehash.py +++ b/wheel/python/tests/test_curry_and_treehash.py @@ -1,5 +1,8 @@ from clvm_rs.program import Program -from clvm_rs.curry_and_treehash import calculate_hash_of_quoted_mod_hash, curry_and_treehash +from clvm_rs.curry_and_treehash import ( + calculate_hash_of_quoted_mod_hash, + curry_and_treehash, +) def test_curry_and_treehash() -> None: From f7f5eef40b36265192bfc102e1a08b69c079e92c Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Mon, 17 Oct 2022 16:54:57 -0700 Subject: [PATCH 09/45] Refactor, rename. --- src/deserialize_tree.rs | 228 ++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 +- src/serde/mod.rs | 1 + src/serialize.rs | 217 +------------------------------------- wheel/src/api.rs | 13 +-- 5 files changed, 238 insertions(+), 223 deletions(-) create mode 100644 src/deserialize_tree.rs diff --git a/src/deserialize_tree.rs b/src/deserialize_tree.rs new file mode 100644 index 00000000..eba211d5 --- /dev/null +++ b/src/deserialize_tree.rs @@ -0,0 +1,228 @@ +use std::io; +use std::io::Read; + +use crate::serde::decode_size; + +const MAX_SINGLE_BYTE: u8 = 0x7f; +const CONS_BOX_MARKER: u8 = 0xff; + +/// This data structure is used with `deserialize_tree`, which returns a triple of +/// integer values for each clvm object in a tree. + +#[derive(Debug, PartialEq, Eq)] +pub enum CLVMTreeBoundary { + Atom { + start: u64, + end: u64, + atom_offset: u32, + }, + Pair { + start: u64, + end: u64, + right_index: u32, + }, +} + +enum ParseOpRef { + ParseObj, + SaveCursor(usize), + SaveIndex(usize), +} + +fn skip_bytes(f: &mut R, skip_size: u64) -> io::Result { + io::copy(&mut f.by_ref().take(skip_size), &mut io::sink()) +} + +/// parse a serialized clvm object tree to an array of `CLVMTreeBoundary` objects + +/// This alternative mechanism of deserialization generates an array of +/// references to each clvm object. A reference contains three values: +/// a start offset within the blob, an end offset, and a third value that +/// is either: an atom offset (relative to the start offset) where the atom +/// data starts (and continues to the end offset); or an index in the array +/// corresponding to the "right" element of the pair (in which case, the +/// "left" element corresponds to the next index in the array). +/// +/// Since these values are offsets into the original buffer, that buffer needs +/// to be kept around to get the original atoms. + +pub fn deserialize_tree(f: &mut R) -> io::Result> { + let mut r = Vec::new(); + let mut op_stack = vec![ParseOpRef::ParseObj]; + let mut cursor: u64 = 0; + loop { + match op_stack.pop() { + None => { + break; + } + Some(op) => match op { + ParseOpRef::ParseObj => { + let mut b: [u8; 1] = [0]; + f.read_exact(&mut b)?; + let start = cursor as u64; + cursor += 1; + let b = b[0]; + if b == CONS_BOX_MARKER { + let index = r.len(); + let new_obj = CLVMTreeBoundary::Pair { + start, + end: 0, + right_index: 0, + }; + r.push(new_obj); + op_stack.push(ParseOpRef::SaveCursor(index)); + op_stack.push(ParseOpRef::ParseObj); + op_stack.push(ParseOpRef::SaveIndex(index)); + op_stack.push(ParseOpRef::ParseObj); + } else { + let (start, end, atom_offset) = { + if b <= MAX_SINGLE_BYTE { + (start, start + 1, 0) + } else { + let (atom_offset, atom_size) = decode_size(f, b)?; + skip_bytes(f, atom_size)?; + let end = start + (atom_offset as u64) + (atom_size as u64); + (start, end, atom_offset as u32) + } + }; + let new_obj = CLVMTreeBoundary::Atom { + start, + end, + atom_offset, + }; + cursor = end; + r.push(new_obj); + } + } + ParseOpRef::SaveCursor(index) => { + if let CLVMTreeBoundary::Pair { + start, + end: _, + right_index, + } = r[index] + { + r[index] = CLVMTreeBoundary::Pair { + start, + end: cursor, + right_index, + }; + } + } + ParseOpRef::SaveIndex(index) => { + if let CLVMTreeBoundary::Pair { + start, + end, + right_index: _, + } = r[index] + { + r[index] = CLVMTreeBoundary::Pair { + start, + end, + right_index: r.len() as u32, + }; + } + } + }, + } + } + Ok(r) +} + +#[cfg(test)] +use std::io::Cursor; + +#[cfg(test)] +use hex::FromHex; + +#[cfg(test)] +fn check_parse_tree(h: &str, expected: Vec) -> () { + let b = Vec::from_hex(h).unwrap(); + println!("{:?}", b); + let mut f = Cursor::new(b); + let p = deserialize_tree(&mut f).unwrap(); + assert_eq!(p, expected); +} + +#[test] +fn test_parse_tree() { + check_parse_tree( + "80", + vec![CLVMTreeBoundary::Atom { + start: 0, + end: 1, + atom_offset: 1, + }], + ); + + check_parse_tree( + "ff648200c8", + vec![ + CLVMTreeBoundary::Pair { + start: 0, + end: 5, + right_index: 2, + }, + CLVMTreeBoundary::Atom { + start: 1, + end: 2, + atom_offset: 0, + }, + CLVMTreeBoundary::Atom { + start: 2, + end: 5, + atom_offset: 1, + }, + ], + ); + + check_parse_tree( + "ff83666f6fff83626172ff8362617a80", // `(foo bar baz)` + vec![ + CLVMTreeBoundary::Pair { + start: 0, + end: 16, + right_index: 2, + }, + CLVMTreeBoundary::Atom { + start: 1, + end: 5, + atom_offset: 1, + }, + CLVMTreeBoundary::Pair { + start: 5, + end: 16, + right_index: 4, + }, + CLVMTreeBoundary::Atom { + start: 6, + end: 10, + atom_offset: 1, + }, + CLVMTreeBoundary::Pair { + start: 10, + end: 16, + right_index: 6, + }, + CLVMTreeBoundary::Atom { + start: 11, + end: 15, + atom_offset: 1, + }, + CLVMTreeBoundary::Atom { + start: 15, + end: 16, + atom_offset: 1, + }, + ], + ); + + let s = "c0a0".to_owned() + &hex::encode([0x31u8; 160]); + check_parse_tree( + &s, + vec![CLVMTreeBoundary::Atom { + start: 0, + end: 162, + atom_offset: 2, + }], + ); +} diff --git a/src/lib.rs b/src/lib.rs index 9bf0d5ee..5c9aeb1c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ pub mod allocator; pub mod chia_dialect; pub mod core_ops; pub mod cost; +pub mod deserialize_tree; pub mod dialect; pub mod err_utils; pub mod f_table; @@ -13,7 +14,6 @@ pub mod reduction; pub mod run_program; pub mod runtime_dialect; pub mod serde; -pub mod serialize; pub mod sha2; pub mod traverse_path; diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 41dcf0f6..98136d3f 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -17,4 +17,5 @@ pub use de::node_from_bytes; pub use de_br::node_from_bytes_backrefs; pub use ser::node_to_bytes; pub use ser_br::node_to_bytes_backrefs; +pub use parse_atom::decode_size; pub use tools::{serialized_length_from_bytes, tree_hash_from_stream}; diff --git a/src/serialize.rs b/src/serialize.rs index b97b2dfe..d6ea63ea 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -88,132 +88,10 @@ pub fn node_to_stream(node: &Node, f: &mut dyn io::Write) -> io::Result<()> { Ok(()) } -/// This data structure is used with `parse_triples`, which returns a triple of -/// integer values for each clvm object in a tree. - -#[derive(Debug, PartialEq, Eq)] -pub enum ParsedTriple { - Atom { - start: u64, - end: u64, - atom_offset: u32, - }, - Pair { - start: u64, - end: u64, - right_index: u32, - }, -} - -enum ParseOpRef { - ParseObj, - SaveCursor(usize), - SaveIndex(usize), -} - -fn skip_bytes(f: &mut R, skip_size: u64) -> io::Result { - io::copy(&mut f.by_ref().take(skip_size), &mut io::sink()) -} - -/// parse a serialized clvm object tree to an array of `ParsedTriple` objects - -/// This alternative mechanism of deserialization generates an array of -/// references to each clvm object. A reference contains three values: -/// a start offset within the blob, an end offset, and a third value that -/// is either: an atom offset (relative to the start offset) where the atom -/// data starts (and continues to the end offset); or an index in the array -/// corresponding to the "right" element of the pair (in which case, the -/// "left" element corresponds to the next index in the array). -/// -/// Since these values are offsets into the original buffer, that buffer needs -/// to be kept around to get the original atoms. - -pub fn parse_triples(f: &mut R) -> io::Result> { - let mut r = Vec::new(); - let mut op_stack = vec![ParseOpRef::ParseObj]; - let mut cursor: u64 = 0; - loop { - match op_stack.pop() { - None => { - break; - } - Some(op) => match op { - ParseOpRef::ParseObj => { - let mut b: [u8; 1] = [0]; - f.read_exact(&mut b)?; - let start = cursor as u64; - cursor += 1; - let b = b[0]; - if b == CONS_BOX_MARKER { - let index = r.len(); - let new_obj = ParsedTriple::Pair { - start, - end: 0, - right_index: 0, - }; - r.push(new_obj); - op_stack.push(ParseOpRef::SaveCursor(index)); - op_stack.push(ParseOpRef::ParseObj); - op_stack.push(ParseOpRef::SaveIndex(index)); - op_stack.push(ParseOpRef::ParseObj); - } else { - let (start, end, atom_offset) = { - if b <= MAX_SINGLE_BYTE { - (start, start + 1, 0) - } else { - let (atom_offset, atom_size) = decode_size(f, b)?; - skip_bytes(f, atom_size)?; - let end = start + (atom_offset as u64) + (atom_size as u64); - (start, end, atom_offset as u32) - } - }; - let new_obj = ParsedTriple::Atom { - start, - end, - atom_offset, - }; - cursor = end; - r.push(new_obj); - } - } - ParseOpRef::SaveCursor(index) => { - if let ParsedTriple::Pair { - start, - end: _, - right_index, - } = r[index] - { - r[index] = ParsedTriple::Pair { - start, - end: cursor, - right_index, - }; - } - } - ParseOpRef::SaveIndex(index) => { - if let ParsedTriple::Pair { - start, - end, - right_index: _, - } = r[index] - { - r[index] = ParsedTriple::Pair { - start, - end, - right_index: r.len() as u32, - }; - } - } - }, - } - } - Ok(r) -} - /// decode the length prefix for an atom. Atoms whose value fit in 7 bits /// don't have a length prefix, so those should be handled specially and /// never passed to this function. -fn decode_size(f: &mut R, initial_b: u8) -> io::Result<(u8, u64)> { +pub fn decode_size(f: &mut R, initial_b: u8) -> io::Result<(u8, u64)> { debug_assert!((initial_b & 0x80) != 0); if (initial_b & 0x80) == 0 { return Err(internal_error()); @@ -726,96 +604,3 @@ fn test_truncated_decode_size() { let e = ret.unwrap_err(); assert_eq!(e.kind(), ErrorKind::UnexpectedEof); } - -#[cfg(test)] -fn check_parse_triple(h: &str, expected: Vec) -> () { - let b = Vec::from_hex(h).unwrap(); - println!("{:?}", b); - let mut f = Cursor::new(b); - let p = parse_triples(&mut f).unwrap(); - assert_eq!(p, expected); -} - -#[test] -fn test_parse_triple() { - check_parse_triple( - "80", - vec![ParsedTriple::Atom { - start: 0, - end: 1, - atom_offset: 1, - }], - ); - - check_parse_triple( - "ff648200c8", - vec![ - ParsedTriple::Pair { - start: 0, - end: 5, - right_index: 2, - }, - ParsedTriple::Atom { - start: 1, - end: 2, - atom_offset: 0, - }, - ParsedTriple::Atom { - start: 2, - end: 5, - atom_offset: 1, - }, - ], - ); - - check_parse_triple( - "ff83666f6fff83626172ff8362617a80", // `(foo bar baz)` - vec![ - ParsedTriple::Pair { - start: 0, - end: 16, - right_index: 2, - }, - ParsedTriple::Atom { - start: 1, - end: 5, - atom_offset: 1, - }, - ParsedTriple::Pair { - start: 5, - end: 16, - right_index: 4, - }, - ParsedTriple::Atom { - start: 6, - end: 10, - atom_offset: 1, - }, - ParsedTriple::Pair { - start: 10, - end: 16, - right_index: 6, - }, - ParsedTriple::Atom { - start: 11, - end: 15, - atom_offset: 1, - }, - ParsedTriple::Atom { - start: 15, - end: 16, - atom_offset: 1, - }, - ], - ); - - let s = "c0a0".to_owned() + &hex::encode([0x31u8; 160]); - check_parse_triple( - &s, - vec![ParsedTriple::Atom { - start: 0, - end: 162, - atom_offset: 2, - }], - ); -} diff --git a/wheel/src/api.rs b/wheel/src/api.rs index 5c316aca..e32a107e 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -5,10 +5,11 @@ use crate::adapt_response::adapt_response; use clvmr::allocator::Allocator; use clvmr::chia_dialect::ChiaDialect; use clvmr::cost::Cost; +use clvmr::deserialize_tree::{deserialize_tree, CLVMTreeBoundary}; use clvmr::reduction::Response; use clvmr::run_program::run_program; use clvmr::serde::{node_from_bytes, serialized_length_from_bytes}; -use clvmr::serialize::{parse_triples, ParsedTriple}; +use clvmr::deserialize_tree::{parse_triples, ParsedTriple}; use clvmr::{LIMIT_HEAP, LIMIT_STACK, MEMPOOL_MODE, NO_UNKNOWN_OPS}; use pyo3::prelude::*; use pyo3::types::PyTuple; @@ -43,14 +44,14 @@ pub fn run_serialized_chia_program( adapt_response(py, allocator, r) } -fn tuple_for_parsed_triple(py: Python<'_>, p: &ParsedTriple) -> PyObject { +fn tuple_for_parsed_triple(py: Python<'_>, p: &CLVMTreeBoundary) -> PyObject { let tuple = match p { - ParsedTriple::Atom { + CLVMTreeBoundary::Atom { start, end, atom_offset, } => PyTuple::new(py, [*start, *end, *atom_offset as u64]), - ParsedTriple::Pair { + CLVMTreeBoundary::Pair { start, end, right_index, @@ -60,9 +61,9 @@ fn tuple_for_parsed_triple(py: Python<'_>, p: &ParsedTriple) -> PyObject { } #[pyfunction] -fn deserialize_as_triples(py: Python, blob: &[u8]) -> PyResult> { +fn deserialize_as_tree(py: Python, blob: &[u8]) -> PyResult> { let mut cursor = io::Cursor::new(blob); - let r = parse_triples(&mut cursor)?; + let r = deserialize_tree(&mut cursor)?; let r = r.iter().map(|pt| tuple_for_parsed_triple(py, pt)).collect(); Ok(r) } From a154d7ea01eddcd7644d6ab18984936872be8cd2 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Mon, 17 Oct 2022 19:01:04 -0700 Subject: [PATCH 10/45] First crack at rust tree hashes. --- src/deserialize_tree.rs | 120 +++++++++++++++++++++++++++++++++++++--- wheel/src/api.rs | 8 ++- 2 files changed, 117 insertions(+), 11 deletions(-) diff --git a/src/deserialize_tree.rs b/src/deserialize_tree.rs index eba211d5..384e8775 100644 --- a/src/deserialize_tree.rs +++ b/src/deserialize_tree.rs @@ -1,11 +1,28 @@ +use std::convert::TryInto; use std::io; -use std::io::Read; +use std::io::{Read, Write}; + +use sha2::Digest; + +use crate::sha2::Sha256; use crate::serde::decode_size; const MAX_SINGLE_BYTE: u8 = 0x7f; const CONS_BOX_MARKER: u8 = 0xff; +struct ShaWrapper(Sha256); + +impl Write for ShaWrapper { + fn write(&mut self, blob: &[u8]) -> std::result::Result { + self.0.update(blob); + Ok(blob.len()) + } + fn flush(&mut self) -> std::result::Result<(), std::io::Error> { + Ok(()) + } +} + /// This data structure is used with `deserialize_tree`, which returns a triple of /// integer values for each clvm object in a tree. @@ -29,10 +46,52 @@ enum ParseOpRef { SaveIndex(usize), } +fn sha_blobs(blobs: &[&[u8]]) -> [u8; 32] { + let mut h = Sha256::new(); + for blob in blobs { + h.update(blob); + } + h.finalize() + .as_slice() + .try_into() + .expect("wrong slice length") +} + +fn tree_hash_for_byte(b: u8, calculate_tree_hashes: bool) -> Option<[u8; 32]> { + if calculate_tree_hashes { + Some(sha_blobs(&[&[1, b]])) + } else { + None + } +} + fn skip_bytes(f: &mut R, skip_size: u64) -> io::Result { io::copy(&mut f.by_ref().take(skip_size), &mut io::sink()) } +fn skip_or_sha_bytes( + f: &mut R, + skip_size: u64, + calculate_tree_hashes: bool, +) -> io::Result> { + if calculate_tree_hashes { + let mut f = &mut f.by_ref().take(skip_size); + let mut h = Sha256::new(); + h.update(&[1]); + let mut w = ShaWrapper(h); + io::copy(&mut f, &mut w)?; + let r: [u8; 32] = + w.0.finalize() + .as_slice() + .try_into() + .expect("wrong slice length"); + Ok(Some(r)) + } else { + skip_bytes(f, skip_size)?; + Ok(None) + } +} + /// parse a serialized clvm object tree to an array of `CLVMTreeBoundary` objects /// This alternative mechanism of deserialization generates an array of @@ -46,8 +105,12 @@ fn skip_bytes(f: &mut R, skip_size: u64) -> io::Result { /// Since these values are offsets into the original buffer, that buffer needs /// to be kept around to get the original atoms. -pub fn deserialize_tree(f: &mut R) -> io::Result> { +pub fn deserialize_tree( + f: &mut R, + calculate_tree_hashes: bool, +) -> io::Result<(Vec, Option>)> { let mut r = Vec::new(); + let mut tree_hashes = Vec::new(); let mut op_stack = vec![ParseOpRef::ParseObj]; let mut cursor: u64 = 0; loop { @@ -70,21 +133,32 @@ pub fn deserialize_tree(f: &mut R) -> io::Result(f: &mut R) -> io::Result(f: &mut R) -> io::Result) -> () { +fn check_parse_tree(h: &str, expected: Vec, expected_sha_tree_hex: &str) -> () { let b = Vec::from_hex(h).unwrap(); println!("{:?}", b); let mut f = Cursor::new(b); - let p = deserialize_tree(&mut f).unwrap(); + let (p, tree_hash) = deserialize_tree(&mut f, false).unwrap(); assert_eq!(p, expected); + assert_eq!(tree_hash, None); + + let b = Vec::from_hex(h).unwrap(); + let mut f = Cursor::new(b); + let (p, tree_hash) = deserialize_tree(&mut f, true).unwrap(); + assert_eq!(p, expected); + + let est = Vec::from_hex(expected_sha_tree_hex).unwrap(); + assert_eq!(tree_hash.unwrap()[0].to_vec(), est); } #[test] @@ -152,6 +250,7 @@ fn test_parse_tree() { end: 1, atom_offset: 1, }], + "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", ); check_parse_tree( @@ -173,6 +272,7 @@ fn test_parse_tree() { atom_offset: 1, }, ], + "247f7d3f63b346ea93ca47f571cd0f4455392348b888a4286072bef0ac6069b5", ); check_parse_tree( @@ -214,6 +314,7 @@ fn test_parse_tree() { atom_offset: 1, }, ], + "47f30bf9935e25e4262023124fb5e986d755b9ed65a28ac78925c933bfd57dbd", ); let s = "c0a0".to_owned() + &hex::encode([0x31u8; 160]); @@ -224,5 +325,6 @@ fn test_parse_tree() { end: 162, atom_offset: 2, }], + "d1c109981a9c5a3bbe2d98795a186a0f057dc9a3a7f5e1eb4dfb63a1636efa2d", ); } diff --git a/wheel/src/api.rs b/wheel/src/api.rs index e32a107e..85ec03ca 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -61,9 +61,13 @@ fn tuple_for_parsed_triple(py: Python<'_>, p: &CLVMTreeBoundary) -> PyObject { } #[pyfunction] -fn deserialize_as_tree(py: Python, blob: &[u8]) -> PyResult> { +fn deserialize_as_tree( + py: Python, + blob: &[u8], + calculate_tree_hashes: bool, +) -> PyResult> { let mut cursor = io::Cursor::new(blob); - let r = deserialize_tree(&mut cursor)?; + let (r, _tree_hashes) = deserialize_tree(&mut cursor, calculate_tree_hashes)?; let r = r.iter().map(|pt| tuple_for_parsed_triple(py, pt)).collect(); Ok(r) } From 9a2ca753fe250daeba876e8dd9ea2209df357e84 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Thu, 12 Jan 2023 17:05:30 -0800 Subject: [PATCH 11/45] Use rust parsing if present. --- wheel/python/benchmarks/deserialization.py | 6 +++++- wheel/python/clvm_rs/casts.py | 5 +++-- wheel/python/clvm_rs/clvm_tree.py | 3 +-- wheel/python/clvm_rs/deser.py | 16 ++++++++++++++-- wheel/python/clvm_rs/program.py | 20 ++++++++++++++++++-- wheel/src/api.rs | 9 +++++---- 6 files changed, 46 insertions(+), 13 deletions(-) diff --git a/wheel/python/benchmarks/deserialization.py b/wheel/python/benchmarks/deserialization.py index fbe381d5..cd4c155a 100644 --- a/wheel/python/benchmarks/deserialization.py +++ b/wheel/python/benchmarks/deserialization.py @@ -3,13 +3,15 @@ from clvm_rs.program import Program +from clvm_rs.clvm_rs import serialized_length + def bench(f, name: str): start = time.time() r = f() end = time.time() d = end - start - print(f"{name}: {end-start:1.4f} s") + print(f"{name}: {d:1.4f} s") print() return r @@ -44,6 +46,8 @@ def bench(f, name: str): bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash") bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash again") +bench(lambda: print(serialized_length(blob)), "print serialized_length") + des_output = bench( lambda: Program.from_bytes(blob, calculate_tree_hash=False), "from_bytes output (with no tree hashing)", diff --git a/wheel/python/clvm_rs/casts.py b/wheel/python/clvm_rs/casts.py index 38972507..206bd8a5 100644 --- a/wheel/python/clvm_rs/casts.py +++ b/wheel/python/clvm_rs/casts.py @@ -27,7 +27,8 @@ def int_to_bytes(v): r = v.to_bytes(byte_count, "big", signed=True) # make sure the string returned is minimal # ie. no leading 00 or ff bytes that are unnecessary - assert not (len(r) > 1 and r[0] == (0xFF if r[1] & 0x80 else 0)) + while len(r) > 1 and r[0] == (0xFF if r[1] & 0x80 else 0): + r = r[1:] return r @@ -42,7 +43,7 @@ def to_atom_type(v: AtomCastableType) -> bytes: return v.encode() if isinstance(v, int): return int_to_bytes(v) - if hasattr(v, "__bytes__"): + if hasattr(v, "__bytes__") or isinstance(v, memoryview): return bytes(v) if v is None: return NULL diff --git a/wheel/python/clvm_rs/clvm_tree.py b/wheel/python/clvm_rs/clvm_tree.py index 9b213035..c0548507 100644 --- a/wheel/python/clvm_rs/clvm_tree.py +++ b/wheel/python/clvm_rs/clvm_tree.py @@ -1,6 +1,5 @@ from .deser import deserialize_as_tuples - from typing import List, Optional, Tuple @@ -65,7 +64,7 @@ def atom(self) -> Optional[bytes]: if self.blob[start] == 0xFF: self._atom = None else: - self._atom = self.blob[start + atom_offset : end] + self._atom = bytes(self.blob[start + atom_offset : end]) return self._atom @property diff --git a/wheel/python/clvm_rs/deser.py b/wheel/python/clvm_rs/deser.py index 414c211a..817710af 100644 --- a/wheel/python/clvm_rs/deser.py +++ b/wheel/python/clvm_rs/deser.py @@ -2,6 +2,11 @@ from .tree_hash import shatree_atom, shatree_pair +try: + from clvm_rs.clvm_rs import deserialize_as_tree +except ImportError: + deserialize_as_tree = None + MAX_SINGLE_BYTE = 0x7F CONS_BOX_MARKER = 0xFF @@ -14,6 +19,13 @@ def deserialize_as_tuples( blob: bytes, cursor: int, calculate_tree_hash: bool ) -> Tuple[List[Tuple[int, int, int]], List[Optional[bytes]]]: + + if deserialize_as_tree: + tree, hashes = deserialize_as_tree(blob, calculate_tree_hash) + if not calculate_tree_hash: + hashes = [None] * len(tree) + return tree, hashes + def save_cursor(index, blob, cursor, obj_list, op_stack): assert blob[obj_list[index][0]] == 0xFF left_hash = tree_hash_list[index + 1] @@ -48,7 +60,7 @@ def parse_obj(blob, cursor, obj_list, op_stack): atom_offset, new_cursor = _atom_size_from_cursor(blob, cursor) my_hash = None if calculate_tree_hash: - my_hash = shatree_atom(blob[cursor + atom_offset : new_cursor]) + my_hash = shatree_atom(blob[cursor + atom_offset:new_cursor]) tree_hash_list.append(my_hash) obj_list.append((cursor, new_cursor, atom_offset)) return new_cursor @@ -77,7 +89,7 @@ def _atom_size_from_cursor(blob, cursor) -> Tuple[int, int]: bit_mask >>= 1 size_blob = bytes([b]) if bit_count > 1: - size_blob += blob[cursor + 1 : cursor + bit_count] + size_blob += blob[cursor + 1:cursor + bit_count] size = int.from_bytes(size_blob, "big") new_cursor = cursor + size + bit_count if new_cursor > len(blob): diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index 7de86693..6aba803c 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -2,7 +2,7 @@ from typing import Dict, Iterator, List, Tuple, Optional, Any from .base import CLVMObject -from .casts import to_clvm_object +from .casts import to_clvm_object, int_to_bytes from .EvalError import EvalError from clvm_rs.clvm_rs import run_serialized_program from clvm_rs.serialize import sexp_from_stream, sexp_to_stream, sexp_to_bytes @@ -55,6 +55,20 @@ def fromhex(cls, hexstr: str) -> Program: def __bytes__(self) -> bytes: return sexp_to_bytes(self) + def __int__(self) -> int: + return self.as_int() + + def __hash__(self): + return self.tree_hash().__hash__() + + @classmethod + def int_from_bytes(cls, b: bytes) -> int: + return int_from_bytes(b) + + @classmethod + def int_to_bytes(cls, i: int) -> bytes: + return int_to_bytes(i) + # high level casting with `.to` def __init__(self) -> Program: @@ -85,8 +99,9 @@ def unwrap(self) -> CLVMObject: @classmethod def new_atom(cls, v: bytes) -> Program: o = cls() - o.atom = v + o.atom = bytes(v) o.pair = None + o.wrapped = o return o @classmethod @@ -94,6 +109,7 @@ def new_pair(cls, left: CLVMObject, right: CLVMObject) -> Program: o = cls() o.atom = None o.pair = (left, right) + o.wrapped = o return o @classmethod diff --git a/wheel/src/api.rs b/wheel/src/api.rs index 85ec03ca..80af234f 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -12,7 +12,7 @@ use clvmr::serde::{node_from_bytes, serialized_length_from_bytes}; use clvmr::deserialize_tree::{parse_triples, ParsedTriple}; use clvmr::{LIMIT_HEAP, LIMIT_STACK, MEMPOOL_MODE, NO_UNKNOWN_OPS}; use pyo3::prelude::*; -use pyo3::types::PyTuple; +use pyo3::types::{PyBytes, PyTuple}; use pyo3::wrap_pyfunction; #[pyfunction] @@ -65,11 +65,12 @@ fn deserialize_as_tree( py: Python, blob: &[u8], calculate_tree_hashes: bool, -) -> PyResult> { +) -> PyResult<(Vec, Option>)> { let mut cursor = io::Cursor::new(blob); - let (r, _tree_hashes) = deserialize_tree(&mut cursor, calculate_tree_hashes)?; + let (r, tree_hashes) = deserialize_tree(&mut cursor, calculate_tree_hashes)?; let r = r.iter().map(|pt| tuple_for_parsed_triple(py, pt)).collect(); - Ok(r) + let s = tree_hashes.map(|ths| ths.iter().map(|b| PyBytes::new(py, b).into()).collect()); + Ok((r, s)) } #[pymodule] From 49fe40f9acd2c6cb2832dd15bdd50afdc59b673d Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Thu, 12 Jan 2023 17:33:46 -0800 Subject: [PATCH 12/45] checkpoint --- src/lib.rs | 1 - src/{deserialize_tree.rs => serde/de_tree.rs} | 2 +- src/serde/mod.rs | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) rename src/{deserialize_tree.rs => serde/de_tree.rs} (99%) diff --git a/src/lib.rs b/src/lib.rs index 5c9aeb1c..34184ae7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,6 @@ pub mod allocator; pub mod chia_dialect; pub mod core_ops; pub mod cost; -pub mod deserialize_tree; pub mod dialect; pub mod err_utils; pub mod f_table; diff --git a/src/deserialize_tree.rs b/src/serde/de_tree.rs similarity index 99% rename from src/deserialize_tree.rs rename to src/serde/de_tree.rs index 384e8775..97d3916a 100644 --- a/src/deserialize_tree.rs +++ b/src/serde/de_tree.rs @@ -6,7 +6,7 @@ use sha2::Digest; use crate::sha2::Sha256; -use crate::serde::decode_size; +use super::decode_size; const MAX_SINGLE_BYTE: u8 = 0x7f; const CONS_BOX_MARKER: u8 = 0xff; diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 98136d3f..8a6f7e36 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -1,6 +1,7 @@ mod bytes32; mod de; mod de_br; +mod de_tree; mod errors; mod object_cache; mod parse_atom; @@ -17,5 +18,4 @@ pub use de::node_from_bytes; pub use de_br::node_from_bytes_backrefs; pub use ser::node_to_bytes; pub use ser_br::node_to_bytes_backrefs; -pub use parse_atom::decode_size; pub use tools::{serialized_length_from_bytes, tree_hash_from_stream}; From 8ecaee47026cdd85f9e8a30fb061664af46e272e Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Thu, 5 Jan 2023 15:01:54 -0800 Subject: [PATCH 13/45] tests pass --- src/serde/de_tree.rs | 83 +-- src/serde/mod.rs | 1 + src/serde/parse_atom.rs | 26 +- src/serialize.rs | 606 --------------------- wheel/python/benchmarks/deserialization.py | 14 +- wheel/python/tests/test_serialize.py | 6 +- wheel/src/api.rs | 14 +- 7 files changed, 74 insertions(+), 676 deletions(-) delete mode 100644 src/serialize.rs diff --git a/src/serde/de_tree.rs b/src/serde/de_tree.rs index 97d3916a..486c63c5 100644 --- a/src/serde/de_tree.rs +++ b/src/serde/de_tree.rs @@ -1,12 +1,11 @@ use std::convert::TryInto; -use std::io; -use std::io::{Read, Write}; +use std::io::{copy, sink, Error, Read, Result, Write}; use sha2::Digest; use crate::sha2::Sha256; -use super::decode_size; +use super::parse_atom::decode_size_with_offset; const MAX_SINGLE_BYTE: u8 = 0x7f; const CONS_BOX_MARKER: u8 = 0xff; @@ -14,20 +13,20 @@ const CONS_BOX_MARKER: u8 = 0xff; struct ShaWrapper(Sha256); impl Write for ShaWrapper { - fn write(&mut self, blob: &[u8]) -> std::result::Result { + fn write(&mut self, blob: &[u8]) -> std::result::Result { self.0.update(blob); Ok(blob.len()) } - fn flush(&mut self) -> std::result::Result<(), std::io::Error> { + fn flush(&mut self) -> std::result::Result<(), Error> { Ok(()) } } -/// This data structure is used with `deserialize_tree`, which returns a triple of +/// This data structure is used with `parse_triples`, which returns a triple of /// integer values for each clvm object in a tree. #[derive(Debug, PartialEq, Eq)] -pub enum CLVMTreeBoundary { +pub enum ParsedTriple { Atom { start: u64, end: u64, @@ -65,21 +64,21 @@ fn tree_hash_for_byte(b: u8, calculate_tree_hashes: bool) -> Option<[u8; 32]> { } } -fn skip_bytes(f: &mut R, skip_size: u64) -> io::Result { - io::copy(&mut f.by_ref().take(skip_size), &mut io::sink()) +fn skip_bytes(f: &mut R, skip_size: u64) -> Result { + copy(&mut f.by_ref().take(skip_size), &mut sink()) } -fn skip_or_sha_bytes( +fn skip_or_sha_bytes( f: &mut R, skip_size: u64, calculate_tree_hashes: bool, -) -> io::Result> { +) -> Result> { if calculate_tree_hashes { let mut f = &mut f.by_ref().take(skip_size); let mut h = Sha256::new(); - h.update(&[1]); + h.update([1]); let mut w = ShaWrapper(h); - io::copy(&mut f, &mut w)?; + copy(&mut f, &mut w)?; let r: [u8; 32] = w.0.finalize() .as_slice() @@ -92,7 +91,7 @@ fn skip_or_sha_bytes( } } -/// parse a serialized clvm object tree to an array of `CLVMTreeBoundary` objects +/// parse a serialized clvm object tree to an array of `ParsedTriple` objects /// This alternative mechanism of deserialization generates an array of /// references to each clvm object. A reference contains three values: @@ -100,15 +99,17 @@ fn skip_or_sha_bytes( /// is either: an atom offset (relative to the start offset) where the atom /// data starts (and continues to the end offset); or an index in the array /// corresponding to the "right" element of the pair (in which case, the -/// "left" element corresponds to the next index in the array). +/// "left" element corresponds to the current index + 1). /// /// Since these values are offsets into the original buffer, that buffer needs /// to be kept around to get the original atoms. -pub fn deserialize_tree( +type ParsedTriplesOutput = (Vec, Option>); + +pub fn parse_triples( f: &mut R, calculate_tree_hashes: bool, -) -> io::Result<(Vec, Option>)> { +) -> Result { let mut r = Vec::new(); let mut tree_hashes = Vec::new(); let mut op_stack = vec![ParseOpRef::ParseObj]; @@ -122,12 +123,12 @@ pub fn deserialize_tree( ParseOpRef::ParseObj => { let mut b: [u8; 1] = [0]; f.read_exact(&mut b)?; - let start = cursor as u64; + let start = cursor; cursor += 1; let b = b[0]; if b == CONS_BOX_MARKER { let index = r.len(); - let new_obj = CLVMTreeBoundary::Pair { + let new_obj = ParsedTriple::Pair { start, end: 0, right_index: 0, @@ -150,8 +151,8 @@ pub fn deserialize_tree( tree_hash_for_byte(b, calculate_tree_hashes), ) } else { - let (atom_offset, atom_size) = decode_size(f, b)?; - let end = start + (atom_offset as u64) + (atom_size as u64); + let (atom_offset, atom_size) = decode_size_with_offset(f, b)?; + let end = start + (atom_offset as u64) + atom_size; let h = skip_or_sha_bytes(f, atom_size, calculate_tree_hashes)?; (start, end, atom_offset as u32, h) } @@ -159,7 +160,7 @@ pub fn deserialize_tree( if calculate_tree_hashes { tree_hashes.push(tree_hash.expect("failed unwrap")) } - let new_obj = CLVMTreeBoundary::Atom { + let new_obj = ParsedTriple::Atom { start, end, atom_offset, @@ -169,7 +170,7 @@ pub fn deserialize_tree( } } ParseOpRef::SaveCursor(index) => { - if let CLVMTreeBoundary::Pair { + if let ParsedTriple::Pair { start, end: _, right_index, @@ -183,7 +184,7 @@ pub fn deserialize_tree( ]); tree_hashes[index] = h; } - r[index] = CLVMTreeBoundary::Pair { + r[index] = ParsedTriple::Pair { start, end: cursor, right_index, @@ -191,13 +192,13 @@ pub fn deserialize_tree( } } ParseOpRef::SaveIndex(index) => { - if let CLVMTreeBoundary::Pair { + if let ParsedTriple::Pair { start, end, right_index: _, } = r[index] { - r[index] = CLVMTreeBoundary::Pair { + r[index] = ParsedTriple::Pair { start, end, right_index: r.len() as u32, @@ -224,17 +225,17 @@ use std::io::Cursor; use hex::FromHex; #[cfg(test)] -fn check_parse_tree(h: &str, expected: Vec, expected_sha_tree_hex: &str) -> () { +fn check_parse_tree(h: &str, expected: Vec, expected_sha_tree_hex: &str) -> () { let b = Vec::from_hex(h).unwrap(); println!("{:?}", b); let mut f = Cursor::new(b); - let (p, tree_hash) = deserialize_tree(&mut f, false).unwrap(); + let (p, tree_hash) = parse_triples(&mut f, false).unwrap(); assert_eq!(p, expected); assert_eq!(tree_hash, None); let b = Vec::from_hex(h).unwrap(); let mut f = Cursor::new(b); - let (p, tree_hash) = deserialize_tree(&mut f, true).unwrap(); + let (p, tree_hash) = parse_triples(&mut f, true).unwrap(); assert_eq!(p, expected); let est = Vec::from_hex(expected_sha_tree_hex).unwrap(); @@ -245,7 +246,7 @@ fn check_parse_tree(h: &str, expected: Vec, expected_sha_tree_ fn test_parse_tree() { check_parse_tree( "80", - vec![CLVMTreeBoundary::Atom { + vec![ParsedTriple::Atom { start: 0, end: 1, atom_offset: 1, @@ -256,17 +257,17 @@ fn test_parse_tree() { check_parse_tree( "ff648200c8", vec![ - CLVMTreeBoundary::Pair { + ParsedTriple::Pair { start: 0, end: 5, right_index: 2, }, - CLVMTreeBoundary::Atom { + ParsedTriple::Atom { start: 1, end: 2, atom_offset: 0, }, - CLVMTreeBoundary::Atom { + ParsedTriple::Atom { start: 2, end: 5, atom_offset: 1, @@ -278,37 +279,37 @@ fn test_parse_tree() { check_parse_tree( "ff83666f6fff83626172ff8362617a80", // `(foo bar baz)` vec![ - CLVMTreeBoundary::Pair { + ParsedTriple::Pair { start: 0, end: 16, right_index: 2, }, - CLVMTreeBoundary::Atom { + ParsedTriple::Atom { start: 1, end: 5, atom_offset: 1, }, - CLVMTreeBoundary::Pair { + ParsedTriple::Pair { start: 5, end: 16, right_index: 4, }, - CLVMTreeBoundary::Atom { + ParsedTriple::Atom { start: 6, end: 10, atom_offset: 1, }, - CLVMTreeBoundary::Pair { + ParsedTriple::Pair { start: 10, end: 16, right_index: 6, }, - CLVMTreeBoundary::Atom { + ParsedTriple::Atom { start: 11, end: 15, atom_offset: 1, }, - CLVMTreeBoundary::Atom { + ParsedTriple::Atom { start: 15, end: 16, atom_offset: 1, @@ -320,7 +321,7 @@ fn test_parse_tree() { let s = "c0a0".to_owned() + &hex::encode([0x31u8; 160]); check_parse_tree( &s, - vec![CLVMTreeBoundary::Atom { + vec![ParsedTriple::Atom { start: 0, end: 162, atom_offset: 2, diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 8a6f7e36..d6ff0f8c 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -16,6 +16,7 @@ mod test; pub use de::node_from_bytes; pub use de_br::node_from_bytes_backrefs; +pub use de_tree::{parse_triples, ParsedTriple}; pub use ser::node_to_bytes; pub use ser_br::node_to_bytes_backrefs; pub use tools::{serialized_length_from_bytes, tree_hash_from_stream}; diff --git a/src/serde/parse_atom.rs b/src/serde/parse_atom.rs index 3dddb29e..9fd0abb3 100644 --- a/src/serde/parse_atom.rs +++ b/src/serde/parse_atom.rs @@ -9,40 +9,44 @@ const MAX_SINGLE_BYTE: u8 = 0x7f; /// decode the length prefix for an atom. Atoms whose value fit in 7 bits /// don't have a length prefix, so those should be handled specially and /// never passed to this function. -pub fn decode_size(f: &mut R, initial_b: u8) -> Result { +pub fn decode_size_with_offset(f: &mut R, initial_b: u8) -> Result<(u8, u64)> { debug_assert!((initial_b & 0x80) != 0); if (initial_b & 0x80) == 0 { return Err(internal_error()); } - let mut bit_count = 0; + let mut atom_start_offset = 0; let mut bit_mask: u8 = 0x80; let mut b = initial_b; while b & bit_mask != 0 { - bit_count += 1; + atom_start_offset += 1; b &= 0xff ^ bit_mask; bit_mask >>= 1; } let mut size_blob: Vec = Vec::new(); - size_blob.resize(bit_count, 0); + size_blob.resize(atom_start_offset, 0); size_blob[0] = b; - if bit_count > 1 { - let remaining_buffer = &mut size_blob[1..]; + if atom_start_offset > 1 { + let remaining_buffer = &mut size_blob[1..atom_start_offset]; f.read_exact(remaining_buffer)?; } // need to convert size_blob to an int - let mut v: u64 = 0; + let mut atom_size: u64 = 0; if size_blob.len() > 6 { return Err(bad_encoding()); } for b in &size_blob { - v <<= 8; - v += *b as u64; + atom_size <<= 8; + atom_size += *b as u64; } - if v >= 0x400000000 { + if atom_size >= 0x400000000 { return Err(bad_encoding()); } - Ok(v) + Ok((atom_start_offset as u8, atom_size)) +} + +pub fn decode_size(f: &mut R, initial_b: u8) -> Result { + decode_size_with_offset(f, initial_b).map(|v| v.1) } /// parse an atom from the stream and return a pointer to it diff --git a/src/serialize.rs b/src/serialize.rs deleted file mode 100644 index d6ea63ea..00000000 --- a/src/serialize.rs +++ /dev/null @@ -1,606 +0,0 @@ -use std::io; -use std::io::{Cursor, ErrorKind, Read, Seek, SeekFrom}; - -use crate::allocator::{Allocator, NodePtr, SExp}; -use crate::node::Node; - -const MAX_SINGLE_BYTE: u8 = 0x7f; -const CONS_BOX_MARKER: u8 = 0xff; - -fn bad_encoding() -> io::Error { - io::Error::new(ErrorKind::InvalidInput, "bad encoding") -} - -fn internal_error() -> io::Error { - io::Error::new(ErrorKind::InvalidInput, "internal error") -} - -/// all atoms serialize their contents verbatim. All expect those one-byte atoms -/// from 0x00-0x7f also have a prefix encoding their length. This function -/// writes the correct prefix for an atom of size `size` whose first byte is `atom_0`. -/// If the atom is of size 0, use any placeholder first byte, as it's ignored anyway. - -fn write_atom_encoding_prefix_with_size( - f: &mut dyn io::Write, - atom_0: u8, - size: u64, -) -> io::Result<()> { - if size == 0 { - f.write_all(&[0x80]) - } else if size == 1 && atom_0 < 0x80 { - Ok(()) - } else if size < 0x40 { - f.write_all(&[0x80 | (size as u8)]) - } else if size < 0x2000 { - f.write_all(&[0xc0 | (size >> 8) as u8, size as u8]) - } else if size < 0x10_0000 { - f.write_all(&[ - (0xe0 | (size >> 16)) as u8, - ((size >> 8) & 0xff) as u8, - ((size) & 0xff) as u8, - ]) - } else if size < 0x800_0000 { - f.write_all(&[ - (0xf0 | (size >> 24)) as u8, - ((size >> 16) & 0xff) as u8, - ((size >> 8) & 0xff) as u8, - ((size) & 0xff) as u8, - ]) - } else if size < 0x4_0000_0000 { - f.write_all(&[ - (0xf8 | (size >> 32)) as u8, - ((size >> 24) & 0xff) as u8, - ((size >> 16) & 0xff) as u8, - ((size >> 8) & 0xff) as u8, - ((size) & 0xff) as u8, - ]) - } else { - Err(io::Error::new(ErrorKind::InvalidData, "atom too big")) - } -} - -/// serialize an atom -fn write_atom(f: &mut dyn io::Write, atom: &[u8]) -> io::Result<()> { - let u8_0 = if !atom.is_empty() { atom[0] } else { 0 }; - write_atom_encoding_prefix_with_size(f, u8_0, atom.len() as u64)?; - f.write_all(atom) -} - -/// serialize a node -pub fn node_to_stream(node: &Node, f: &mut dyn io::Write) -> io::Result<()> { - let mut values: Vec = vec![node.node]; - let a = node.allocator; - while !values.is_empty() { - let v = values.pop().unwrap(); - let n = a.sexp(v); - match n { - SExp::Atom(atom_ptr) => { - let atom = a.buf(&atom_ptr); - write_atom(f, atom)?; - } - SExp::Pair(left, right) => { - f.write_all(&[CONS_BOX_MARKER as u8])?; - values.push(right); - values.push(left); - } - } - } - Ok(()) -} - -/// decode the length prefix for an atom. Atoms whose value fit in 7 bits -/// don't have a length prefix, so those should be handled specially and -/// never passed to this function. -pub fn decode_size(f: &mut R, initial_b: u8) -> io::Result<(u8, u64)> { - debug_assert!((initial_b & 0x80) != 0); - if (initial_b & 0x80) == 0 { - return Err(internal_error()); - } - - let mut atom_start_offset = 0; - let mut bit_mask: u8 = 0x80; - let mut b = initial_b; - while b & bit_mask != 0 { - atom_start_offset += 1; - b &= 0xff ^ bit_mask; - bit_mask >>= 1; - } - let mut size_blob: Vec = Vec::new(); - size_blob.resize(atom_start_offset, 0); - size_blob[0] = b; - if atom_start_offset > 1 { - let remaining_buffer = &mut size_blob[1..atom_start_offset]; - f.read_exact(remaining_buffer)?; - } - // need to convert size_blob to an int - let mut atom_size: u64 = 0; - if size_blob.len() > 6 { - return Err(bad_encoding()); - } - for b in &size_blob { - atom_size <<= 8; - atom_size += *b as u64; - } - if atom_size >= 0x400000000 { - return Err(bad_encoding()); - } - Ok((atom_start_offset as u8, atom_size)) -} - -enum ParseOp { - SExp, - Cons, -} - -/// deserialize a clvm node from a `std::io::Cursor` -pub fn node_from_stream(allocator: &mut Allocator, f: &mut Cursor<&[u8]>) -> io::Result { - let mut values: Vec = Vec::new(); - let mut ops = vec![ParseOp::SExp]; - - let mut b = [0; 1]; - loop { - let op = ops.pop(); - if op.is_none() { - break; - } - match op.unwrap() { - ParseOp::SExp => { - f.read_exact(&mut b)?; - if b[0] == CONS_BOX_MARKER { - ops.push(ParseOp::Cons); - ops.push(ParseOp::SExp); - ops.push(ParseOp::SExp); - } else if b[0] == 0x01 { - values.push(allocator.one()); - } else if b[0] == 0x80 { - values.push(allocator.null()); - } else if b[0] <= MAX_SINGLE_BYTE { - values.push(allocator.new_atom(&b)?); - } else { - let (_prefix_size, blob_size) = decode_size(f, b[0])?; - if (f.get_ref().len()) < blob_size as usize { - return Err(bad_encoding()); - } - let mut blob: Vec = vec![0; blob_size as usize]; - f.read_exact(&mut blob)?; - values.push(allocator.new_atom(&blob)?); - } - } - ParseOp::Cons => { - // cons - let v2 = values.pop(); - let v1 = values.pop(); - values.push(allocator.new_pair(v1.unwrap(), v2.unwrap())?); - } - } - } - Ok(values.pop().unwrap()) -} - -pub fn node_from_bytes(allocator: &mut Allocator, b: &[u8]) -> io::Result { - let mut buffer = Cursor::new(b); - node_from_stream(allocator, &mut buffer) -} - -pub fn node_to_bytes(node: &Node) -> io::Result> { - let mut buffer = Cursor::new(Vec::new()); - - node_to_stream(node, &mut buffer)?; - let vec = buffer.into_inner(); - Ok(vec) -} - -pub fn serialized_length_from_bytes(b: &[u8]) -> io::Result { - let mut f = Cursor::new(b); - let mut ops = vec![ParseOp::SExp]; - let mut b = [0; 1]; - loop { - let op = ops.pop(); - if op.is_none() { - break; - } - match op.unwrap() { - ParseOp::SExp => { - f.read_exact(&mut b)?; - if b[0] == CONS_BOX_MARKER { - // since all we're doing is to determing the length of the - // serialized buffer, we don't need to do anything about - // "cons". So we skip pushing it to lower the pressure on - // the op stack - //ops.push(ParseOp::Cons); - ops.push(ParseOp::SExp); - ops.push(ParseOp::SExp); - } else if b[0] == 0x80 || b[0] <= MAX_SINGLE_BYTE { - // This one byte we just read was the whole atom. - // or the - // special case of NIL - } else { - let (_prefix_size, blob_size) = decode_size(&mut f, b[0])?; - f.seek(SeekFrom::Current(blob_size as i64))?; - if (f.get_ref().len() as u64) < f.position() { - return Err(bad_encoding()); - } - } - } - ParseOp::Cons => { - // cons. No need to construct any structure here. Just keep - // going - } - } - } - Ok(f.position()) -} - -use crate::sha2::{Digest, Sha256}; - -fn hash_atom(buf: &[u8]) -> [u8; 32] { - let mut ctx = Sha256::new(); - ctx.update(&[1_u8]); - ctx.update(buf); - ctx.finalize().into() -} - -fn hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] { - let mut ctx = Sha256::new(); - ctx.update(&[2_u8]); - ctx.update(left); - ctx.update(right); - ctx.finalize().into() -} - -// computes the tree-hash of a CLVM structure in serialized form -pub fn tree_hash_from_stream(f: &mut Cursor<&[u8]>) -> io::Result<[u8; 32]> { - let mut values: Vec<[u8; 32]> = Vec::new(); - let mut ops = vec![ParseOp::SExp]; - - let mut b = [0; 1]; - loop { - let op = ops.pop(); - if op.is_none() { - break; - } - match op.unwrap() { - ParseOp::SExp => { - f.read_exact(&mut b)?; - if b[0] == CONS_BOX_MARKER { - ops.push(ParseOp::Cons); - ops.push(ParseOp::SExp); - ops.push(ParseOp::SExp); - } else if b[0] == 0x80 { - values.push(hash_atom(&[])); - } else if b[0] <= MAX_SINGLE_BYTE { - values.push(hash_atom(&b)); - } else { - let (_, blob_size) = decode_size(f, b[0])?; - let blob = &f.get_ref()[f.position() as usize..]; - if blob.len() < blob_size as usize { - return Err(bad_encoding()); - } - let blob_size = blob_size as u64; - f.set_position(f.position() + blob_size); - values.push(hash_atom(&blob[..blob_size as usize])); - } - } - ParseOp::Cons => { - // cons - let v2 = values.pop(); - let v1 = values.pop(); - values.push(hash_pair(&v1.unwrap(), &v2.unwrap())); - } - } - } - Ok(values.pop().unwrap()) -} - -#[test] -fn test_tree_hash_max_single_byte() { - let mut ctx = Sha256::new(); - ctx.update(&[1_u8]); - ctx.update(&[0x7f_u8]); - let mut cursor = Cursor::<&[u8]>::new(&[0x7f_u8]); - assert_eq!( - tree_hash_from_stream(&mut cursor).unwrap(), - ctx.finalize().as_slice() - ); -} - -#[test] -fn test_tree_hash_one() { - let mut ctx = Sha256::new(); - ctx.update(&[1_u8]); - ctx.update(&[1_u8]); - let mut cursor = Cursor::<&[u8]>::new(&[1_u8]); - assert_eq!( - tree_hash_from_stream(&mut cursor).unwrap(), - ctx.finalize().as_slice() - ); -} - -#[test] -fn test_tree_hash_zero() { - let mut ctx = Sha256::new(); - ctx.update(&[1_u8]); - ctx.update(&[0_u8]); - let mut cursor = Cursor::<&[u8]>::new(&[0_u8]); - assert_eq!( - tree_hash_from_stream(&mut cursor).unwrap(), - ctx.finalize().as_slice() - ); -} - -#[test] -fn test_tree_hash_nil() { - let mut ctx = Sha256::new(); - ctx.update(&[1_u8]); - let mut cursor = Cursor::<&[u8]>::new(&[0x80_u8]); - assert_eq!( - tree_hash_from_stream(&mut cursor).unwrap(), - ctx.finalize().as_slice() - ); -} - -#[test] -fn test_tree_hash_overlong() { - let mut cursor = Cursor::<&[u8]>::new(&[0x8f, 0xff]); - let e = tree_hash_from_stream(&mut cursor).unwrap_err(); - assert_eq!(e.kind(), bad_encoding().kind()); - - let mut cursor = Cursor::<&[u8]>::new(&[0b11001111, 0xff]); - let e = tree_hash_from_stream(&mut cursor).unwrap_err(); - assert_eq!(e.kind(), bad_encoding().kind()); - - let mut cursor = Cursor::<&[u8]>::new(&[0b11001111, 0xff, 0, 0]); - let e = tree_hash_from_stream(&mut cursor).unwrap_err(); - assert_eq!(e.kind(), bad_encoding().kind()); -} - -#[cfg(test)] -use hex::FromHex; - -// these test cases were produced by: - -// from chia.types.blockchain_format.program import Program -// a = Program.to(...) -// print(bytes(a).hex()) -// print(a.get_tree_hash().hex()) - -#[test] -fn test_tree_hash_list() { - // this is the list (1 (2 (3 (4 (5 ()))))) - let buf = Vec::from_hex("ff01ff02ff03ff04ff0580").unwrap(); - let mut cursor = Cursor::<&[u8]>::new(&buf); - assert_eq!( - tree_hash_from_stream(&mut cursor).unwrap().to_vec(), - Vec::from_hex("123190dddde51acfc61f48429a879a7b905d1726a52991f7d63349863d06b1b6").unwrap() - ); -} - -#[test] -fn test_tree_hash_tree() { - // this is the tree ((1, 2), (3, 4)) - let buf = Vec::from_hex("ffff0102ff0304").unwrap(); - let mut cursor = Cursor::<&[u8]>::new(&buf); - assert_eq!( - tree_hash_from_stream(&mut cursor).unwrap().to_vec(), - Vec::from_hex("2824018d148bc6aed0847e2c86aaa8a5407b916169f15b12cea31fa932fc4c8d").unwrap() - ); -} - -#[test] -fn test_tree_hash_tree_large_atom() { - // this is the tree ((1, 2), (3, b"foobar")) - let buf = Vec::from_hex("ffff0102ff0386666f6f626172").unwrap(); - let mut cursor = Cursor::<&[u8]>::new(&buf); - assert_eq!( - tree_hash_from_stream(&mut cursor).unwrap().to_vec(), - Vec::from_hex("b28d5b401bd02b65b7ed93de8e916cfc488738323e568bcca7e032c3a97a12e4").unwrap() - ); -} - -#[test] -fn test_serialized_length_from_bytes() { - assert_eq!( - serialized_length_from_bytes(&[0x7f, 0x00, 0x00, 0x00]).unwrap(), - 1 - ); - assert_eq!( - serialized_length_from_bytes(&[0x80, 0x00, 0x00, 0x00]).unwrap(), - 1 - ); - assert_eq!( - serialized_length_from_bytes(&[0xff, 0x00, 0x00, 0x00]).unwrap(), - 3 - ); - assert_eq!( - serialized_length_from_bytes(&[0xff, 0x01, 0xff, 0x80, 0x80, 0x00]).unwrap(), - 5 - ); - - let e = serialized_length_from_bytes(&[0x8f, 0xff]).unwrap_err(); - assert_eq!(e.kind(), bad_encoding().kind()); - assert_eq!(e.to_string(), "bad encoding"); - - let e = serialized_length_from_bytes(&[0b11001111, 0xff]).unwrap_err(); - assert_eq!(e.kind(), bad_encoding().kind()); - assert_eq!(e.to_string(), "bad encoding"); - - let e = serialized_length_from_bytes(&[0b11001111, 0xff, 0, 0]).unwrap_err(); - assert_eq!(e.kind(), bad_encoding().kind()); - assert_eq!(e.to_string(), "bad encoding"); - - assert_eq!( - serialized_length_from_bytes(&[0x8f, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).unwrap(), - 16 - ); -} - -#[test] -fn test_write_atom_encoding_prefix_with_size() { - let mut buf = Vec::::new(); - assert!(write_atom_encoding_prefix_with_size(&mut buf, 0, 0).is_ok()); - assert_eq!(buf, vec![0x80]); - - for v in 0..0x7f { - let mut buf = Vec::::new(); - assert!(write_atom_encoding_prefix_with_size(&mut buf, v, 1).is_ok()); - assert_eq!(buf, vec![]); - } - - for v in 0x80..0xff { - let mut buf = Vec::::new(); - assert!(write_atom_encoding_prefix_with_size(&mut buf, v, 1).is_ok()); - assert_eq!(buf, vec![0x81]); - } - - for size in 0x1_u8..0x3f_u8 { - let mut buf = Vec::::new(); - assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, size as u64).is_ok()); - assert_eq!(buf, vec![0x80 + size]); - } - - let mut buf = Vec::::new(); - assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0b111111).is_ok()); - assert_eq!(buf, vec![0b10111111]); - - let mut buf = Vec::::new(); - assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0b1000000).is_ok()); - assert_eq!(buf, vec![0b11000000, 0b1000000]); - - let mut buf = Vec::::new(); - assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0xfffff).is_ok()); - assert_eq!(buf, vec![0b11101111, 0xff, 0xff]); - - let mut buf = Vec::::new(); - assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0xffffff).is_ok()); - assert_eq!(buf, vec![0b11110000, 0xff, 0xff, 0xff]); - - let mut buf = Vec::::new(); - assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0xffffffff).is_ok()); - assert_eq!(buf, vec![0b11111000, 0xff, 0xff, 0xff, 0xff]); - - // this is the largest possible atom size - let mut buf = Vec::::new(); - assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0x3ffffffff).is_ok()); - assert_eq!(buf, vec![0b11111011, 0xff, 0xff, 0xff, 0xff]); - - // this is too large - let mut buf = Vec::::new(); - assert!(!write_atom_encoding_prefix_with_size(&mut buf, 0xaa, 0x400000000).is_ok()); - - for (size, expected_prefix) in [ - (0x1, vec![0x81]), - (0x2, vec![0x82]), - (0x3f, vec![0xbf]), - (0x40, vec![0xc0, 0x40]), - (0x1fff, vec![0xdf, 0xff]), - (0x2000, vec![0xe0, 0x20, 0x00]), - (0xf_ffff, vec![0xef, 0xff, 0xff]), - (0x10_0000, vec![0xf0, 0x10, 0x00, 0x00]), - (0x7ff_ffff, vec![0xf7, 0xff, 0xff, 0xff]), - (0x800_0000, vec![0xf8, 0x08, 0x00, 0x00, 0x00]), - (0x3_ffff_ffff, vec![0xfb, 0xff, 0xff, 0xff, 0xff]), - ] { - let mut buf = Vec::::new(); - assert!(write_atom_encoding_prefix_with_size(&mut buf, 0xaa, size).is_ok()); - assert_eq!(buf, expected_prefix); - } -} - -#[test] -fn test_write_atom() { - let mut buf = Vec::::new(); - assert!(write_atom(&mut buf, &vec![]).is_ok()); - assert_eq!(buf, vec![0b10000000]); - - let mut buf = Vec::::new(); - assert!(write_atom(&mut buf, &vec![0x00]).is_ok()); - assert_eq!(buf, vec![0b00000000]); - - let mut buf = Vec::::new(); - assert!(write_atom(&mut buf, &vec![0x7f]).is_ok()); - assert_eq!(buf, vec![0x7f]); - - let mut buf = Vec::::new(); - assert!(write_atom(&mut buf, &vec![0x80]).is_ok()); - assert_eq!(buf, vec![0x81, 0x80]); - - let mut buf = Vec::::new(); - assert!(write_atom(&mut buf, &vec![0xff]).is_ok()); - assert_eq!(buf, vec![0x81, 0xff]); - - let mut buf = Vec::::new(); - assert!(write_atom(&mut buf, &vec![0xaa, 0xbb]).is_ok()); - assert_eq!(buf, vec![0x82, 0xaa, 0xbb]); - - for (size, mut expected_prefix) in [ - (0x1, vec![0x81]), - (0x2, vec![0x82]), - (0x3f, vec![0xbf]), - (0x40, vec![0xc0, 0x40]), - (0x1fff, vec![0xdf, 0xff]), - (0x2000, vec![0xe0, 0x20, 0x00]), - (0xf_ffff, vec![0xef, 0xff, 0xff]), - (0x10_0000, vec![0xf0, 0x10, 0x00, 0x00]), - (0x7ff_ffff, vec![0xf7, 0xff, 0xff, 0xff]), - (0x800_0000, vec![0xf8, 0x08, 0x00, 0x00, 0x00]), - // the next one represents 17 GB of memory, which it then has to serialize - // so let's not do it until some time in the future when all machines have - // 64 GB of memory - // (0x3_ffff_ffff, vec![0xfb, 0xff, 0xff, 0xff, 0xff]), - ] { - let mut buf = Vec::::new(); - let atom = vec![0xaa; size]; - assert!(write_atom(&mut buf, &atom).is_ok()); - expected_prefix.extend(atom); - assert_eq!(buf, expected_prefix); - } -} - -#[test] -fn test_decode_size() { - // single-byte length prefix - let mut buffer = Cursor::new(&[]); - assert_eq!(decode_size(&mut buffer, 0x80 | 0x20).unwrap(), (1, 0x20)); - - // two-byte length prefix - let first = 0b11001111; - let mut buffer = Cursor::new(&[0xaa]); - assert_eq!(decode_size(&mut buffer, first).unwrap(), (2, 0xfaa)); -} - -#[test] -fn test_large_decode_size() { - // this is an atom length-prefix 0xffffffffffff, or (2^48 - 1). - // We don't support atoms this large and we should fail before attempting to - // allocate this much memory - let first = 0b11111110; - let mut buffer = Cursor::new(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff]); - let ret = decode_size(&mut buffer, first); - let e = ret.unwrap_err(); - assert_eq!(e.kind(), bad_encoding().kind()); - assert_eq!(e.to_string(), "bad encoding"); - - // this is still too large - let first = 0b11111100; - let mut buffer = Cursor::new(&[0x4, 0, 0, 0, 0]); - let ret = decode_size(&mut buffer, first); - let e = ret.unwrap_err(); - assert_eq!(e.kind(), bad_encoding().kind()); - assert_eq!(e.to_string(), "bad encoding"); - - // But this is *just* within what we support - // Still a very large blob, probably enough for a DoS attack - let first = 0b11111100; - let mut buffer = Cursor::new(&[0x3, 0xff, 0xff, 0xff, 0xff]); - assert_eq!(decode_size(&mut buffer, first).unwrap(), (6, 0x3ffffffff)); -} - -#[test] -fn test_truncated_decode_size() { - // the stream is truncated - let first = 0b11111100; - let mut buffer = Cursor::new(&[0x4, 0, 0, 0]); - let ret = decode_size(&mut buffer, first); - let e = ret.unwrap_err(); - assert_eq!(e.kind(), ErrorKind::UnexpectedEof); -} diff --git a/wheel/python/benchmarks/deserialization.py b/wheel/python/benchmarks/deserialization.py index cd4c155a..3884b777 100644 --- a/wheel/python/benchmarks/deserialization.py +++ b/wheel/python/benchmarks/deserialization.py @@ -22,14 +22,14 @@ def bench(f, name: str): # breakpoint() -obj = bench(lambda: Program.parse(open("block-2500014.compressed.bin", "rb")), "parse") -bench(lambda: bytes(obj), "to_bytes") +obj = bench(lambda: Program.parse(open("block-2500014.compressed.bin", "rb")), "obj = Program.parse(open([file]))") +bench(lambda: bytes(obj), "bytes(obj)") obj1 = bench( lambda: Program.from_bytes(open("block-2500014.compressed.bin", "rb").read()), - "from_bytes", + "obj = Program.from_bytes([blob])", ) -bench(lambda: bytes(obj1), "to_bytes") +bench(lambda: bytes(obj1), "bytes(obj)") cost, output = bench(lambda: obj.run_with_cost(0), "run") @@ -55,9 +55,9 @@ def bench(f, name: str): bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash") bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash again") -reparsed_output = bench(lambda: Program.parse(io.BytesIO(blob)), "parse output") -bench(lambda: print(reparsed_output.tree_hash().hex()), "print parsed tree hash") -bench(lambda: print(reparsed_output.tree_hash().hex()), "print parsed tree hash again") +reparsed_output = bench(lambda: Program.parse(io.BytesIO(blob)), "reparse output") +bench(lambda: print(reparsed_output.tree_hash().hex()), "print reparsed tree hash") +bench(lambda: print(reparsed_output.tree_hash().hex()), "print reparsed tree hash again") foo = Program.to("foo") diff --git a/wheel/python/tests/test_serialize.py b/wheel/python/tests/test_serialize.py index 294817ab..20bec45f 100644 --- a/wheel/python/tests/test_serialize.py +++ b/wheel/python/tests/test_serialize.py @@ -133,7 +133,7 @@ def test_repr_clvm_tree(self): self.assertEqual(repr(o.unwrap()), "") def test_bad_blob(self): - self.assertRaises(ValueError, lambda: Program.fromhex("ff")) + self.assertRaises(OSError, lambda: Program.fromhex("ff")) def test_large_atom(self): s = "foo" * 100 @@ -143,5 +143,5 @@ def test_large_atom(self): self.assertEqual(p, p1) def test_too_large_atom(self): - self.assertRaises(ValueError, lambda: Program.fromhex("fc")) - self.assertRaises(ValueError, lambda: Program.fromhex("fc8000000000")) + self.assertRaises(OSError, lambda: Program.fromhex("fc")) + self.assertRaises(OSError, lambda: Program.fromhex("fc8000000000")) diff --git a/wheel/src/api.rs b/wheel/src/api.rs index 80af234f..dc46ac92 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -5,11 +5,9 @@ use crate::adapt_response::adapt_response; use clvmr::allocator::Allocator; use clvmr::chia_dialect::ChiaDialect; use clvmr::cost::Cost; -use clvmr::deserialize_tree::{deserialize_tree, CLVMTreeBoundary}; use clvmr::reduction::Response; use clvmr::run_program::run_program; -use clvmr::serde::{node_from_bytes, serialized_length_from_bytes}; -use clvmr::deserialize_tree::{parse_triples, ParsedTriple}; +use clvmr::serde::{node_from_bytes, parse_triples, serialized_length_from_bytes, ParsedTriple}; use clvmr::{LIMIT_HEAP, LIMIT_STACK, MEMPOOL_MODE, NO_UNKNOWN_OPS}; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyTuple}; @@ -44,14 +42,14 @@ pub fn run_serialized_chia_program( adapt_response(py, allocator, r) } -fn tuple_for_parsed_triple(py: Python<'_>, p: &CLVMTreeBoundary) -> PyObject { +fn tuple_for_parsed_triple(py: Python<'_>, p: &ParsedTriple) -> PyObject { let tuple = match p { - CLVMTreeBoundary::Atom { + ParsedTriple::Atom { start, end, atom_offset, } => PyTuple::new(py, [*start, *end, *atom_offset as u64]), - CLVMTreeBoundary::Pair { + ParsedTriple::Pair { start, end, right_index, @@ -67,7 +65,7 @@ fn deserialize_as_tree( calculate_tree_hashes: bool, ) -> PyResult<(Vec, Option>)> { let mut cursor = io::Cursor::new(blob); - let (r, tree_hashes) = deserialize_tree(&mut cursor, calculate_tree_hashes)?; + let (r, tree_hashes) = parse_triples(&mut cursor, calculate_tree_hashes)?; let r = r.iter().map(|pt| tuple_for_parsed_triple(py, pt)).collect(); let s = tree_hashes.map(|ths| ths.iter().map(|b| PyBytes::new(py, b).into()).collect()); Ok((r, s)) @@ -77,7 +75,7 @@ fn deserialize_as_tree( fn clvm_rs(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(run_serialized_chia_program, m)?)?; m.add_function(wrap_pyfunction!(serialized_length, m)?)?; - m.add_function(wrap_pyfunction!(deserialize_as_triples, m)?)?; + m.add_function(wrap_pyfunction!(deserialize_as_tree, m)?)?; m.add("NO_UNKNOWN_OPS", NO_UNKNOWN_OPS)?; m.add("LIMIT_HEAP", LIMIT_HEAP)?; From 4a11b0fa0491699eea86dff47f4eebcb00005296 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Thu, 19 Jan 2023 14:19:26 -0800 Subject: [PATCH 14/45] checkpoint --- wheel/python/clvm_rs/__init__.py | 1 - wheel/python/clvm_rs/base.py | 27 +++-- wheel/python/clvm_rs/casts.py | 52 ++++++---- wheel/python/clvm_rs/clvm_tree.py | 26 +++-- wheel/python/clvm_rs/deser.py | 57 ++++++---- wheel/python/clvm_rs/program.py | 144 +++++++++++++++----------- wheel/python/clvm_rs/run_program.py | 27 ----- wheel/python/clvm_rs/tree_hash.py | 38 +++---- wheel/python/tests/test_program.py | 7 +- wheel/python/tests/test_serialize.py | 5 +- wheel/python/tests/test_to_program.py | 31 +++--- 11 files changed, 221 insertions(+), 194 deletions(-) delete mode 100644 wheel/python/clvm_rs/run_program.py diff --git a/wheel/python/clvm_rs/__init__.py b/wheel/python/clvm_rs/__init__.py index 1250efbc..e69de29b 100644 --- a/wheel/python/clvm_rs/__init__.py +++ b/wheel/python/clvm_rs/__init__.py @@ -1 +0,0 @@ -from .base import CLVMObject # noqa: F401 diff --git a/wheel/python/clvm_rs/base.py b/wheel/python/clvm_rs/base.py index cc492194..fa0c33fd 100644 --- a/wheel/python/clvm_rs/base.py +++ b/wheel/python/clvm_rs/base.py @@ -1,14 +1,27 @@ -from typing import Optional, Protocol, Tuple +from typing import Optional, Protocol, Tuple, runtime_checkable -class CLVMObject(Protocol): +@runtime_checkable +class CLVMStorage(Protocol): atom: Optional[bytes] - pair: Optional[Tuple["CLVMObject", "CLVMObject"]] + _cached_sha256_treehash: Optional[bytes] + @property + def pair(self) -> Optional[Tuple["CLVMStorage", "CLVMStorage"]]: + ... + + # optional fields used to speed implementations: + + # `_cached_sha256_treehash: Optional[bytes]` is used by `sha256_treehash` + # `_cached_serialization: bytes` is used by `sexp_to_byte_iterator` to speed up serialization + + +@runtime_checkable +class CLVMStorageFactory(Protocol): @classmethod - def new_atom(cls, v: bytes) -> "CLVMObject": - raise NotImplementedError() + def new_atom(cls, v: bytes) -> "CLVMStorage": + ... @classmethod - def new_pair(cls, p1, p2) -> "CLVMObject": - raise NotImplementedError() + def new_pair(cls, p1: "CLVMStorage", p2: "CLVMStorage") -> "CLVMStorage": + ... diff --git a/wheel/python/clvm_rs/casts.py b/wheel/python/clvm_rs/casts.py index 206bd8a5..31cc21fb 100644 --- a/wheel/python/clvm_rs/casts.py +++ b/wheel/python/clvm_rs/casts.py @@ -1,26 +1,32 @@ -from typing import Any, Tuple, Union +from typing import Any, Callable, List, Optional, SupportsBytes, Tuple, Union + +from .base import CLVMStorage AtomCastableType = Union[ bytes, str, int, + SupportsBytes, None, ] CastableType = Union[ AtomCastableType, - list, + List["CastableType"], Tuple["CastableType", "CastableType"], + CLVMStorage, ] -def looks_like_clvm_object(o: Any) -> bool: - d = dir(o) - return "atom" in d and "pair" in d +def int_from_bytes(blob): + size = len(blob) + if size == 0: + return 0 + return int.from_bytes(blob, "big", signed=True) -def int_to_bytes(v): +def int_to_bytes(v) -> bytes: byte_count = (v.bit_length() + 8) >> 3 if v == 0: return b"" @@ -43,7 +49,7 @@ def to_atom_type(v: AtomCastableType) -> bytes: return v.encode() if isinstance(v, int): return int_to_bytes(v) - if hasattr(v, "__bytes__") or isinstance(v, memoryview): + if isinstance(v, (memoryview, SupportsBytes)): return bytes(v) if v is None: return NULL @@ -53,32 +59,33 @@ def to_atom_type(v: AtomCastableType) -> bytes: def to_clvm_object( v: CastableType, - to_atom_f, - to_pair_f, + to_atom_f: Callable[[bytes], CLVMStorage], + to_pair_f: Callable[[CLVMStorage, CLVMStorage], CLVMStorage], ): - stack = [v] - ops = [(0, None)] # convert + stack: List[CastableType] = [v] + ops: List[Tuple[int, Optional[CastableType]]] = [(0, None)] # convert while len(ops) > 0: op, target = ops.pop() # convert value if op == 0: - if looks_like_clvm_object(stack[-1]): - obj = stack.pop() - if obj.pair is None: - new_obj = to_atom_f(obj.atom) + v = stack.pop() + if isinstance(v, CLVMStorage): + if v.pair is None: + atom = v.atom + assert atom is not None + new_obj = to_atom_f(to_atom_type(atom)) else: - new_obj = to_pair_f(obj.pair[0], obj.pair[1]) + new_obj = to_pair_f(v.pair[0], v.pair[1]) stack.append(new_obj) continue - v = stack.pop() if isinstance(v, tuple): if len(v) != 2: raise ValueError("can't cast tuple of size %d" % len(v)) left, right = v target = len(stack) - ll_right = looks_like_clvm_object(right) - ll_left = looks_like_clvm_object(left) + ll_right = isinstance(right, CLVMStorage) + ll_left = isinstance(left, CLVMStorage) if ll_right and ll_left: stack.append(to_pair_f(left, right)) else: @@ -97,14 +104,17 @@ def to_clvm_object( ops.append((1, target)) # prepend list # we only need to convert if it's not already the right # type - if not looks_like_clvm_object(_): + if not isinstance(_, CLVMStorage): ops.append((0, None)) # convert continue stack.append(to_atom_f(to_atom_type(v))) continue if op == 1: # prepend list - stack[target] = to_pair_f(stack.pop(), stack[target]) + left = stack.pop() + assert isinstance(target, int) + right = stack[target] + stack[target] = to_pair_f(left, right) continue if op == 2: # roll p1 = stack.pop() diff --git a/wheel/python/clvm_rs/clvm_tree.py b/wheel/python/clvm_rs/clvm_tree.py index c0548507..4099d448 100644 --- a/wheel/python/clvm_rs/clvm_tree.py +++ b/wheel/python/clvm_rs/clvm_tree.py @@ -1,11 +1,12 @@ +from .base import CLVMStorage from .deser import deserialize_as_tuples from typing import List, Optional, Tuple -class CLVMTree: +class CLVMTree(CLVMStorage): """ - This object conforms with the `CLVMObject` protocol. It's optimized for + This object conforms with the `CLVMStorage` protocol. It's optimized for deserialization, and keeps a reference to the serialized blob and to a list of tuples of integers, each of which corresponds to a subtree. @@ -36,6 +37,8 @@ class CLVMTree: in well-behaved python code. """ + _pair: Optional[Tuple["CLVMStorage", "CLVMStorage"]] + @classmethod def from_bytes(cls, blob: bytes, calculate_tree_hash: bool = True) -> "CLVMTree": int_tuples, tree_hashes = deserialize_as_tuples( @@ -55,20 +58,15 @@ def __init__( self.tree_hashes = tree_hashes self.index = index self._cached_sha256_treehash = self.tree_hashes[index] + start, end, atom_offset = self.int_tuples[self.index] + if self.blob[start] == 0xFF: + self.atom = None + else: + self.atom = bytes(self.blob[start + atom_offset : end]) + self._pair = None @property - def atom(self) -> Optional[bytes]: - if not hasattr(self, "_atom"): - start, end, atom_offset = self.int_tuples[self.index] - # if `self.blob[start]` is 0xff, it's a pair - if self.blob[start] == 0xFF: - self._atom = None - else: - self._atom = bytes(self.blob[start + atom_offset : end]) - return self._atom - - @property - def pair(self) -> Optional[Tuple["CLVMTree", "CLVMTree"]]: + def pair(self) -> Optional[Tuple["CLVMStorage", "CLVMStorage"]]: if not hasattr(self, "_pair"): tuples, tree_hashes = self.int_tuples, self.tree_hashes start, end, right_index = tuples[self.index] diff --git a/wheel/python/clvm_rs/deser.py b/wheel/python/clvm_rs/deser.py index 817710af..ac1b7b24 100644 --- a/wheel/python/clvm_rs/deser.py +++ b/wheel/python/clvm_rs/deser.py @@ -1,5 +1,6 @@ -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union +from .base import CLVMStorage from .tree_hash import shatree_atom, shatree_pair try: @@ -15,6 +16,9 @@ # ATOM: serialize_offset, serialize_end, atom_offset # PAIR: serialize_offset, serialize_end, right_index +Triple = Tuple[int, int, int] +DeserOp = Callable[[bytes, int, List[Triple], List], int] + def deserialize_as_tuples( blob: bytes, cursor: int, calculate_tree_hash: bool @@ -26,32 +30,49 @@ def deserialize_as_tuples( hashes = [None] * len(tree) return tree, hashes - def save_cursor(index, blob, cursor, obj_list, op_stack): - assert blob[obj_list[index][0]] == 0xFF + def save_cursor( + index: int, + blob: bytes, + cursor: int, + obj_list: List[Triple], + op_stack: List[DeserOp], + ) -> int: + blob_index = obj_list[index][0] + assert blob[blob_index] == 0xFF left_hash = tree_hash_list[index + 1] - right_hash = tree_hash_list[obj_list[index][2]] + hash_index = obj_list[index][2] + right_hash = tree_hash_list[hash_index] tree_hash_list[index] = None if calculate_tree_hash: + assert left_hash is not None + assert right_hash is not None tree_hash_list[index] = shatree_pair(left_hash, right_hash) - obj_list[index] = ( - obj_list[index][0], - cursor, - obj_list[index][2], - ) + v0 = obj_list[index][0] + v2 = obj_list[index][2] + obj_list[index] = (v0, cursor, v2) return cursor - def save_index(index, blob, cursor, obj_list, op_stack): - obj_list[index][2] = len(obj_list) + def save_index( + index: int, + blob: bytes, + cursor: int, + obj_list: List[Triple], + op_stack: List[DeserOp], + ) -> int: + e = obj_list[index] + obj_list[index] = (e[0], e[1], len(obj_list)) return cursor - def parse_obj(blob, cursor, obj_list, op_stack): + def parse_obj( + blob: bytes, cursor: int, obj_list: List[Triple], op_stack: List[DeserOp] + ) -> int: if cursor >= len(blob): raise ValueError("bad encoding") if blob[cursor] == CONS_BOX_MARKER: index = len(obj_list) tree_hash_list.append(None) - obj_list.append([cursor, None, None]) + obj_list.append((cursor, 0, 0)) op_stack.append(lambda *args: save_cursor(index, *args)) op_stack.append(parse_obj) op_stack.append(lambda *args: save_index(index, *args)) @@ -60,14 +81,14 @@ def parse_obj(blob, cursor, obj_list, op_stack): atom_offset, new_cursor = _atom_size_from_cursor(blob, cursor) my_hash = None if calculate_tree_hash: - my_hash = shatree_atom(blob[cursor + atom_offset:new_cursor]) + my_hash = shatree_atom(blob[cursor + atom_offset : new_cursor]) tree_hash_list.append(my_hash) obj_list.append((cursor, new_cursor, atom_offset)) return new_cursor - obj_list = [] - tree_hash_list = [] - op_stack = [parse_obj] + obj_list: List[Triple] = [] + tree_hash_list: List[Optional[bytes]] = [] + op_stack: List[DeserOp] = [parse_obj] while op_stack: f = op_stack.pop() cursor = f(blob, cursor, obj_list, op_stack) @@ -89,7 +110,7 @@ def _atom_size_from_cursor(blob, cursor) -> Tuple[int, int]: bit_mask >>= 1 size_blob = bytes([b]) if bit_count > 1: - size_blob += blob[cursor + 1:cursor + bit_count] + size_blob += blob[cursor + 1 : cursor + bit_count] size = int.from_bytes(size_blob, "big") new_cursor = cursor + size + bit_count if new_cursor > len(blob): diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index 6aba803c..69909a26 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -1,8 +1,8 @@ from __future__ import annotations from typing import Dict, Iterator, List, Tuple, Optional, Any -from .base import CLVMObject -from .casts import to_clvm_object, int_to_bytes +from .base import CLVMStorage +from .casts import to_clvm_object, int_from_bytes, int_to_bytes from .EvalError import EvalError from clvm_rs.clvm_rs import run_serialized_program from clvm_rs.serialize import sexp_from_stream, sexp_to_stream, sexp_to_bytes @@ -15,7 +15,7 @@ MAX_COST = 0x7FFFFFFFFFFFFFFF -class Program(CLVMObject): +class Program(CLVMStorage): """ A thin wrapper around s-expression data intended to be invoked with "eval". """ @@ -41,10 +41,7 @@ def from_bytes_with_cursor( cls, blob: bytes, cursor: int, calculate_tree_hash: bool = True ) -> Tuple[Program, int]: tree = CLVMTree.from_bytes(blob, calculate_tree_hash=calculate_tree_hash) - if tree.atom is not None: - obj = cls.new_atom(tree.atom) - else: - obj = cls.wrap(tree) + obj = cls.wrap(tree) new_cursor = len(bytes(tree)) + cursor return obj, new_cursor @@ -71,45 +68,49 @@ def int_to_bytes(cls, i: int) -> bytes: # high level casting with `.to` - def __init__(self) -> Program: + def __init__(self): self.atom = b"" - self.pair = None - self.wrapped = self + self._pair = None + self._unwrapped_pair = None self._cached_sha256_treehash = None + @property + def pair(self) -> Optional[Tuple["Program", "Program"]]: + if self._pair is None and self.atom is None: + pair = self._unwrapped_pair + self._pair = (self.wrap(pair[0]), self.wrap(pair[1])) + return self._pair + @classmethod def to(cls, v: Any) -> Program: return cls.wrap(to_clvm_object(v, cls.new_atom, cls.new_pair)) @classmethod - def wrap(cls, v: CLVMObject) -> Program: + def wrap(cls, v: CLVMStorage) -> Program: if isinstance(v, Program): return v o = cls() o.atom = v.atom - o.pair = v.pair - o.wrapped = v + o._pair = None + o._unwrapped_pair = v.pair return o - def unwrap(self) -> CLVMObject: - return self.wrapped - # new object creation on the python heap @classmethod def new_atom(cls, v: bytes) -> Program: o = cls() o.atom = bytes(v) - o.pair = None - o.wrapped = o + o._pair = None + o._unwrapped_pair = None return o @classmethod - def new_pair(cls, left: CLVMObject, right: CLVMObject) -> Program: + def new_pair(cls, left: CLVMStorage, right: CLVMStorage) -> Program: o = cls() o.atom = None - o.pair = (left, right) - o.wrapped = o + o._pair = None + o._unwrapped_pair = (left, right) return o @classmethod @@ -125,14 +126,18 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({str(self)})" def __eq__(self, other) -> bool: - stack = [(self, Program.to(other))] + stack: List[Tuple[CLVMStorage, CLVMStorage]] = [(self, Program.to(other))] while stack: p1, p2 = stack.pop() if p1.atom is None: if p2.atom is not None: return False - stack.append((p1.pair[1], p2.pair[1])) - stack.append((p1.pair[0], p2.pair[0])) + pair_1 = p1.pair + pair_2 = p2.pair + assert pair_1 is not None + assert pair_2 is not None + stack.append((pair_1[1], pair_2[1])) + stack.append((pair_1[0], pair_2[0])) else: if p1.atom != p2.atom: return False @@ -143,18 +148,16 @@ def __ne__(self, other) -> bool: def first(self) -> Optional[Program]: if self.pair: - return self.wrap(self.pair[0]) + return self.pair[0] return None def rest(self) -> Optional[Program]: if self.pair: - return self.wrap(self.pair[1]) + return self.pair[1] return None def as_pair(self) -> Optional[Tuple[Program, Program]]: - if self.pair: - return tuple(self.wrap(_) for _ in self.pair) - return None + return self.pair def as_atom(self) -> Optional[bytes]: return self.atom @@ -167,13 +170,13 @@ def nullp(self) -> bool: def list_len(self) -> int: c = 0 - v = self - while v.pair: + v: CLVMStorage = self + while v.pair is not None: v = v.pair[1] c += 1 return c - def at(self, position: str) -> "Program": + def at(self, position: str) -> Optional["Program"]: """ Take a string of `f` and `r` characters and follow that path. @@ -192,8 +195,10 @@ def at(self, position: str) -> "Program": ``` """ - v = self + v: Optional[Program] = self for c in position.lower(): + if v is None: + return v if c == "f": v = v.first() elif c == "r": @@ -230,13 +235,13 @@ def replace(self, **kwargs) -> "Program": return _replace(self, **kwargs) def tree_hash(self) -> bytes32: - return sha256_treehash(self.unwrap()) + return sha256_treehash(self) def run_with_cost(self, args, max_cost: int = MAX_COST) -> Tuple[int, "Program"]: prog_bytes = bytes(self) args_bytes = bytes(self.to(args)) cost, r = run_serialized_program(prog_bytes, args_bytes, max_cost, 0) - r = Program.to(r) + r = self.wrap(r) if isinstance(cost, str): raise EvalError(cost, r) return cost, r @@ -245,38 +250,58 @@ def run(self, args) -> "Program": cost, r = self.run_with_cost(args, MAX_COST) return r - # Replicates the curry function from clvm_tools, taking advantage of *args - # being a list. We iterate through args in reverse building the code to - # create a clvm list. - # - # Given arguments to a function addressable by the '1' reference in clvm - # - # fixed_args = 1 - # - # Each arg is prepended as fixed_args = (c (q . arg) fixed_args) - # - # The resulting argument list is interpreted with apply (2) - # - # (2 (1 . self) rest) - # - # Resulting in a function which places its own arguments after those - # curried in in the form of a proper list. + """ + Replicates the curry function from clvm_tools, taking advantage of *args + being a list. We iterate through args in reverse building the code to + create a clvm list. + + Given arguments to a function addressable by the '1' reference in clvm + + fixed_args = 1 + + Each arg is prepended as fixed_args = (c (q . arg) fixed_args) + + The resulting argument list is interpreted with apply (2) + + (2 (1 . self) rest) + + Resulting in a function which places its own arguments after those + curried in in the form of a proper list. + """ + def curry(self, *args) -> "Program": fixed_args: Any = 1 for arg in reversed(args): fixed_args = [4, (1, arg), fixed_args] - return Program.to([2, (1, self), fixed_args]) + return self.to([2, (1, self), fixed_args]) + + """ + uncurry the given program - def uncurry(self) -> Tuple[Program, Optional[Program]]: + returns `mod, [arg1, arg2, ...]` + + if the program is not a valid curry, return `self, NULL` + + This distinguishes it from the case of a valid curry of 0 arguments + (which is rather pointless but possible), which returns `self, []` + """ + + def uncurry(self) -> Tuple[Program, Optional[List[Program]]]: if self.at("f") != A_KW or self.at("rff") != Q_KW or self.at("rrr") != NULL: return self, None uncurried_function = self.at("rfr") + if uncurried_function is None: + return self, None core_items = [] core = self.at("rrf") while core != ONE: + if core is None: + return self, None if core.at("f") != C_KW or core.at("rff") != Q_KW or core.at("rrr") != NULL: return self, None new_item = core.at("rfr") + if new_item is None: + return self, None core_items.append(new_item) core = core.at("rrf") return uncurried_function, core_items @@ -335,12 +360,12 @@ def _replace(program: Program, **kwargs) -> Program: # Now split `kwargs` into two groups: those # that start with `f` and those that start with `r` - args_by_prefix: Dict[str, Program] = {} + args_by_prefix: Dict[str, Dict[str, Program]] = dict(f={}, r={}) for k, v in kwargs.items(): c = k[0] if c not in "fr": raise ValueError(f"bad path containing {c}: must only contain `f` and `r`") - args_by_prefix.setdefault(c, dict())[k[1:]] = v + args_by_prefix[c][k[1:]] = program.to(v) pair = program.pair if pair is None: @@ -350,11 +375,4 @@ def _replace(program: Program, **kwargs) -> Program: new_f = _replace(pair[0], **args_by_prefix.get("f", {})) new_r = _replace(pair[1], **args_by_prefix.get("r", {})) - return program.new_pair(Program.to(new_f), Program.to(new_r)) - - -def int_from_bytes(blob): - size = len(blob) - if size == 0: - return 0 - return int.from_bytes(blob, "big", signed=True) + return program.new_pair(new_f, new_r) diff --git a/wheel/python/clvm_rs/run_program.py b/wheel/python/clvm_rs/run_program.py deleted file mode 100644 index 70932636..00000000 --- a/wheel/python/clvm_rs/run_program.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Tuple - -from clvm_rs import CLVMObject, run_serialized_program, NO_NEG_DIV - - -from .EvalError import EvalError -from .serialize import sexp_to_bytes - - -DEFAULT_MAX_COST = (1 << 64) - 1 -DEFAULT_FLAGS = NO_NEG_DIV - - -def run_program( - program: CLVMObject, - args: CLVMObject, - max_cost=DEFAULT_MAX_COST, - flags=DEFAULT_FLAGS, -) -> Tuple[int, CLVMObject]: - program_blob = sexp_to_bytes(program) - args_blob = sexp_to_bytes(args) - cost_or_err_str, result = run_serialized_program( - program_blob, args_blob, max_cost, flags - ) - if isinstance(cost_or_err_str, str): - raise EvalError(cost_or_err_str, result) - return cost_or_err_str, result diff --git a/wheel/python/clvm_rs/tree_hash.py b/wheel/python/clvm_rs/tree_hash.py index e71ef098..2d38d49b 100644 --- a/wheel/python/clvm_rs/tree_hash.py +++ b/wheel/python/clvm_rs/tree_hash.py @@ -7,8 +7,9 @@ """ from hashlib import sha256 +from typing import List -from clvm_rs import CLVMObject +from clvm_rs.base import CLVMStorage bytes32 = bytes @@ -32,43 +33,36 @@ def shatree_pair(left_hash: bytes32, right_hash: bytes32) -> bytes32: return bytes32(s.digest()) -def sha256_treehash(sexp: CLVMObject) -> bytes32: - def handle_sexp(sexp_stack, op_stack) -> None: +def sha256_treehash(sexp: CLVMStorage) -> bytes32: + def handle_sexp(sexp_stack, hash_stack, op_stack) -> None: sexp = sexp_stack.pop() r = getattr(sexp, "_cached_sha256_treehash", None) if r is not None: - sexp_stack.append(r) + hash_stack.append(r) + return elif sexp.pair: p0, p1 = sexp.pair sexp_stack.append(p0) sexp_stack.append(p1) op_stack.append(handle_pair) op_stack.append(handle_sexp) - op_stack.append(roll) op_stack.append(handle_sexp) else: r = shatree_atom(sexp.atom) - sexp_stack.append(r) - if hasattr(sexp, "_cached_sha256_treehash"): - sexp._cached_sha256_treehash = r - - def handle_pair(sexp_stack, op_stack) -> None: - p0 = sexp_stack.pop() - p1 = sexp_stack.pop() - r = shatree_pair(p0, p1) - sexp_stack.append(r) - if hasattr(sexp, "_cached_sha256_treehash"): + hash_stack.append(r) sexp._cached_sha256_treehash = r - def roll(sexp_stack, op_stack) -> None: - p0 = sexp_stack.pop() - p1 = sexp_stack.pop() - sexp_stack.append(p0) - sexp_stack.append(p1) + def handle_pair(sexp_stack, hash_stack, op_stack) -> None: + p0 = hash_stack.pop() + p1 = hash_stack.pop() + r = shatree_pair(p0, p1) + hash_stack.append(r) + sexp._cached_sha256_treehash = r sexp_stack = [sexp] op_stack = [handle_sexp] + hash_stack: List[bytes32] = [] while len(op_stack) > 0: op = op_stack.pop() - op(sexp_stack, op_stack) - return bytes32(sexp_stack[0]) + op(sexp_stack, hash_stack, op_stack) + return hash_stack[0] diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py index 2299f66a..26f13d35 100644 --- a/wheel/python/tests/test_program.py +++ b/wheel/python/tests/test_program.py @@ -1,6 +1,6 @@ from unittest import TestCase -from clvm_rs import CLVMObject +from clvm_rs.base import CLVMStorage from clvm_rs.program import Program from clvm_rs.EvalError import EvalError @@ -45,11 +45,6 @@ def test_replace_bad_path(self): self.assertRaises(ValueError, lambda: p1.replace(q=105)) self.assertRaises(ValueError, lambda: p1.replace(rq=105)) - def test_protocol(self): - nil = Program.to(0) - self.assertRaises(NotImplementedError, lambda: CLVMObject.new_atom(nil)) - self.assertRaises(NotImplementedError, lambda: CLVMObject.new_pair(nil, nil)) - def test_first_rest(self): p = Program.to([4, 5]) self.assertEqual(p.first(), 4) diff --git a/wheel/python/tests/test_serialize.py b/wheel/python/tests/test_serialize.py index 20bec45f..3acee12c 100644 --- a/wheel/python/tests/test_serialize.py +++ b/wheel/python/tests/test_serialize.py @@ -129,8 +129,9 @@ def test_deserialize_large_blob(self): Program.parse(InfiniteStream(bytes_in)) def test_repr_clvm_tree(self): - o = Program.fromhex("ff8080") - self.assertEqual(repr(o.unwrap()), "") + o = Program.fromhex("ff8085") + self.assertEqual(repr(o._unwrapped_pair[0]), "") + self.assertEqual(repr(o._unwrapped_pair[1]), "") def test_bad_blob(self): self.assertRaises(OSError, lambda: Program.fromhex("ff")) diff --git a/wheel/python/tests/test_to_program.py b/wheel/python/tests/test_to_program.py index b95ec0e3..1d1cffba 100644 --- a/wheel/python/tests/test_to_program.py +++ b/wheel/python/tests/test_to_program.py @@ -1,6 +1,7 @@ import unittest from typing import Optional, Tuple, Any +from clvm_rs.base import CLVMStorage from clvm_rs.program import Program @@ -8,11 +9,6 @@ def convert_atom_to_bytes(castable: Any) -> Optional[bytes]: return Program.to(castable).atom -def looks_like_clvm_object(o: Any) -> bool: - d = dir(o) - return "atom" in d and "pair" in d - - def validate_program(program): validate_stack = [program] while validate_stack: @@ -21,8 +17,8 @@ def validate_program(program): if v.pair: assert isinstance(v.pair, tuple) v1, v2 = v.pair - assert looks_like_clvm_object(v1) - assert looks_like_clvm_object(v2) + assert isinstance(v1, CLVMStorage) + assert isinstance(v2, CLVMStorage) s1, s2 = v.as_pair() validate_stack.append(s1) validate_stack.append(s2) @@ -67,7 +63,7 @@ def test_cast_1(self): validate_program(t1) def test_wrap_program(self): - # it's a bit of a layer violation that CLVMObject unwraps Program, but we + # it's a bit of a layer violation that CLVMStorage unwraps Program, but we # rely on that in a fair number of places for now. We should probably # work towards phasing that out o = Program.to(Program.to(1)) @@ -76,7 +72,7 @@ def test_wrap_program(self): def test_arbitrary_underlying_tree(self): # Program provides a view on top of a tree of arbitrary types, as long as - # those types implement the CLVMObject protocol. This is an example of + # those types implement the CLVMStorage protocol. This is an example of # a tree that's generated class GeneratedTree: @@ -87,6 +83,7 @@ def __init__(self, depth, val): assert depth >= 0 self.depth = depth self.val = val + self._cached_sha256_treehash = None @property def atom(self) -> Optional[bytes]: @@ -128,18 +125,26 @@ class dummy: pass obj = dummy() - obj.pair = None obj.atom = None + obj.pair = None + obj._cached_sha256_treehash = None print(dir(obj)) - assert looks_like_clvm_object(obj) + assert isinstance(obj, CLVMStorage) obj = dummy() obj.pair = None - assert not looks_like_clvm_object(obj) + obj._cached_sha256_treehash = None + assert not isinstance(obj, CLVMStorage) obj = dummy() obj.atom = None - assert not looks_like_clvm_object(obj) + obj._cached_sha256_treehash = None + assert not isinstance(obj, CLVMStorage) + + obj = dummy() + obj.atom = None + obj.pair = None + assert not isinstance(obj, CLVMStorage) def test_list_conversions(self): a = Program.to([1, 2, 3]) From 48b2b74b9bf2af27d8fe16f2d43c5df6157865be Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Thu, 19 Jan 2023 16:07:22 -0800 Subject: [PATCH 15/45] Rename `base`. --- wheel/python/clvm_rs/base.py | 27 --------------------------- wheel/python/clvm_rs/casts.py | 2 +- wheel/python/clvm_rs/clvm_tree.py | 2 +- wheel/python/clvm_rs/deser.py | 2 +- wheel/python/clvm_rs/program.py | 2 +- wheel/python/clvm_rs/tree_hash.py | 2 +- wheel/python/tests/test_program.py | 2 +- wheel/python/tests/test_to_program.py | 2 +- 8 files changed, 7 insertions(+), 34 deletions(-) delete mode 100644 wheel/python/clvm_rs/base.py diff --git a/wheel/python/clvm_rs/base.py b/wheel/python/clvm_rs/base.py deleted file mode 100644 index fa0c33fd..00000000 --- a/wheel/python/clvm_rs/base.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Optional, Protocol, Tuple, runtime_checkable - - -@runtime_checkable -class CLVMStorage(Protocol): - atom: Optional[bytes] - _cached_sha256_treehash: Optional[bytes] - - @property - def pair(self) -> Optional[Tuple["CLVMStorage", "CLVMStorage"]]: - ... - - # optional fields used to speed implementations: - - # `_cached_sha256_treehash: Optional[bytes]` is used by `sha256_treehash` - # `_cached_serialization: bytes` is used by `sexp_to_byte_iterator` to speed up serialization - - -@runtime_checkable -class CLVMStorageFactory(Protocol): - @classmethod - def new_atom(cls, v: bytes) -> "CLVMStorage": - ... - - @classmethod - def new_pair(cls, p1: "CLVMStorage", p2: "CLVMStorage") -> "CLVMStorage": - ... diff --git a/wheel/python/clvm_rs/casts.py b/wheel/python/clvm_rs/casts.py index 31cc21fb..076aae92 100644 --- a/wheel/python/clvm_rs/casts.py +++ b/wheel/python/clvm_rs/casts.py @@ -1,6 +1,6 @@ from typing import Any, Callable, List, Optional, SupportsBytes, Tuple, Union -from .base import CLVMStorage +from .clvm_storage import CLVMStorage AtomCastableType = Union[ bytes, diff --git a/wheel/python/clvm_rs/clvm_tree.py b/wheel/python/clvm_rs/clvm_tree.py index 4099d448..8dceebc2 100644 --- a/wheel/python/clvm_rs/clvm_tree.py +++ b/wheel/python/clvm_rs/clvm_tree.py @@ -1,4 +1,4 @@ -from .base import CLVMStorage +from .clvm_storage import CLVMStorage from .deser import deserialize_as_tuples from typing import List, Optional, Tuple diff --git a/wheel/python/clvm_rs/deser.py b/wheel/python/clvm_rs/deser.py index ac1b7b24..9ea9ea38 100644 --- a/wheel/python/clvm_rs/deser.py +++ b/wheel/python/clvm_rs/deser.py @@ -1,6 +1,6 @@ from typing import Callable, List, Optional, Tuple, Union -from .base import CLVMStorage +from .clvm_storage import CLVMStorage from .tree_hash import shatree_atom, shatree_pair try: diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index 69909a26..f0c797e4 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Dict, Iterator, List, Tuple, Optional, Any -from .base import CLVMStorage +from .clvm_storage import CLVMStorage from .casts import to_clvm_object, int_from_bytes, int_to_bytes from .EvalError import EvalError from clvm_rs.clvm_rs import run_serialized_program diff --git a/wheel/python/clvm_rs/tree_hash.py b/wheel/python/clvm_rs/tree_hash.py index 2d38d49b..7fd82ac6 100644 --- a/wheel/python/clvm_rs/tree_hash.py +++ b/wheel/python/clvm_rs/tree_hash.py @@ -9,7 +9,7 @@ from hashlib import sha256 from typing import List -from clvm_rs.base import CLVMStorage +from clvm_rs.clvm_storage import CLVMStorage bytes32 = bytes diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py index 26f13d35..ac6cbd66 100644 --- a/wheel/python/tests/test_program.py +++ b/wheel/python/tests/test_program.py @@ -1,6 +1,6 @@ from unittest import TestCase -from clvm_rs.base import CLVMStorage +from clvm_rs.clvm_storage import CLVMStorage from clvm_rs.program import Program from clvm_rs.EvalError import EvalError diff --git a/wheel/python/tests/test_to_program.py b/wheel/python/tests/test_to_program.py index 1d1cffba..11974136 100644 --- a/wheel/python/tests/test_to_program.py +++ b/wheel/python/tests/test_to_program.py @@ -1,7 +1,7 @@ import unittest from typing import Optional, Tuple, Any -from clvm_rs.base import CLVMStorage +from clvm_rs.clvm_storage import CLVMStorage from clvm_rs.program import Program From 23ab08eb83e4ac63b65b4d7f36985fc58f58b72b Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Thu, 19 Jan 2023 16:09:24 -0800 Subject: [PATCH 16/45] Rename to `eval_error.py`. --- wheel/python/clvm_rs/{EvalError.py => eval_error.py} | 0 wheel/python/clvm_rs/program.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename wheel/python/clvm_rs/{EvalError.py => eval_error.py} (100%) diff --git a/wheel/python/clvm_rs/EvalError.py b/wheel/python/clvm_rs/eval_error.py similarity index 100% rename from wheel/python/clvm_rs/EvalError.py rename to wheel/python/clvm_rs/eval_error.py diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index f0c797e4..4512a93f 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -3,7 +3,7 @@ from .clvm_storage import CLVMStorage from .casts import to_clvm_object, int_from_bytes, int_to_bytes -from .EvalError import EvalError +from .eval_error import EvalError from clvm_rs.clvm_rs import run_serialized_program from clvm_rs.serialize import sexp_from_stream, sexp_to_stream, sexp_to_bytes from clvm_rs.tree_hash import sha256_treehash From a15c2c1a818121c234b0866c3f252f4c373b8ca9 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Thu, 19 Jan 2023 16:13:07 -0800 Subject: [PATCH 17/45] Rename modules. --- wheel/python/clvm_rs/clvm_tree.py | 2 +- wheel/python/clvm_rs/{deser.py => de.py} | 0 wheel/python/clvm_rs/program.py | 12 ++++++------ wheel/python/clvm_rs/{serialize.py => ser.py} | 0 4 files changed, 7 insertions(+), 7 deletions(-) rename wheel/python/clvm_rs/{deser.py => de.py} (100%) rename wheel/python/clvm_rs/{serialize.py => ser.py} (100%) diff --git a/wheel/python/clvm_rs/clvm_tree.py b/wheel/python/clvm_rs/clvm_tree.py index 8dceebc2..6c894589 100644 --- a/wheel/python/clvm_rs/clvm_tree.py +++ b/wheel/python/clvm_rs/clvm_tree.py @@ -1,5 +1,5 @@ from .clvm_storage import CLVMStorage -from .deser import deserialize_as_tuples +from .de import deserialize_as_tuples from typing import List, Optional, Tuple diff --git a/wheel/python/clvm_rs/deser.py b/wheel/python/clvm_rs/de.py similarity index 100% rename from wheel/python/clvm_rs/deser.py rename to wheel/python/clvm_rs/de.py diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index 4512a93f..5e4f5ad0 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -1,15 +1,15 @@ from __future__ import annotations from typing import Dict, Iterator, List, Tuple, Optional, Any -from .clvm_storage import CLVMStorage +from .bytes32 import bytes32 from .casts import to_clvm_object, int_from_bytes, int_to_bytes -from .eval_error import EvalError -from clvm_rs.clvm_rs import run_serialized_program -from clvm_rs.serialize import sexp_from_stream, sexp_to_stream, sexp_to_bytes -from clvm_rs.tree_hash import sha256_treehash +from .clvm_rs import run_serialized_program +from .clvm_storage import CLVMStorage from .clvm_tree import CLVMTree -from .bytes32 import bytes32 from .keywords import NULL, ONE, Q_KW, A_KW, C_KW +from .eval_error import EvalError +from .ser import sexp_from_stream, sexp_to_stream, sexp_to_bytes +from .tree_hash import sha256_treehash MAX_COST = 0x7FFFFFFFFFFFFFFF diff --git a/wheel/python/clvm_rs/serialize.py b/wheel/python/clvm_rs/ser.py similarity index 100% rename from wheel/python/clvm_rs/serialize.py rename to wheel/python/clvm_rs/ser.py From 31af14631f1020b62312aaf9fa44e0139b9d2666 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Thu, 19 Jan 2023 17:05:33 -0800 Subject: [PATCH 18/45] Add comments, improve implementations. --- wheel/python/clvm_rs/casts.py | 115 +++++++++++++++----------- wheel/python/clvm_rs/ser.py | 47 ++++++++--- wheel/python/tests/test_program.py | 2 +- wheel/python/tests/test_serialize.py | 2 +- wheel/python/tests/test_to_program.py | 8 -- 5 files changed, 103 insertions(+), 71 deletions(-) diff --git a/wheel/python/clvm_rs/casts.py b/wheel/python/clvm_rs/casts.py index 076aae92..35b74046 100644 --- a/wheel/python/clvm_rs/casts.py +++ b/wheel/python/clvm_rs/casts.py @@ -1,3 +1,7 @@ +""" +Some utilities to cast python types to and from clvm. +""" + from typing import Any, Callable, List, Optional, SupportsBytes, Tuple, Union from .clvm_storage import CLVMStorage @@ -11,6 +15,8 @@ ] +# as of January 2023, mypy does not like this recursive definition + CastableType = Union[ AtomCastableType, List["CastableType"], @@ -19,7 +25,13 @@ ] +NULL_BLOB = b"" + + def int_from_bytes(blob): + """ + Convert a bytes blob encoded as a clvm int to a python int. + """ size = len(blob) if size == 0: return 0 @@ -27,6 +39,9 @@ def int_from_bytes(blob): def int_to_bytes(v) -> bytes: + """ + Convert a python int to a blob that encodes as this integer in clvm. + """ byte_count = (v.bit_length() + 8) >> 3 if v == 0: return b"" @@ -38,11 +53,11 @@ def int_to_bytes(v) -> bytes: return r -NULL = b"" - - def to_atom_type(v: AtomCastableType) -> bytes: - + """ + Convert an `AtomCastableType` to `bytes`. This for use with the + convenience function `Program.to`. + """ if isinstance(v, bytes): return v if isinstance(v, str): @@ -52,24 +67,37 @@ def to_atom_type(v: AtomCastableType) -> bytes: if isinstance(v, (memoryview, SupportsBytes)): return bytes(v) if v is None: - return NULL + return NULL_BLOB raise ValueError("can't cast %s (%s) to bytes" % (type(v), v)) def to_clvm_object( - v: CastableType, + castable: CastableType, to_atom_f: Callable[[bytes], CLVMStorage], to_pair_f: Callable[[CLVMStorage, CLVMStorage], CLVMStorage], ): - stack: List[CastableType] = [v] - ops: List[Tuple[int, Optional[CastableType]]] = [(0, None)] # convert + """ + Convert a python object to clvm object. + + This works on nested tuples and lists of potentially unlimited depth. + It is non-recursive, so nesting depth is not limited by the call stack. + """ + to_convert: List[CastableType] = [castable] + did_convert: List[CLVMStorage] = [] + ops: List[int] = [0] + + # operations: + # 0: pop `to_convert` and convert if possible, storing result on `did_convert`, + # or subdivide task, pushing multiple things on `to_convert` (and new ops) + # 1: pop two items from `did_convert` and cons them, pushing result to `did_convert` + # 2: same as 1 but cons in opposite order. Necessary for converting lists while len(ops) > 0: - op, target = ops.pop() + op = ops.pop() # convert value if op == 0: - v = stack.pop() + v = to_convert.pop() if isinstance(v, CLVMStorage): if v.pair is None: atom = v.atom @@ -77,59 +105,50 @@ def to_clvm_object( new_obj = to_atom_f(to_atom_type(atom)) else: new_obj = to_pair_f(v.pair[0], v.pair[1]) - stack.append(new_obj) + did_convert.append(new_obj) continue if isinstance(v, tuple): if len(v) != 2: raise ValueError("can't cast tuple of size %d" % len(v)) left, right = v - target = len(stack) ll_right = isinstance(right, CLVMStorage) ll_left = isinstance(left, CLVMStorage) if ll_right and ll_left: - stack.append(to_pair_f(left, right)) + did_convert.append(to_pair_f(left, right)) else: - ops.append((3, None)) # cons - stack.append(right) - ops.append((0, None)) # convert - ops.append((2, None)) # roll - stack.append(left) - ops.append((0, None)) # convert + ops.append(1) # cons + to_convert.append(left) + ops.append(0) # convert + to_convert.append(right) + ops.append(0) # convert continue if isinstance(v, list): - target = len(stack) - stack.append(to_atom_f(NULL)) for _ in v: - stack.append(_) - ops.append((1, target)) # prepend list - # we only need to convert if it's not already the right - # type - if not isinstance(_, CLVMStorage): - ops.append((0, None)) # convert - continue - stack.append(to_atom_f(to_atom_type(v))) - continue + ops.append(2) # rcons - if op == 1: # prepend list - left = stack.pop() - assert isinstance(target, int) - right = stack[target] - stack[target] = to_pair_f(left, right) + # add and convert the null terminator + to_convert.append(to_atom_f(NULL_BLOB)) + ops.append(0) # convert + + for _ in reversed(v): + to_convert.append(_) + ops.append(0) # convert + continue + did_convert.append(to_atom_f(to_atom_type(v))) continue - if op == 2: # roll - p1 = stack.pop() - p2 = stack.pop() - stack.append(p1) - stack.append(p2) + if op == 1: # cons + left = did_convert.pop() + right = did_convert.pop() + obj = to_pair_f(left, right) + did_convert.append(obj) continue - if op == 3: # cons - right = stack.pop() - left = stack.pop() + if op == 2: # rcons + right = did_convert.pop() + left = did_convert.pop() obj = to_pair_f(left, right) - stack.append(obj) + did_convert.append(obj) continue - # there's exactly one item left at this point - assert len(stack) == 1 - # stack[0] implements the clvm object protocol - return stack[0] + # there's exactly one item left at this point + assert len(did_convert) == 1 + return did_convert[0] diff --git a/wheel/python/clvm_rs/ser.py b/wheel/python/clvm_rs/ser.py index b6033ef2..967f7588 100644 --- a/wheel/python/clvm_rs/ser.py +++ b/wheel/python/clvm_rs/ser.py @@ -1,3 +1,6 @@ +""" +Serialize clvm. + # decoding: # read a byte # if it's 0xfe, it's nil (which might be same as 0) @@ -10,13 +13,21 @@ # 0xe0-0xef is 3 bytes (`and` of first byte with 0xf) # 0xf0-0xf7 is 4 bytes (`and` of first byte with 0x7) # 0xf7-0xfb is 5 bytes (`and` of first byte with 0x3) +""" + +from typing import BinaryIO, Iterator, List + +from .clvm_storage import CLVMStorage MAX_SINGLE_BYTE = 0x7F CONS_BOX_MARKER = 0xFF -def sexp_to_byte_iterator(sexp): +def sexp_to_byte_iterator(sexp: CLVMStorage) -> Iterator[bytes]: + """ + Yields bytes that serialize the given clvm object. Non-recursive + """ todo_stack = [sexp] while todo_stack: sexp = todo_stack.pop() @@ -30,10 +41,15 @@ def sexp_to_byte_iterator(sexp): todo_stack.append(pair[1]) todo_stack.append(pair[0]) else: - yield from atom_to_byte_iterator(sexp.atom) + atom = sexp.atom + assert atom is not None + yield from atom_to_byte_iterator(atom) -def atom_to_byte_iterator(as_atom): +def atom_to_byte_iterator(as_atom: bytes) -> Iterator[bytes]: + """ + Yield the serialization for a given blob (as a clvm atom). + """ size = len(as_atom) if size == 0: yield b"\x80" @@ -68,18 +84,21 @@ def atom_to_byte_iterator(as_atom): ] ) else: - raise ValueError("sexp too long %s" % as_atom) + raise ValueError("sexp too long %r" % as_atom) yield size_blob yield as_atom -def sexp_to_stream(sexp, f): +def sexp_to_stream(sexp: CLVMStorage, f: BinaryIO) -> None: + """ + Serialize to a file. + """ for b in sexp_to_byte_iterator(sexp): f.write(b) -def sexp_to_bytes(sexp) -> bytes: +def sexp_to_bytes(sexp: CLVMStorage) -> bytes: b = bytearray() for _ in sexp_to_byte_iterator(sexp): b.extend(_) @@ -99,15 +118,17 @@ def _op_read_sexp(op_stack, val_stack, f, new_pair_f, new_atom_f): val_stack.append(_atom_from_stream(f, b, new_atom_f)) -def _op_cons(op_stack, val_stack, f, new_pair_f, new_atom_f): +def _op_cons( + op_stack, val_stack: List[CLVMStorage], f: BinaryIO, new_pair_f, new_atom_f +): right = val_stack.pop() left = val_stack.pop() val_stack.append(new_pair_f(left, right)) -def sexp_from_stream(f, new_pair_f, new_atom_f): +def sexp_from_stream(f: BinaryIO, new_pair_f, new_atom_f): op_stack = [_op_read_sexp] - val_stack = [] + val_stack: List[CLVMStorage] = [] while op_stack: func = op_stack.pop() @@ -115,7 +136,7 @@ def sexp_from_stream(f, new_pair_f, new_atom_f): return val_stack.pop() -def _atom_from_stream(f, b, new_atom_f): +def _atom_from_stream(f: BinaryIO, b: int, new_atom_f): if b == 0x80: return new_atom_f(b"") if b <= MAX_SINGLE_BYTE: @@ -128,10 +149,10 @@ def _atom_from_stream(f, b, new_atom_f): bit_mask >>= 1 size_blob = bytes([b]) if bit_count > 1: - b = f.read(bit_count - 1) - if len(b) != bit_count - 1: + blob = f.read(bit_count - 1) + if len(blob) != bit_count - 1: raise ValueError("bad encoding") - size_blob += b + size_blob += blob size = int.from_bytes(size_blob, "big") if size >= 0x400000000: raise ValueError("blob too large") diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py index ac6cbd66..8774a2b3 100644 --- a/wheel/python/tests/test_program.py +++ b/wheel/python/tests/test_program.py @@ -2,7 +2,7 @@ from clvm_rs.clvm_storage import CLVMStorage from clvm_rs.program import Program -from clvm_rs.EvalError import EvalError +from clvm_rs.eval_error import EvalError from clvm_rs.keywords import A_KW, C_KW, Q_KW diff --git a/wheel/python/tests/test_serialize.py b/wheel/python/tests/test_serialize.py index 3acee12c..7363ce0e 100644 --- a/wheel/python/tests/test_serialize.py +++ b/wheel/python/tests/test_serialize.py @@ -2,7 +2,7 @@ import unittest from clvm_rs.program import Program -from clvm_rs.serialize import atom_to_byte_iterator +from clvm_rs.ser import atom_to_byte_iterator TEXT = b"the quick brown fox jumps over the lazy dogs" diff --git a/wheel/python/tests/test_to_program.py b/wheel/python/tests/test_to_program.py index 11974136..49daa7ae 100644 --- a/wheel/python/tests/test_to_program.py +++ b/wheel/python/tests/test_to_program.py @@ -127,23 +127,15 @@ class dummy: obj = dummy() obj.atom = None obj.pair = None - obj._cached_sha256_treehash = None print(dir(obj)) assert isinstance(obj, CLVMStorage) obj = dummy() obj.pair = None - obj._cached_sha256_treehash = None assert not isinstance(obj, CLVMStorage) obj = dummy() obj.atom = None - obj._cached_sha256_treehash = None - assert not isinstance(obj, CLVMStorage) - - obj = dummy() - obj.atom = None - obj.pair = None assert not isinstance(obj, CLVMStorage) def test_list_conversions(self): From 0107e8a1ebd41326405daf9d6070ccc700d8a3c0 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Fri, 20 Jan 2023 14:49:26 -0800 Subject: [PATCH 19/45] Remove `keywords.py`. --- wheel/python/clvm_rs/curry_and_treehash.py | 2 +- wheel/python/clvm_rs/keywords.py | 6 ------ wheel/python/clvm_rs/program.py | 6 +++--- wheel/python/tests/test_program.py | 3 +-- wheel/python/tests/test_to_program.py | 2 ++ 5 files changed, 7 insertions(+), 12 deletions(-) delete mode 100644 wheel/python/clvm_rs/keywords.py diff --git a/wheel/python/clvm_rs/curry_and_treehash.py b/wheel/python/clvm_rs/curry_and_treehash.py index d53f9c2f..681b601a 100644 --- a/wheel/python/clvm_rs/curry_and_treehash.py +++ b/wheel/python/clvm_rs/curry_and_treehash.py @@ -1,7 +1,7 @@ from typing import List from .bytes32 import bytes32 -from .keywords import NULL, ONE, Q_KW, A_KW, C_KW +from .chia_dialect import NULL, ONE, Q_KW, A_KW, C_KW from .tree_hash import shatree_atom, shatree_pair diff --git a/wheel/python/clvm_rs/keywords.py b/wheel/python/clvm_rs/keywords.py deleted file mode 100644 index f246b1e3..00000000 --- a/wheel/python/clvm_rs/keywords.py +++ /dev/null @@ -1,6 +0,0 @@ -NULL = bytes.fromhex("") -ONE = bytes.fromhex("01") -TWO = bytes.fromhex("02") -Q_KW = bytes.fromhex("01") -A_KW = bytes.fromhex("02") -C_KW = bytes.fromhex("04") diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index 5e4f5ad0..ba8e907a 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -3,10 +3,10 @@ from .bytes32 import bytes32 from .casts import to_clvm_object, int_from_bytes, int_to_bytes +from .chia_dialect import NULL, ONE, Q_KW, A_KW, C_KW from .clvm_rs import run_serialized_program from .clvm_storage import CLVMStorage from .clvm_tree import CLVMTree -from .keywords import NULL, ONE, Q_KW, A_KW, C_KW from .eval_error import EvalError from .ser import sexp_from_stream, sexp_to_stream, sexp_to_bytes from .tree_hash import sha256_treehash @@ -23,10 +23,10 @@ class Program(CLVMStorage): # serialization/deserialization @classmethod - def parse(cls, f) -> Program: + def parse(cls, f: BinaryIO) -> Program: return sexp_from_stream(f, cls.new_pair, cls.new_atom) - def stream(self, f): + def stream(self, f: BinaryIO) -> None: sexp_to_stream(self, f) @classmethod diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py index 8774a2b3..a5d8ca46 100644 --- a/wheel/python/tests/test_program.py +++ b/wheel/python/tests/test_program.py @@ -1,11 +1,10 @@ from unittest import TestCase +from clvm_rs.chia_dialect import A_KW, C_KW, Q_KW from clvm_rs.clvm_storage import CLVMStorage from clvm_rs.program import Program from clvm_rs.eval_error import EvalError -from clvm_rs.keywords import A_KW, C_KW, Q_KW - class TestProgram(TestCase): def test_at(self): diff --git a/wheel/python/tests/test_to_program.py b/wheel/python/tests/test_to_program.py index 49daa7ae..4965f045 100644 --- a/wheel/python/tests/test_to_program.py +++ b/wheel/python/tests/test_to_program.py @@ -149,6 +149,8 @@ def test_string_conversions(self): def test_int_conversions(self): a = Program.to(1337) assert a.as_atom() == bytes([0x5, 0x39]) + a = Program.to(-128) + assert a.as_atom() == bytes([0x80]) def test_none_conversions(self): a = Program.to(None) From f2162a0b9da2926fad7a258ba07c3609895959c4 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Mon, 23 Jan 2023 17:31:59 -0800 Subject: [PATCH 20/45] Various improvements to python. --- wheel/python/clvm_rs/__init__.py | 2 ++ wheel/python/clvm_rs/eval_error.py | 2 +- wheel/python/clvm_rs/program.py | 11 ++++++----- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/wheel/python/clvm_rs/__init__.py b/wheel/python/clvm_rs/__init__.py index e69de29b..fcd136c9 100644 --- a/wheel/python/clvm_rs/__init__.py +++ b/wheel/python/clvm_rs/__init__.py @@ -0,0 +1,2 @@ +from .eval_error import EvalError +from .program import Program diff --git a/wheel/python/clvm_rs/eval_error.py b/wheel/python/clvm_rs/eval_error.py index f71f912a..5174e324 100644 --- a/wheel/python/clvm_rs/eval_error.py +++ b/wheel/python/clvm_rs/eval_error.py @@ -1,4 +1,4 @@ -class EvalError(Exception): +class EvalError(ValueError): def __init__(self, message: str, sexp): super().__init__(message) self._sexp = sexp diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index ba8e907a..e88cfd5e 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -56,7 +56,7 @@ def __int__(self) -> int: return self.as_int() def __hash__(self): - return self.tree_hash().__hash__() + return id(self) @classmethod def int_from_bytes(cls, b: bytes) -> int: @@ -240,10 +240,11 @@ def tree_hash(self) -> bytes32: def run_with_cost(self, args, max_cost: int = MAX_COST) -> Tuple[int, "Program"]: prog_bytes = bytes(self) args_bytes = bytes(self.to(args)) - cost, r = run_serialized_program(prog_bytes, args_bytes, max_cost, 0) - r = self.wrap(r) - if isinstance(cost, str): - raise EvalError(cost, r) + try: + cost, r = run_serialized_program(prog_bytes, args_bytes, max_cost, 0) + r = self.wrap(r) + except ValueError as ve: + raise EvalError(ve.args[0], self.wrap(ve.args[1])) return cost, r def run(self, args) -> "Program": From d796ee4e0a8931b546677e69b0751a034f25d5c3 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Mon, 23 Jan 2023 17:53:51 -0800 Subject: [PATCH 21/45] More tests. --- wheel/python/tests/test_to_program.py | 37 +++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/wheel/python/tests/test_to_program.py b/wheel/python/tests/test_to_program.py index 4965f045..ab132915 100644 --- a/wheel/python/tests/test_to_program.py +++ b/wheel/python/tests/test_to_program.py @@ -1,6 +1,6 @@ import unittest -from typing import Optional, Tuple, Any +from typing import Optional, Tuple, Any, Union from clvm_rs.clvm_storage import CLVMStorage from clvm_rs.program import Program @@ -147,10 +147,37 @@ def test_string_conversions(self): assert a.as_atom() == "foobar".encode() def test_int_conversions(self): - a = Program.to(1337) - assert a.as_atom() == bytes([0x5, 0x39]) - a = Program.to(-128) - assert a.as_atom() == bytes([0x80]) + def check(v: int, h: Union[str, list]): + a = Program.to(v) + b = bytes.fromhex(h) if isinstance(h, str) else bytes(h) + assert a.as_atom() == b + # note that this compares to the atom, not the serialization of that atom + # so 16384 codes as 0x4000, not 0x824000 + + check(1337, "0539") + check(-128, "80") + check(0, "") + check(1, "01") + check(-1, "ff") + + for v in range(1, 0x80): + check(v, [v]) + + for v in range(0x80, 0xFF): + check(v, [0, v]) + + for v in range(128): + check(-v - 1, [255 - v]) + + check(127, "7f") + check(128, "0080") + check(256, "0100") + check(-256, "ff00") + check(16384, "4000") + check(32767, "7fff") + check(32768, "008000") + check(-32768, "8000") + check(-32769, "ff7fff") def test_none_conversions(self): a = Program.to(None) From fbff98ca86af4e5b3389ce32a8ccd9964c7ad140 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Mon, 23 Jan 2023 18:15:03 -0800 Subject: [PATCH 22/45] Improvements to `uncurry`, tests, coverage. --- wheel/python/clvm_rs/program.py | 14 ++++++++------ wheel/python/tests/test_program.py | 1 + wheel/python/tests/test_to_program.py | 15 +++++++++++++++ 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index e88cfd5e..61de6523 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -290,20 +290,22 @@ def curry(self, *args) -> "Program": def uncurry(self) -> Tuple[Program, Optional[List[Program]]]: if self.at("f") != A_KW or self.at("rff") != Q_KW or self.at("rrr") != NULL: return self, None + # since "rff" is not None, neither is "rfr" uncurried_function = self.at("rfr") - if uncurried_function is None: - return self, None + assert uncurried_function is not None core_items = [] + + # since "rrr" is not None, neither is rrf core = self.at("rrf") while core != ONE: - if core is None: - return self, None + assert core is not None if core.at("f") != C_KW or core.at("rff") != Q_KW or core.at("rrr") != NULL: return self, None + # since "rff" is not None, neither is "rfr" new_item = core.at("rfr") - if new_item is None: - return self, None + assert new_item is not None core_items.append(new_item) + # since "rrr" is not None, neither is rrf core = core.at("rrf") return uncurried_function, core_items diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py index a5d8ca46..7aadb66d 100644 --- a/wheel/python/tests/test_program.py +++ b/wheel/python/tests/test_program.py @@ -21,6 +21,7 @@ def test_at(self): self.assertRaises(ValueError, lambda: p.at("q")) self.assertEqual(None, p.at("ff")) + self.assertEqual(None, p.at("ffr")) def test_replace(self): p1 = Program.to([100, 200, 300]) diff --git a/wheel/python/tests/test_to_program.py b/wheel/python/tests/test_to_program.py index ab132915..70dceb2c 100644 --- a/wheel/python/tests/test_to_program.py +++ b/wheel/python/tests/test_to_program.py @@ -179,6 +179,21 @@ def check(v: int, h: Union[str, list]): check(-32768, "8000") check(-32769, "ff7fff") + def test_int_round_trip(self): + def check(n): + p = Program.to(n) + assert int(p) == n + assert p.int_from_bytes(p.atom) == n + assert Program.int_to_bytes(n) == p.atom + + for n in range(0, 256): + check(n) + check(-n) + + for n in range(0, 65536, 97): + check(n) + check(-n) + def test_none_conversions(self): a = Program.to(None) assert a.as_atom() == b"" From 3172fda501bbddcab8a57b8ea92a6dce4a437ab6 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Tue, 24 Jan 2023 10:59:16 -0800 Subject: [PATCH 23/45] Refactor. --- wheel/python/clvm_rs/__init__.py | 2 +- wheel/python/clvm_rs/at.py | 41 +++++++++++++++++++++++++ wheel/python/clvm_rs/program.py | 51 ++++---------------------------- wheel/python/clvm_rs/replace.py | 37 +++++++++++++++++++++++ 4 files changed, 84 insertions(+), 47 deletions(-) create mode 100644 wheel/python/clvm_rs/at.py create mode 100644 wheel/python/clvm_rs/replace.py diff --git a/wheel/python/clvm_rs/__init__.py b/wheel/python/clvm_rs/__init__.py index fcd136c9..b9c3d677 100644 --- a/wheel/python/clvm_rs/__init__.py +++ b/wheel/python/clvm_rs/__init__.py @@ -1,2 +1,2 @@ from .eval_error import EvalError -from .program import Program +from .program import Program \ No newline at end of file diff --git a/wheel/python/clvm_rs/at.py b/wheel/python/clvm_rs/at.py new file mode 100644 index 00000000..4f34f5d5 --- /dev/null +++ b/wheel/python/clvm_rs/at.py @@ -0,0 +1,41 @@ +from __future__ import annotations +from typing import Optional + +from .clvm_storage import CLVMStorage + + +def at(obj: CLVMStorage, position: str) -> Optional[CLVMStorage]: + """ + Take a string of `f` and `r` characters and follow that path. + + Example: + + ``` + p1 = Program.to([10, 20, 30, [15, 17], 40, 50]) + assert Program.to(17) == at(p1, "rrrfrf") + ``` + + Returns `None` if an atom is hit at some intermediate node. + + ``` + p1 = Program.to(10) + assert None == at(p1, "rr") + ``` + + """ + v: Optional[CLVMStorage] = obj + for c in position.lower(): + pair = v.pair + if pair is None: + return None + if c == "f": + v = pair[0] + elif c == "r": + v = pair[1] + else: + raise ValueError( + f"`at` got illegal character `{c}`. Only `f` & `r` allowed" + ) + if v is None: + break + return v diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index 61de6523..40d665d1 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Dict, Iterator, List, Tuple, Optional, Any +from .at import at from .bytes32 import bytes32 from .casts import to_clvm_object, int_from_bytes, int_to_bytes from .chia_dialect import NULL, ONE, Q_KW, A_KW, C_KW @@ -8,8 +9,10 @@ from .clvm_storage import CLVMStorage from .clvm_tree import CLVMTree from .eval_error import EvalError +from .replace import replace from .ser import sexp_from_stream, sexp_to_stream, sexp_to_bytes from .tree_hash import sha256_treehash +from .uncurry import uncurry MAX_COST = 0x7FFFFFFFFFFFFFFF @@ -195,19 +198,7 @@ def at(self, position: str) -> Optional["Program"]: ``` """ - v: Optional[Program] = self - for c in position.lower(): - if v is None: - return v - if c == "f": - v = v.first() - elif c == "r": - v = v.rest() - else: - raise ValueError( - f"`at` got illegal character `{c}`. Only `f` & `r` allowed" - ) - return v + return at(self, position) def replace(self, **kwargs) -> "Program": """ @@ -232,7 +223,7 @@ def replace(self, **kwargs) -> "Program": Note that `Program` objects are immutable. This function returns a new object; the original is left as-is. """ - return _replace(self, **kwargs) + return replace(self, **kwargs) def tree_hash(self) -> bytes32: return sha256_treehash(self) @@ -347,35 +338,3 @@ def as_atom_list(self) -> List[bytes]: NULL_PROGRAM = Program.fromhex("80") - - -def _replace(program: Program, **kwargs) -> Program: - # if `kwargs == {}` then `return program` unchanged - if len(kwargs) == 0: - return program - - if "" in kwargs: - if len(kwargs) > 1: - raise ValueError("conflicting paths") - return kwargs[""] - - # we've confirmed that no `kwargs` is the empty string. - # Now split `kwargs` into two groups: those - # that start with `f` and those that start with `r` - - args_by_prefix: Dict[str, Dict[str, Program]] = dict(f={}, r={}) - for k, v in kwargs.items(): - c = k[0] - if c not in "fr": - raise ValueError(f"bad path containing {c}: must only contain `f` and `r`") - args_by_prefix[c][k[1:]] = program.to(v) - - pair = program.pair - if pair is None: - raise ValueError("path into atom") - - # recurse down the tree - new_f = _replace(pair[0], **args_by_prefix.get("f", {})) - new_r = _replace(pair[1], **args_by_prefix.get("r", {})) - - return program.new_pair(new_f, new_r) diff --git a/wheel/python/clvm_rs/replace.py b/wheel/python/clvm_rs/replace.py new file mode 100644 index 00000000..f9be8622 --- /dev/null +++ b/wheel/python/clvm_rs/replace.py @@ -0,0 +1,37 @@ +from __future__ import annotations +from typing import Dict + +from .clvm_storage import CLVMStorage + + +def replace(program: CLVMStorage, **kwargs) -> CLVMStorage: + # if `kwargs == {}` then `return program` unchanged + if len(kwargs) == 0: + return program + + if "" in kwargs: + if len(kwargs) > 1: + raise ValueError("conflicting paths") + return kwargs[""] + + # we've confirmed that no `kwargs` is the empty string. + # Now split `kwargs` into two groups: those + # that start with `f` and those that start with `r` + + args_by_prefix: Dict[str, Dict[str, CLVMStorage]] = dict(f={}, r={}) + for k, v in kwargs.items(): + c = k[0] + if c not in "fr": + msg = f"bad path containing {c}: must only contain `f` and `r`" + raise ValueError(msg) + args_by_prefix[c][k[1:]] = program.to(v) + + pair = program.pair + if pair is None: + raise ValueError("path into atom") + + # recurse down the tree + new_f = replace(pair[0], **args_by_prefix.get("f", {})) + new_r = replace(pair[1], **args_by_prefix.get("r", {})) + + return program.new_pair(new_f, new_r) From f80a0fc8099ea859a291cf9efae6d34876ae803e Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Tue, 24 Jan 2023 16:59:05 -0800 Subject: [PATCH 24/45] More refactor. --- wheel/python/clvm_rs/at.py | 4 +- wheel/python/clvm_rs/casts.py | 2 +- wheel/python/clvm_rs/chia_dialect.py | 40 ++++++ wheel/python/clvm_rs/clvm_storage.py | 26 ++++ wheel/python/clvm_rs/curry_and_treehash.py | 129 ++++++++++-------- wheel/python/clvm_rs/de.py | 3 +- wheel/python/clvm_rs/program.py | 41 ++---- wheel/python/clvm_rs/replace.py | 7 +- wheel/python/clvm_rs/tree_hash.py | 105 ++++++++------ wheel/python/clvm_rs/uncurry.py | 51 +++++++ wheel/python/tests/test_curry_and_treehash.py | 2 + wheel/python/tests/test_program.py | 4 +- 12 files changed, 277 insertions(+), 137 deletions(-) create mode 100644 wheel/python/clvm_rs/chia_dialect.py create mode 100644 wheel/python/clvm_rs/clvm_storage.py create mode 100644 wheel/python/clvm_rs/uncurry.py diff --git a/wheel/python/clvm_rs/at.py b/wheel/python/clvm_rs/at.py index 4f34f5d5..dcfa1370 100644 --- a/wheel/python/clvm_rs/at.py +++ b/wheel/python/clvm_rs/at.py @@ -25,6 +25,8 @@ def at(obj: CLVMStorage, position: str) -> Optional[CLVMStorage]: """ v: Optional[CLVMStorage] = obj for c in position.lower(): + if v is None: + break pair = v.pair if pair is None: return None @@ -36,6 +38,4 @@ def at(obj: CLVMStorage, position: str) -> Optional[CLVMStorage]: raise ValueError( f"`at` got illegal character `{c}`. Only `f` & `r` allowed" ) - if v is None: - break return v diff --git a/wheel/python/clvm_rs/casts.py b/wheel/python/clvm_rs/casts.py index 35b74046..7b762e3e 100644 --- a/wheel/python/clvm_rs/casts.py +++ b/wheel/python/clvm_rs/casts.py @@ -2,7 +2,7 @@ Some utilities to cast python types to and from clvm. """ -from typing import Any, Callable, List, Optional, SupportsBytes, Tuple, Union +from typing import Callable, List, SupportsBytes, Tuple, Union from .clvm_storage import CLVMStorage diff --git a/wheel/python/clvm_rs/chia_dialect.py b/wheel/python/clvm_rs/chia_dialect.py new file mode 100644 index 00000000..b90ab5fa --- /dev/null +++ b/wheel/python/clvm_rs/chia_dialect.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +from typing import List + + +@dataclass +class Dialect: + KEYWORDS: List[str] + + NULL: bytes + ONE: bytes + TWO: bytes + Q_KW: bytes + A_KW: bytes + C_KW: bytes + + +chia_dialect = Dialect( + ( + # core opcodes 0x01-x08 + ". q a i c f r l x " + # opcodes on atoms as strings 0x09-0x0f + "= >s sha256 substr strlen concat . " + # opcodes on atoms as ints 0x10-0x17 + "+ - * / divmod > ash lsh " + # opcodes on atoms as vectors of bools 0x18-0x1c + "logand logior logxor lognot . " + # opcodes for bls 1381 0x1d-0x1f + "point_add pubkey_for_exp . " + # bool opcodes 0x20-0x23 + "not any all . " + # misc 0x24 + "softfork " + ).split(), + NULL=bytes.fromhex(""), + ONE=bytes.fromhex("01"), + TWO=bytes.fromhex("02"), + Q_KW=bytes.fromhex("01"), + A_KW=bytes.fromhex("02"), + C_KW=bytes.fromhex("04"), +) diff --git a/wheel/python/clvm_rs/clvm_storage.py b/wheel/python/clvm_rs/clvm_storage.py new file mode 100644 index 00000000..8965019f --- /dev/null +++ b/wheel/python/clvm_rs/clvm_storage.py @@ -0,0 +1,26 @@ +from typing import Optional, Protocol, Tuple, runtime_checkable + + +@runtime_checkable +class CLVMStorage(Protocol): + atom: Optional[bytes] + + @property + def pair(self) -> Optional[Tuple["CLVMStorage", "CLVMStorage"]]: + ... + + # optional fields used to speed implementations: + + # `_cached_sha256_treehash: Optional[bytes]` is used by `sha256_treehash` + # `_cached_serialization: bytes` is used by `sexp_to_byte_iterator` to speed up serialization + + +@runtime_checkable +class CLVMStorageFactory(Protocol): + @classmethod + def new_atom(cls, v: bytes) -> "CLVMStorage": + ... + + @classmethod + def new_pair(cls, p1: "CLVMStorage", p2: "CLVMStorage") -> "CLVMStorage": + ... diff --git a/wheel/python/clvm_rs/curry_and_treehash.py b/wheel/python/clvm_rs/curry_and_treehash.py index 681b601a..29f0e738 100644 --- a/wheel/python/clvm_rs/curry_and_treehash.py +++ b/wheel/python/clvm_rs/curry_and_treehash.py @@ -1,59 +1,76 @@ from typing import List from .bytes32 import bytes32 -from .chia_dialect import NULL, ONE, Q_KW, A_KW, C_KW -from .tree_hash import shatree_atom, shatree_pair - - -Q_KW_TREEHASH = shatree_atom(Q_KW) -A_KW_TREEHASH = shatree_atom(A_KW) -C_KW_TREEHASH = shatree_atom(C_KW) -ONE_TREEHASH = shatree_atom(ONE) -NULL_TREEHASH = shatree_atom(NULL) - - -# The environment `E = (F . R)` recursively expands out to -# `(c . ((q . F) . EXPANSION(R)))` if R is not 0 -# `1` if R is 0 - - -def curried_values_tree_hash(arguments: List[bytes32]) -> bytes32: - if len(arguments) == 0: - return ONE_TREEHASH - - inner_curried_values = curried_values_tree_hash(arguments[1:]) - - return shatree_pair( - C_KW_TREEHASH, - shatree_pair( - shatree_pair(Q_KW_TREEHASH, arguments[0]), - shatree_pair(inner_curried_values, NULL_TREEHASH), - ), - ) - - -# The curry pattern is `(a . ((q . F) . (E . 0)))` == `(a (q . F) E) -# where `F` is the `mod` and `E` is the curried environment - - -def curry_and_treehash( - hash_of_quoted_mod_hash: bytes32, *hashed_arguments: bytes32 -) -> bytes32: - """ - `hash_of_quoted_mod_hash` : tree hash of `(q . MOD)` where `MOD` - is template to be curried - `arguments` : tree hashes of arguments to be curried - """ - - curried_values = curried_values_tree_hash(list(hashed_arguments)) - return shatree_pair( - A_KW_TREEHASH, - shatree_pair( - hash_of_quoted_mod_hash, - shatree_pair(curried_values, NULL_TREEHASH), - ), - ) - - -def calculate_hash_of_quoted_mod_hash(mod_hash: bytes32) -> bytes32: - return shatree_pair(Q_KW_TREEHASH, mod_hash) +from .chia_dialect import Dialect, chia_dialect +from .tree_hash import shatree_pair, Treehash, CHIA_TREEHASHER + + +ONE = bytes.fromhex("01") + + +class CurryTreehasher: + q_hw_treehash: bytes + a_hw_treehash: bytes + c_hw_treehash: bytes + atom_prefix_treehash: bytes + pair_prefix_treehash: bytes + + def __init__(self, dialect: Dialect, tree_hash: Treehash): + self.dialect = dialect + self.tree_hash = tree_hash + + self.q_kw_treehash = self.tree_hash.shatree_atom(dialect.Q_KW) + self.c_kw_treehash = self.tree_hash.shatree_atom(dialect.C_KW) + self.a_kw_treehash = self.tree_hash.shatree_atom(dialect.A_KW) + self.null_treehash = self.tree_hash.shatree_atom(dialect.NULL) + self.one_treehash = self.tree_hash.shatree_atom(ONE) + + # The environment `E = (F . R)` recursively expands out to + # `(c . ((q . F) . EXPANSION(R)))` if R is not 0 + # `1` if R is 0 + + def curried_values_tree_hash(self, arguments: List[bytes32]) -> bytes32: + if len(arguments) == 0: + return self.one_treehash + + inner_curried_values = self.curried_values_tree_hash(arguments[1:]) + + return self.tree_hash.shatree_pair( + self.c_kw_treehash, + self.tree_hash.shatree_pair( + self.tree_hash.shatree_pair(self.q_kw_treehash, arguments[0]), + self.tree_hash.shatree_pair(inner_curried_values, self.null_treehash), + ), + ) + + # The curry pattern is `(a . ((q . F) . (E . 0)))` == `(a (q . F) E) + # where `F` is the `mod` and `E` is the curried environment + + def curry_and_treehash( + self, hash_of_quoted_mod_hash: bytes32, *hashed_arguments: bytes32 + ) -> bytes32: + """ + `hash_of_quoted_mod_hash` : tree hash of `(q . MOD)` where `MOD` + is template to be curried + `arguments` : tree hashes of arguments to be curried + """ + + curried_values = self.curried_values_tree_hash(list(hashed_arguments)) + return shatree_pair( + self.a_kw_treehash, + shatree_pair( + hash_of_quoted_mod_hash, + shatree_pair(curried_values, self.null_treehash), + ), + ) + + def calculate_hash_of_quoted_mod_hash(self, mod_hash: bytes32) -> bytes32: + return shatree_pair(self.q_kw_treehash, mod_hash) + + +CHIA_CURRY_TREEHASHER = CurryTreehasher(chia_dialect, CHIA_TREEHASHER) + +curry_and_treehash = CHIA_CURRY_TREEHASHER.curry_and_treehash +calculate_hash_of_quoted_mod_hash = ( + CHIA_CURRY_TREEHASHER.calculate_hash_of_quoted_mod_hash +) diff --git a/wheel/python/clvm_rs/de.py b/wheel/python/clvm_rs/de.py index 9ea9ea38..0b2a2ec5 100644 --- a/wheel/python/clvm_rs/de.py +++ b/wheel/python/clvm_rs/de.py @@ -1,6 +1,5 @@ -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple -from .clvm_storage import CLVMStorage from .tree_hash import shatree_atom, shatree_pair try: diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index 40d665d1..ba0a9cb5 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -1,17 +1,17 @@ from __future__ import annotations -from typing import Dict, Iterator, List, Tuple, Optional, Any +from typing import Iterator, List, Tuple, Optional, Any, BinaryIO from .at import at from .bytes32 import bytes32 from .casts import to_clvm_object, int_from_bytes, int_to_bytes -from .chia_dialect import NULL, ONE, Q_KW, A_KW, C_KW +from .chia_dialect import Dialect, chia_dialect from .clvm_rs import run_serialized_program from .clvm_storage import CLVMStorage from .clvm_tree import CLVMTree from .eval_error import EvalError from .replace import replace from .ser import sexp_from_stream, sexp_to_stream, sexp_to_bytes -from .tree_hash import sha256_treehash +from .tree_hash import CHIA_TREEHASHER from .uncurry import uncurry @@ -23,6 +23,8 @@ class Program(CLVMStorage): A thin wrapper around s-expression data intended to be invoked with "eval". """ + dialect: Dialect = chia_dialect + # serialization/deserialization @classmethod @@ -198,7 +200,7 @@ def at(self, position: str) -> Optional["Program"]: ``` """ - return at(self, position) + return self.to(at(self, position)) def replace(self, **kwargs) -> "Program": """ @@ -223,10 +225,10 @@ def replace(self, **kwargs) -> "Program": Note that `Program` objects are immutable. This function returns a new object; the original is left as-is. """ - return replace(self, **kwargs) + return self.to(replace(self, **kwargs)) def tree_hash(self) -> bytes32: - return sha256_treehash(self) + return CHIA_TREEHASHER.sha256_treehash(self) def run_with_cost(self, args, max_cost: int = MAX_COST) -> Tuple[int, "Program"]: prog_bytes = bytes(self) @@ -264,8 +266,8 @@ def run(self, args) -> "Program": def curry(self, *args) -> "Program": fixed_args: Any = 1 for arg in reversed(args): - fixed_args = [4, (1, arg), fixed_args] - return self.to([2, (1, self), fixed_args]) + fixed_args = [self.dialect.C_KW, (self.dialect.Q_KW, arg), fixed_args] + return self.to([self.dialect.A_KW, (self.dialect.Q_KW, self), fixed_args]) """ uncurry the given program @@ -279,26 +281,9 @@ def curry(self, *args) -> "Program": """ def uncurry(self) -> Tuple[Program, Optional[List[Program]]]: - if self.at("f") != A_KW or self.at("rff") != Q_KW or self.at("rrr") != NULL: - return self, None - # since "rff" is not None, neither is "rfr" - uncurried_function = self.at("rfr") - assert uncurried_function is not None - core_items = [] - - # since "rrr" is not None, neither is rrf - core = self.at("rrf") - while core != ONE: - assert core is not None - if core.at("f") != C_KW or core.at("rff") != Q_KW or core.at("rrr") != NULL: - return self, None - # since "rff" is not None, neither is "rfr" - new_item = core.at("rfr") - assert new_item is not None - core_items.append(new_item) - # since "rrr" is not None, neither is rrf - core = core.at("rrf") - return uncurried_function, core_items + mod, args = uncurry(self.dialect, self) + p_args = args if args is None else [self.to(_) for _ in args] + return self.to(mod), p_args def as_int(self) -> int: return int_from_bytes(self.as_atom()) diff --git a/wheel/python/clvm_rs/replace.py b/wheel/python/clvm_rs/replace.py index f9be8622..5afb4c42 100644 --- a/wheel/python/clvm_rs/replace.py +++ b/wheel/python/clvm_rs/replace.py @@ -1,10 +1,11 @@ from __future__ import annotations from typing import Dict +from .casts import CastableType from .clvm_storage import CLVMStorage -def replace(program: CLVMStorage, **kwargs) -> CLVMStorage: +def replace(program: CLVMStorage, **kwargs) -> CastableType: # if `kwargs == {}` then `return program` unchanged if len(kwargs) == 0: return program @@ -24,7 +25,7 @@ def replace(program: CLVMStorage, **kwargs) -> CLVMStorage: if c not in "fr": msg = f"bad path containing {c}: must only contain `f` and `r`" raise ValueError(msg) - args_by_prefix[c][k[1:]] = program.to(v) + args_by_prefix[c][k[1:]] = v pair = program.pair if pair is None: @@ -34,4 +35,4 @@ def replace(program: CLVMStorage, **kwargs) -> CLVMStorage: new_f = replace(pair[0], **args_by_prefix.get("f", {})) new_r = replace(pair[1], **args_by_prefix.get("r", {})) - return program.new_pair(new_f, new_r) + return (new_f, new_r) diff --git a/wheel/python/clvm_rs/tree_hash.py b/wheel/python/clvm_rs/tree_hash.py index 7fd82ac6..cba589ce 100644 --- a/wheel/python/clvm_rs/tree_hash.py +++ b/wheel/python/clvm_rs/tree_hash.py @@ -8,61 +8,78 @@ from hashlib import sha256 from typing import List +from weakref import WeakKeyDictionary -from clvm_rs.clvm_storage import CLVMStorage +from .chia_dialect import Dialect, chia_dialect +from .clvm_storage import CLVMStorage bytes32 = bytes -ONE = bytes.fromhex("01") -TWO = bytes.fromhex("02") +class Treehash: + atom_prefix: bytes + pair_prefix: bytes + def __init__(self, dialect: Dialect, atom_prefix: bytes, pair_prefix: bytes): + self.dialect = dialect + self.atom_prefix = atom_prefix + self.pair_prefix = pair_prefix + self.hash_cache: WeakKeyDictionary[CLVMStorage, bytes32] = WeakKeyDictionary() + self.cache_hits = 0 -def shatree_atom(atom: bytes) -> bytes32: - s = sha256() - s.update(ONE) - s.update(atom) - return bytes32(s.digest()) + def shatree_atom(self, atom: bytes) -> bytes32: + s = sha256() + s.update(self.atom_prefix) + s.update(atom) + return bytes32(s.digest()) + def shatree_pair(self, left_hash: bytes32, right_hash: bytes32) -> bytes32: + s = sha256() + s.update(self.pair_prefix) + s.update(left_hash) + s.update(right_hash) + return bytes32(s.digest()) -def shatree_pair(left_hash: bytes32, right_hash: bytes32) -> bytes32: - s = sha256() - s.update(TWO) - s.update(left_hash) - s.update(right_hash) - return bytes32(s.digest()) + def sha256_treehash(self, sexp: CLVMStorage) -> bytes32: + def handle_sexp(sexp_stack, hash_stack, op_stack) -> None: + sexp = sexp_stack.pop() + r = getattr(sexp, "_cached_sha256_treehash", None) + if r is None: + r = self.hash_cache.get(sexp) + if r is not None: + self.cache_hits += 1 + hash_stack.append(r) + return + elif sexp.pair: + p0, p1 = sexp.pair + sexp_stack.append(p0) + sexp_stack.append(p1) + op_stack.append(handle_pair) + op_stack.append(handle_sexp) + op_stack.append(handle_sexp) + else: + r = shatree_atom(sexp.atom) + hash_stack.append(r) + sexp._cached_sha256_treehash = r + self.hash_cache[sexp] = r - -def sha256_treehash(sexp: CLVMStorage) -> bytes32: - def handle_sexp(sexp_stack, hash_stack, op_stack) -> None: - sexp = sexp_stack.pop() - r = getattr(sexp, "_cached_sha256_treehash", None) - if r is not None: - hash_stack.append(r) - return - elif sexp.pair: - p0, p1 = sexp.pair - sexp_stack.append(p0) - sexp_stack.append(p1) - op_stack.append(handle_pair) - op_stack.append(handle_sexp) - op_stack.append(handle_sexp) - else: - r = shatree_atom(sexp.atom) + def handle_pair(sexp_stack, hash_stack, op_stack) -> None: + p0 = hash_stack.pop() + p1 = hash_stack.pop() + r = shatree_pair(p0, p1) hash_stack.append(r) sexp._cached_sha256_treehash = r + self.hash_cache[sexp] = r + + sexp_stack = [sexp] + op_stack = [handle_sexp] + hash_stack: List[bytes32] = [] + while len(op_stack) > 0: + op = op_stack.pop() + op(sexp_stack, hash_stack, op_stack) + return hash_stack[0] - def handle_pair(sexp_stack, hash_stack, op_stack) -> None: - p0 = hash_stack.pop() - p1 = hash_stack.pop() - r = shatree_pair(p0, p1) - hash_stack.append(r) - sexp._cached_sha256_treehash = r - sexp_stack = [sexp] - op_stack = [handle_sexp] - hash_stack: List[bytes32] = [] - while len(op_stack) > 0: - op = op_stack.pop() - op(sexp_stack, hash_stack, op_stack) - return hash_stack[0] +CHIA_TREEHASHER = Treehash(chia_dialect, bytes.fromhex("01"), bytes.fromhex("02")) +shatree_atom = CHIA_TREEHASHER.shatree_atom +shatree_pair = CHIA_TREEHASHER.shatree_pair diff --git a/wheel/python/clvm_rs/uncurry.py b/wheel/python/clvm_rs/uncurry.py new file mode 100644 index 00000000..5f0cf693 --- /dev/null +++ b/wheel/python/clvm_rs/uncurry.py @@ -0,0 +1,51 @@ +from __future__ import annotations +from typing import List, Tuple, Optional + +from .at import at +from .chia_dialect import Dialect +from .clvm_storage import CLVMStorage + + +""" +uncurry the given program + +returns `mod, [arg1, arg2, ...]` + +if the program is not a valid curry, return `sexp, NULL` + +This distinguishes it from the case of a valid curry of 0 arguments +(which is rather pointless but possible), which returns `sexp, []` +""" + + +def uncurry( + dialect: Dialect, sexp: CLVMStorage +) -> Tuple[CLVMStorage, Optional[List[CLVMStorage]]]: + if ( + at(sexp, "f") != dialect.A_KW + or at(sexp, "rff") != dialect.Q_KW + or at(sexp, "rrr") != dialect.NULL + ): + return sexp, None + # since "rff" is not None, neither is "rfr" + uncurried_function = at(sexp, "rfr") + assert uncurried_function is not None + core_items = [] + + # since "rrr" is not None, neither is rrf + core = at(sexp, "rrf") + while core != dialect.ONE: + assert core is not None + if ( + at(core, "f") != dialect.C_KW + or at(core, "rff") != dialect.Q_KW + or at(core, "rrr") != dialect.NULL + ): + return sexp, None + # since "rff" is not None, neither is "rfr" + new_item = at(core, "rfr") + assert new_item is not None + core_items.append(new_item) + # since "rrr" is not None, neither is rrf + core = at(core, "rrf") + return uncurried_function, core_items diff --git a/wheel/python/tests/test_curry_and_treehash.py b/wheel/python/tests/test_curry_and_treehash.py index 17ad15bb..7b30a95d 100644 --- a/wheel/python/tests/test_curry_and_treehash.py +++ b/wheel/python/tests/test_curry_and_treehash.py @@ -13,6 +13,8 @@ def test_curry_and_treehash() -> None: # we don't really care what `arbitrary_mod` is. We just need some code quoted_mod_hash = calculate_hash_of_quoted_mod_hash(arbitrary_mod_hash) + exp_hash = "9f487f9078d4b215e0cbe2cbdd21215ad6ed8e894ae00d616751e0efdccb25a9" + assert quoted_mod_hash == bytes.fromhex(exp_hash) for v in range(500): args = [v, v * v, v * v * v] diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py index 7aadb66d..b909fa68 100644 --- a/wheel/python/tests/test_program.py +++ b/wheel/python/tests/test_program.py @@ -1,10 +1,12 @@ from unittest import TestCase -from clvm_rs.chia_dialect import A_KW, C_KW, Q_KW +from clvm_rs.chia_dialect import chia_dialect from clvm_rs.clvm_storage import CLVMStorage from clvm_rs.program import Program from clvm_rs.eval_error import EvalError +A_KW, C_KW, Q_KW = [getattr(chia_dialect, _) for _ in "A_KW C_KW Q_KW".split()] + class TestProgram(TestCase): def test_at(self): From 81e54ef0fe6dbbb7ac549a565b916aa90a793e7f Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 25 Jan 2023 15:58:22 -0800 Subject: [PATCH 25/45] Handle end of stream properly. --- src/serde/de_tree.rs | 25 +++++++++++++++++++++---- wheel/python/clvm_rs/chia_dialect.py | 2 +- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/serde/de_tree.rs b/src/serde/de_tree.rs index 486c63c5..636824eb 100644 --- a/src/serde/de_tree.rs +++ b/src/serde/de_tree.rs @@ -64,8 +64,26 @@ fn tree_hash_for_byte(b: u8, calculate_tree_hashes: bool) -> Option<[u8; 32]> { } } -fn skip_bytes(f: &mut R, skip_size: u64) -> Result { - copy(&mut f.by_ref().take(skip_size), &mut sink()) +pub fn copy_exactly( + reader: &mut R, + writer: &mut W, + expected_size: u64, +) -> Result<()> { + let mut reader = reader.by_ref().take(expected_size); + + let count = copy(&mut reader, writer)?; + if count < expected_size { + Err(Error::new( + std::io::ErrorKind::UnexpectedEof, + "copy terminated early", + )) + } else { + Ok(()) + } +} + +fn skip_bytes(f: &mut R, skip_size: u64) -> Result<()> { + copy_exactly(f, &mut sink(), skip_size) } fn skip_or_sha_bytes( @@ -74,11 +92,10 @@ fn skip_or_sha_bytes( calculate_tree_hashes: bool, ) -> Result> { if calculate_tree_hashes { - let mut f = &mut f.by_ref().take(skip_size); let mut h = Sha256::new(); h.update([1]); let mut w = ShaWrapper(h); - copy(&mut f, &mut w)?; + copy_exactly(f, &mut w, skip_size)?; let r: [u8; 32] = w.0.finalize() .as_slice() diff --git a/wheel/python/clvm_rs/chia_dialect.py b/wheel/python/clvm_rs/chia_dialect.py index b90ab5fa..7ae8e77f 100644 --- a/wheel/python/clvm_rs/chia_dialect.py +++ b/wheel/python/clvm_rs/chia_dialect.py @@ -14,7 +14,7 @@ class Dialect: C_KW: bytes -chia_dialect = Dialect( +CHIA_DIALECT = Dialect( ( # core opcodes 0x01-x08 ". q a i c f r l x " From 3cfbd56c28b09eb495c6c05164ec991d5d1d1030 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 25 Jan 2023 16:32:51 -0800 Subject: [PATCH 26/45] More refactor. --- wheel/python/clvm_rs/curry_and_treehash.py | 107 ++++++++++++++---- wheel/python/clvm_rs/de.py | 5 +- wheel/python/clvm_rs/program.py | 16 +-- wheel/python/clvm_rs/tree_hash.py | 10 +- wheel/python/clvm_rs/uncurry.py | 51 --------- wheel/python/tests/test_curry_and_treehash.py | 8 +- wheel/python/tests/test_program.py | 4 +- wheel/python/tests/test_serialize.py | 13 ++- 8 files changed, 116 insertions(+), 98 deletions(-) delete mode 100644 wheel/python/clvm_rs/uncurry.py diff --git a/wheel/python/clvm_rs/curry_and_treehash.py b/wheel/python/clvm_rs/curry_and_treehash.py index 29f0e738..b8d0c266 100644 --- a/wheel/python/clvm_rs/curry_and_treehash.py +++ b/wheel/python/clvm_rs/curry_and_treehash.py @@ -1,8 +1,10 @@ -from typing import List +from typing import Any, List, Optional, Tuple +from .at import at from .bytes32 import bytes32 -from .chia_dialect import Dialect, chia_dialect -from .tree_hash import shatree_pair, Treehash, CHIA_TREEHASHER +from .chia_dialect import Dialect, CHIA_DIALECT +from .clvm_storage import CLVMStorage +from .tree_hash import shatree_pair, shatree_atom ONE = bytes.fromhex("01") @@ -15,15 +17,14 @@ class CurryTreehasher: atom_prefix_treehash: bytes pair_prefix_treehash: bytes - def __init__(self, dialect: Dialect, tree_hash: Treehash): + def __init__(self, dialect: Dialect): self.dialect = dialect - self.tree_hash = tree_hash - self.q_kw_treehash = self.tree_hash.shatree_atom(dialect.Q_KW) - self.c_kw_treehash = self.tree_hash.shatree_atom(dialect.C_KW) - self.a_kw_treehash = self.tree_hash.shatree_atom(dialect.A_KW) - self.null_treehash = self.tree_hash.shatree_atom(dialect.NULL) - self.one_treehash = self.tree_hash.shatree_atom(ONE) + self.q_kw_treehash = shatree_atom(dialect.Q_KW) + self.c_kw_treehash = shatree_atom(dialect.C_KW) + self.a_kw_treehash = shatree_atom(dialect.A_KW) + self.null_treehash = shatree_atom(dialect.NULL) + self.one_treehash = shatree_atom(ONE) # The environment `E = (F . R)` recursively expands out to # `(c . ((q . F) . EXPANSION(R)))` if R is not 0 @@ -35,11 +36,11 @@ def curried_values_tree_hash(self, arguments: List[bytes32]) -> bytes32: inner_curried_values = self.curried_values_tree_hash(arguments[1:]) - return self.tree_hash.shatree_pair( + return shatree_pair( self.c_kw_treehash, - self.tree_hash.shatree_pair( - self.tree_hash.shatree_pair(self.q_kw_treehash, arguments[0]), - self.tree_hash.shatree_pair(inner_curried_values, self.null_treehash), + shatree_pair( + shatree_pair(self.q_kw_treehash, arguments[0]), + shatree_pair(inner_curried_values, self.null_treehash), ), ) @@ -67,10 +68,74 @@ def curry_and_treehash( def calculate_hash_of_quoted_mod_hash(self, mod_hash: bytes32) -> bytes32: return shatree_pair(self.q_kw_treehash, mod_hash) - -CHIA_CURRY_TREEHASHER = CurryTreehasher(chia_dialect, CHIA_TREEHASHER) - -curry_and_treehash = CHIA_CURRY_TREEHASHER.curry_and_treehash -calculate_hash_of_quoted_mod_hash = ( - CHIA_CURRY_TREEHASHER.calculate_hash_of_quoted_mod_hash -) + """ + Replicates the curry function from clvm_tools, taking advantage of *args + being a list. We iterate through args in reverse building the code to + create a clvm list. + + Given arguments to a function addressable by the '1' reference in clvm + + fixed_args = 1 + + Each arg is prepended as fixed_args = (c (q . arg) fixed_args) + + The resulting argument list is interpreted with apply (2) + + (2 (1 . self) rest) + + Resulting in a function which places its own arguments after those + curried in in the form of a proper list. + """ + + def curry(self, mod, *args) -> Any: + fixed_args: Any = 1 + for arg in reversed(args): + fixed_args = [self.dialect.C_KW, (self.dialect.Q_KW, arg), fixed_args] + return [self.dialect.A_KW, (self.dialect.Q_KW, mod), fixed_args] + + """ + uncurry the given program + + returns `mod, [arg1, arg2, ...]` + + if the program is not a valid curry, return `sexp, NULL` + + This distinguishes it from the case of a valid curry of 0 arguments + (which is rather pointless but possible), which returns `sexp, []` + """ + + def uncurry( + self, sexp: CLVMStorage + ) -> Tuple[CLVMStorage, Optional[List[CLVMStorage]]]: + dialect = self.dialect + if ( + at(sexp, "f") != dialect.A_KW + or at(sexp, "rff") != dialect.Q_KW + or at(sexp, "rrr") != dialect.NULL + ): + return sexp, None + # since "rff" is not None, neither is "rfr" + uncurried_function = at(sexp, "rfr") + assert uncurried_function is not None + core_items = [] + + # since "rrr" is not None, neither is rrf + core = at(sexp, "rrf") + while core != dialect.ONE: + assert core is not None + if ( + at(core, "f") != dialect.C_KW + or at(core, "rff") != dialect.Q_KW + or at(core, "rrr") != dialect.NULL + ): + return sexp, None + # since "rff" is not None, neither is "rfr" + new_item = at(core, "rfr") + assert new_item is not None + core_items.append(new_item) + # since "rrr" is not None, neither is rrf + core = at(core, "rrf") + return uncurried_function, core_items + + +CHIA_CURRY_TREEHASHER = CurryTreehasher(CHIA_DIALECT) diff --git a/wheel/python/clvm_rs/de.py b/wheel/python/clvm_rs/de.py index 0b2a2ec5..aeb8bdb6 100644 --- a/wheel/python/clvm_rs/de.py +++ b/wheel/python/clvm_rs/de.py @@ -24,7 +24,10 @@ def deserialize_as_tuples( ) -> Tuple[List[Tuple[int, int, int]], List[Optional[bytes]]]: if deserialize_as_tree: - tree, hashes = deserialize_as_tree(blob, calculate_tree_hash) + try: + tree, hashes = deserialize_as_tree(blob, calculate_tree_hash) + except OSError as ex: + raise ValueError(ex) if not calculate_tree_hash: hashes = [None] * len(tree) return tree, hashes diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index ba0a9cb5..c46fd778 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -4,15 +4,14 @@ from .at import at from .bytes32 import bytes32 from .casts import to_clvm_object, int_from_bytes, int_to_bytes -from .chia_dialect import Dialect, chia_dialect from .clvm_rs import run_serialized_program from .clvm_storage import CLVMStorage from .clvm_tree import CLVMTree +from .curry_and_treehash import CurryTreehasher, CHIA_CURRY_TREEHASHER from .eval_error import EvalError from .replace import replace from .ser import sexp_from_stream, sexp_to_stream, sexp_to_bytes -from .tree_hash import CHIA_TREEHASHER -from .uncurry import uncurry +from .tree_hash import sha256_treehash MAX_COST = 0x7FFFFFFFFFFFFFFF @@ -23,7 +22,7 @@ class Program(CLVMStorage): A thin wrapper around s-expression data intended to be invoked with "eval". """ - dialect: Dialect = chia_dialect + curry_treehasher: CurryTreehasher = CHIA_CURRY_TREEHASHER # serialization/deserialization @@ -228,7 +227,7 @@ def replace(self, **kwargs) -> "Program": return self.to(replace(self, **kwargs)) def tree_hash(self) -> bytes32: - return CHIA_TREEHASHER.sha256_treehash(self) + return sha256_treehash(self) def run_with_cost(self, args, max_cost: int = MAX_COST) -> Tuple[int, "Program"]: prog_bytes = bytes(self) @@ -264,10 +263,7 @@ def run(self, args) -> "Program": """ def curry(self, *args) -> "Program": - fixed_args: Any = 1 - for arg in reversed(args): - fixed_args = [self.dialect.C_KW, (self.dialect.Q_KW, arg), fixed_args] - return self.to([self.dialect.A_KW, (self.dialect.Q_KW, self), fixed_args]) + return self.to(self.curry_treehasher.curry(self, *args)) """ uncurry the given program @@ -281,7 +277,7 @@ def curry(self, *args) -> "Program": """ def uncurry(self) -> Tuple[Program, Optional[List[Program]]]: - mod, args = uncurry(self.dialect, self) + mod, args = self.curry_treehasher.uncurry(self) p_args = args if args is None else [self.to(_) for _ in args] return self.to(mod), p_args diff --git a/wheel/python/clvm_rs/tree_hash.py b/wheel/python/clvm_rs/tree_hash.py index cba589ce..b26257cb 100644 --- a/wheel/python/clvm_rs/tree_hash.py +++ b/wheel/python/clvm_rs/tree_hash.py @@ -10,18 +10,16 @@ from typing import List from weakref import WeakKeyDictionary -from .chia_dialect import Dialect, chia_dialect from .clvm_storage import CLVMStorage bytes32 = bytes -class Treehash: +class Treehasher: atom_prefix: bytes pair_prefix: bytes - def __init__(self, dialect: Dialect, atom_prefix: bytes, pair_prefix: bytes): - self.dialect = dialect + def __init__(self, atom_prefix: bytes, pair_prefix: bytes): self.atom_prefix = atom_prefix self.pair_prefix = pair_prefix self.hash_cache: WeakKeyDictionary[CLVMStorage, bytes32] = WeakKeyDictionary() @@ -80,6 +78,8 @@ def handle_pair(sexp_stack, hash_stack, op_stack) -> None: return hash_stack[0] -CHIA_TREEHASHER = Treehash(chia_dialect, bytes.fromhex("01"), bytes.fromhex("02")) +CHIA_TREEHASHER = Treehasher(bytes.fromhex("01"), bytes.fromhex("02")) + +sha256_treehash = CHIA_TREEHASHER.sha256_treehash shatree_atom = CHIA_TREEHASHER.shatree_atom shatree_pair = CHIA_TREEHASHER.shatree_pair diff --git a/wheel/python/clvm_rs/uncurry.py b/wheel/python/clvm_rs/uncurry.py deleted file mode 100644 index 5f0cf693..00000000 --- a/wheel/python/clvm_rs/uncurry.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations -from typing import List, Tuple, Optional - -from .at import at -from .chia_dialect import Dialect -from .clvm_storage import CLVMStorage - - -""" -uncurry the given program - -returns `mod, [arg1, arg2, ...]` - -if the program is not a valid curry, return `sexp, NULL` - -This distinguishes it from the case of a valid curry of 0 arguments -(which is rather pointless but possible), which returns `sexp, []` -""" - - -def uncurry( - dialect: Dialect, sexp: CLVMStorage -) -> Tuple[CLVMStorage, Optional[List[CLVMStorage]]]: - if ( - at(sexp, "f") != dialect.A_KW - or at(sexp, "rff") != dialect.Q_KW - or at(sexp, "rrr") != dialect.NULL - ): - return sexp, None - # since "rff" is not None, neither is "rfr" - uncurried_function = at(sexp, "rfr") - assert uncurried_function is not None - core_items = [] - - # since "rrr" is not None, neither is rrf - core = at(sexp, "rrf") - while core != dialect.ONE: - assert core is not None - if ( - at(core, "f") != dialect.C_KW - or at(core, "rff") != dialect.Q_KW - or at(core, "rrr") != dialect.NULL - ): - return sexp, None - # since "rff" is not None, neither is "rfr" - new_item = at(core, "rfr") - assert new_item is not None - core_items.append(new_item) - # since "rrr" is not None, neither is rrf - core = at(core, "rrf") - return uncurried_function, core_items diff --git a/wheel/python/tests/test_curry_and_treehash.py b/wheel/python/tests/test_curry_and_treehash.py index 7b30a95d..6b651011 100644 --- a/wheel/python/tests/test_curry_and_treehash.py +++ b/wheel/python/tests/test_curry_and_treehash.py @@ -1,7 +1,9 @@ from clvm_rs.program import Program -from clvm_rs.curry_and_treehash import ( - calculate_hash_of_quoted_mod_hash, - curry_and_treehash, +from clvm_rs.curry_and_treehash import CHIA_CURRY_TREEHASHER + +curry_and_treehash = CHIA_CURRY_TREEHASHER.curry_and_treehash +calculate_hash_of_quoted_mod_hash = ( + CHIA_CURRY_TREEHASHER.calculate_hash_of_quoted_mod_hash ) diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py index b909fa68..6e94983e 100644 --- a/wheel/python/tests/test_program.py +++ b/wheel/python/tests/test_program.py @@ -1,11 +1,11 @@ from unittest import TestCase -from clvm_rs.chia_dialect import chia_dialect +from clvm_rs.chia_dialect import CHIA_DIALECT from clvm_rs.clvm_storage import CLVMStorage from clvm_rs.program import Program from clvm_rs.eval_error import EvalError -A_KW, C_KW, Q_KW = [getattr(chia_dialect, _) for _ in "A_KW C_KW Q_KW".split()] +A_KW, C_KW, Q_KW = [getattr(CHIA_DIALECT, _) for _ in "A_KW C_KW Q_KW".split()] class TestProgram(TestCase): diff --git a/wheel/python/tests/test_serialize.py b/wheel/python/tests/test_serialize.py index 7363ce0e..c55d0f8b 100644 --- a/wheel/python/tests/test_serialize.py +++ b/wheel/python/tests/test_serialize.py @@ -129,12 +129,15 @@ def test_deserialize_large_blob(self): Program.parse(InfiniteStream(bytes_in)) def test_repr_clvm_tree(self): - o = Program.fromhex("ff8085") + with self.assertRaises(ValueError): + Program.fromhex("ff8085") + + o = Program.fromhex("ff808185") self.assertEqual(repr(o._unwrapped_pair[0]), "") - self.assertEqual(repr(o._unwrapped_pair[1]), "") + self.assertEqual(repr(o._unwrapped_pair[1]), "") def test_bad_blob(self): - self.assertRaises(OSError, lambda: Program.fromhex("ff")) + self.assertRaises(ValueError, lambda: Program.fromhex("ff")) def test_large_atom(self): s = "foo" * 100 @@ -144,5 +147,5 @@ def test_large_atom(self): self.assertEqual(p, p1) def test_too_large_atom(self): - self.assertRaises(OSError, lambda: Program.fromhex("fc")) - self.assertRaises(OSError, lambda: Program.fromhex("fc8000000000")) + self.assertRaises(ValueError, lambda: Program.fromhex("fc")) + self.assertRaises(ValueError, lambda: Program.fromhex("fc8000000000")) From 4b8fe5bca2230cfde15f397f44ee952903bc2daa Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 25 Jan 2023 16:46:33 -0800 Subject: [PATCH 27/45] Fix name --- wheel/python/clvm_rs/program.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index c46fd778..d086e561 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -4,7 +4,7 @@ from .at import at from .bytes32 import bytes32 from .casts import to_clvm_object, int_from_bytes, int_to_bytes -from .clvm_rs import run_serialized_program +from .clvm_rs import run_serialized_chia_program from .clvm_storage import CLVMStorage from .clvm_tree import CLVMTree from .curry_and_treehash import CurryTreehasher, CHIA_CURRY_TREEHASHER @@ -233,7 +233,7 @@ def run_with_cost(self, args, max_cost: int = MAX_COST) -> Tuple[int, "Program"] prog_bytes = bytes(self) args_bytes = bytes(self.to(args)) try: - cost, r = run_serialized_program(prog_bytes, args_bytes, max_cost, 0) + cost, r = run_serialized_chia_program(prog_bytes, args_bytes, max_cost, 0) r = self.wrap(r) except ValueError as ve: raise EvalError(ve.args[0], self.wrap(ve.args[1])) From 0acda17e7a5ac221d95d219a99adb419552aa1a5 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 25 Jan 2023 17:06:23 -0800 Subject: [PATCH 28/45] refactor --- wheel/python/clvm_rs/ser.py | 50 +++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/wheel/python/clvm_rs/ser.py b/wheel/python/clvm_rs/ser.py index 967f7588..13d3579d 100644 --- a/wheel/python/clvm_rs/ser.py +++ b/wheel/python/clvm_rs/ser.py @@ -46,26 +46,16 @@ def sexp_to_byte_iterator(sexp: CLVMStorage) -> Iterator[bytes]: yield from atom_to_byte_iterator(atom) -def atom_to_byte_iterator(as_atom: bytes) -> Iterator[bytes]: - """ - Yield the serialization for a given blob (as a clvm atom). - """ - size = len(as_atom) - if size == 0: - yield b"\x80" - return - if size == 1: - if as_atom[0] <= MAX_SINGLE_BYTE: - yield as_atom - return +def size_blob_for_size(blob: int) -> bytes: + size = len(blob) if size < 0x40: - size_blob = bytes([0x80 | size]) - elif size < 0x2000: - size_blob = bytes([0xC0 | (size >> 8), (size >> 0) & 0xFF]) - elif size < 0x100000: - size_blob = bytes([0xE0 | (size >> 16), (size >> 8) & 0xFF, (size >> 0) & 0xFF]) - elif size < 0x8000000: - size_blob = bytes( + return bytes([0x80 | size]) + if size < 0x2000: + return bytes([0xC0 | (size >> 8), (size >> 0) & 0xFF]) + if size < 0x100000: + return bytes([0xE0 | (size >> 16), (size >> 8) & 0xFF, (size >> 0) & 0xFF]) + if size < 0x8000000: + return bytes( [ 0xF0 | (size >> 24), (size >> 16) & 0xFF, @@ -73,8 +63,8 @@ def atom_to_byte_iterator(as_atom: bytes) -> Iterator[bytes]: (size >> 0) & 0xFF, ] ) - elif size < 0x400000000: - size_blob = bytes( + if size < 0x400000000: + return bytes( [ 0xF8 | (size >> 32), (size >> 24) & 0xFF, @@ -83,10 +73,22 @@ def atom_to_byte_iterator(as_atom: bytes) -> Iterator[bytes]: (size >> 0) & 0xFF, ] ) - else: - raise ValueError("sexp too long %r" % as_atom) + raise ValueError("sexp too long %r" % blob) + - yield size_blob +def atom_to_byte_iterator(as_atom: bytes) -> Iterator[bytes]: + """ + Yield the serialization for a given blob (as a clvm atom). + """ + size = len(as_atom) + if size == 0: + yield b"\x80" + return + if size == 1: + if as_atom[0] <= MAX_SINGLE_BYTE: + yield as_atom + return + yield size_blob_for_size(as_atom) yield as_atom From 67dc587138db702cca451fc15b9ad11ad44d607a Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 25 Jan 2023 22:31:38 -0800 Subject: [PATCH 29/45] fix benchmarks --- benchmark/run-benchmark.py | 20 ++++++-------------- wheel/python/clvm_rs/__init__.py | 2 +- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/benchmark/run-benchmark.py b/benchmark/run-benchmark.py index 36fa80e9..b29dd645 100755 --- a/benchmark/run-benchmark.py +++ b/benchmark/run-benchmark.py @@ -6,9 +6,7 @@ import os import time import random -from clvm import KEYWORD_FROM_ATOM, KEYWORD_TO_ATOM -from clvm.operators import OP_REWRITE -from clvm_rs import run_serialized_chia_program +from clvm_rs import Program from colorama import init, Fore, Style init() @@ -184,12 +182,6 @@ def need_update(file_path, mtime): test_runs = {} test_costs = {} -native_opcode_names_by_opcode = dict( - ("op_%s" % OP_REWRITE.get(k, k), op) - for op, k in KEYWORD_FROM_ATOM.items() - if k not in "qa." -) - print('benchmarking...') for n in range(5): if "-v" in sys.argv: @@ -209,13 +201,13 @@ def need_update(file_path, mtime): else: if "-v" in sys.argv: print(fn) - program_data = bytes.fromhex(open(fn, 'r').read()) - env_data = bytes.fromhex(open(env_fn, 'r').read()) + program = Program.fromhex(open(fn, 'r').read()) + env = Program.fromhex(open(env_fn, 'r').read()) time_start = time.perf_counter() - cost, result = run_serialized_chia_program( - program_data, - env_data, + cost, result = program.run( + program, + env, max_cost, flags, ) diff --git a/wheel/python/clvm_rs/__init__.py b/wheel/python/clvm_rs/__init__.py index b9c3d677..fcd136c9 100644 --- a/wheel/python/clvm_rs/__init__.py +++ b/wheel/python/clvm_rs/__init__.py @@ -1,2 +1,2 @@ from .eval_error import EvalError -from .program import Program \ No newline at end of file +from .program import Program From edde719a58aa1ab6bf4f5edb826f3e18152f28b8 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 25 Jan 2023 22:46:01 -0800 Subject: [PATCH 30/45] fix benchmark --- benchmark/run-benchmark.py | 2 +- wheel/python/clvm_rs/program.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/benchmark/run-benchmark.py b/benchmark/run-benchmark.py index b29dd645..620050f5 100755 --- a/benchmark/run-benchmark.py +++ b/benchmark/run-benchmark.py @@ -205,7 +205,7 @@ def need_update(file_path, mtime): env = Program.fromhex(open(env_fn, 'r').read()) time_start = time.perf_counter() - cost, result = program.run( + cost, result = program.run_with_cost( program, env, max_cost, diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index d086e561..c7b7c02a 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -229,11 +229,15 @@ def replace(self, **kwargs) -> "Program": def tree_hash(self) -> bytes32: return sha256_treehash(self) - def run_with_cost(self, args, max_cost: int = MAX_COST) -> Tuple[int, "Program"]: + def run_with_cost( + self, args, max_cost: int = MAX_COST, flags: int = 0 + ) -> Tuple[int, "Program"]: prog_bytes = bytes(self) args_bytes = bytes(self.to(args)) try: - cost, r = run_serialized_chia_program(prog_bytes, args_bytes, max_cost, 0) + cost, r = run_serialized_chia_program( + prog_bytes, args_bytes, max_cost, flags + ) r = self.wrap(r) except ValueError as ve: raise EvalError(ve.args[0], self.wrap(ve.args[1])) From 0c25e04ab307fedab76541948c0e8730138ae011 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 25 Jan 2023 23:14:41 -0800 Subject: [PATCH 31/45] Support py37 --- wheel/python/clvm_rs/casts.py | 8 ++++---- wheel/python/clvm_rs/clvm_storage.py | 21 +++++++++++---------- wheel/python/tests/test_to_program.py | 16 ++++++++++------ 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/wheel/python/clvm_rs/casts.py b/wheel/python/clvm_rs/casts.py index 7b762e3e..455693fc 100644 --- a/wheel/python/clvm_rs/casts.py +++ b/wheel/python/clvm_rs/casts.py @@ -4,7 +4,7 @@ from typing import Callable, List, SupportsBytes, Tuple, Union -from .clvm_storage import CLVMStorage +from .clvm_storage import CLVMStorage, is_clvm_storage AtomCastableType = Union[ bytes, @@ -98,7 +98,7 @@ def to_clvm_object( # convert value if op == 0: v = to_convert.pop() - if isinstance(v, CLVMStorage): + if is_clvm_storage(v): if v.pair is None: atom = v.atom assert atom is not None @@ -111,8 +111,8 @@ def to_clvm_object( if len(v) != 2: raise ValueError("can't cast tuple of size %d" % len(v)) left, right = v - ll_right = isinstance(right, CLVMStorage) - ll_left = isinstance(left, CLVMStorage) + ll_right = is_clvm_storage(right) + ll_left = is_clvm_storage(left) if ll_right and ll_left: did_convert.append(to_pair_f(left, right)) else: diff --git a/wheel/python/clvm_rs/clvm_storage.py b/wheel/python/clvm_rs/clvm_storage.py index 8965019f..1385089b 100644 --- a/wheel/python/clvm_rs/clvm_storage.py +++ b/wheel/python/clvm_rs/clvm_storage.py @@ -1,4 +1,12 @@ -from typing import Optional, Protocol, Tuple, runtime_checkable +from typing import Optional, Tuple + +# we support py3.7 which doesn't yet have typing.Protocol + +try: + from typing import Protocol, runtime_checkable +except ImportError: + Protocol = object + runtime_checkable = lambda arg: arg @runtime_checkable @@ -15,12 +23,5 @@ def pair(self) -> Optional[Tuple["CLVMStorage", "CLVMStorage"]]: # `_cached_serialization: bytes` is used by `sexp_to_byte_iterator` to speed up serialization -@runtime_checkable -class CLVMStorageFactory(Protocol): - @classmethod - def new_atom(cls, v: bytes) -> "CLVMStorage": - ... - - @classmethod - def new_pair(cls, p1: "CLVMStorage", p2: "CLVMStorage") -> "CLVMStorage": - ... +def is_clvm_storage(obj): + return hasattr(obj, "atom") and hasattr(obj, "pair") diff --git a/wheel/python/tests/test_to_program.py b/wheel/python/tests/test_to_program.py index 70dceb2c..ed6ddeef 100644 --- a/wheel/python/tests/test_to_program.py +++ b/wheel/python/tests/test_to_program.py @@ -1,7 +1,7 @@ import unittest from typing import Optional, Tuple, Any, Union -from clvm_rs.clvm_storage import CLVMStorage +from clvm_rs.clvm_storage import CLVMStorage, is_clvm_storage from clvm_rs.program import Program @@ -17,8 +17,8 @@ def validate_program(program): if v.pair: assert isinstance(v.pair, tuple) v1, v2 = v.pair - assert isinstance(v1, CLVMStorage) - assert isinstance(v2, CLVMStorage) + assert is_clvm_storage(v1) + assert is_clvm_storage(v2) s1, s2 = v.as_pair() validate_stack.append(s1) validate_stack.append(s2) @@ -101,6 +101,10 @@ def pair(self) -> Optional[Tuple[Any, Any]]: GeneratedTree(new_depth, self.val + 2**new_depth), ) + @classmethod + def isinstance(cls, obj): + return isinstance(obj, cls) + tree = Program.to(GeneratedTree(5, 0)) assert ( print_leaves(tree) @@ -128,15 +132,15 @@ class dummy: obj.atom = None obj.pair = None print(dir(obj)) - assert isinstance(obj, CLVMStorage) + assert is_clvm_storage(obj) obj = dummy() obj.pair = None - assert not isinstance(obj, CLVMStorage) + assert not is_clvm_storage(obj) obj = dummy() obj.atom = None - assert not isinstance(obj, CLVMStorage) + assert not is_clvm_storage(obj) def test_list_conversions(self): a = Program.to([1, 2, 3]) From 4f5c9dcb9e747dee32342731a0066d5fe9413e71 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 25 Jan 2023 23:35:57 -0800 Subject: [PATCH 32/45] Use api --- benchmark/run-benchmark.py | 1 - tests/run.py | 15 +++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/benchmark/run-benchmark.py b/benchmark/run-benchmark.py index 620050f5..03a7717f 100755 --- a/benchmark/run-benchmark.py +++ b/benchmark/run-benchmark.py @@ -206,7 +206,6 @@ def need_update(file_path, mtime): time_start = time.perf_counter() cost, result = program.run_with_cost( - program, env, max_cost, flags, diff --git a/tests/run.py b/tests/run.py index dd231ee7..7e01384d 100755 --- a/tests/run.py +++ b/tests/run.py @@ -1,24 +1,23 @@ #!/usr/bin/env python3 -from clvm_rs import run_serialized_chia_program +from clvm_rs import Program def run_clvm(fn, env=None): - program_data = bytes.fromhex(open(fn, 'r').read()) + program = Program.fromhex(open(fn, 'r').read()) if env is not None: - env_data = bytes.fromhex(open(env, 'r').read()) + env = Program.fromhex(open(env, 'r').read()) else: - env_data = bytes.fromhex("ff80") + env = Program.fromhex("ff80") # constants from the main chia blockchain: # https://github.com/Chia-Network/chia-blockchain/blob/main/chia/consensus/default_constants.py max_cost = 11000000000 cost_per_byte = 12000 - max_cost -= (len(program_data) + len(env_data)) * cost_per_byte - return run_serialized_chia_program( - program_data, - env_data, + max_cost -= (len(bytes(program)) + len(bytes(env))) * cost_per_byte + return program.run_with_cost( + env, max_cost, 0, ) From 3bc775d65ac416cd0f840f37dc8e2e1b6da59c27 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Fri, 27 Jan 2023 14:42:40 -0800 Subject: [PATCH 33/45] fix comments --- wheel/python/clvm_rs/program.py | 45 +++++++++++++++------------------ 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index c7b7c02a..cbe178e8 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -197,7 +197,6 @@ def at(self, position: str) -> Optional["Program"]: p1 = Program.to(10) assert None == at(p1, "rr") ``` - """ return self.to(at(self, position)) @@ -247,40 +246,38 @@ def run(self, args) -> "Program": cost, r = self.run_with_cost(args, MAX_COST) return r - """ - Replicates the curry function from clvm_tools, taking advantage of *args - being a list. We iterate through args in reverse building the code to - create a clvm list. - - Given arguments to a function addressable by the '1' reference in clvm + def curry(self, *args) -> "Program": + """ + Replicates the curry function from clvm_tools, taking advantage of *args + being a list. We iterate through args in reverse building the code to + create a clvm list. - fixed_args = 1 + Given arguments to a function addressable by the '1' reference in clvm - Each arg is prepended as fixed_args = (c (q . arg) fixed_args) + fixed_args = 1 - The resulting argument list is interpreted with apply (2) + Each arg is prepended as fixed_args = (c (q . arg) fixed_args) - (2 (1 . self) rest) + The resulting argument list is interpreted with apply (2) - Resulting in a function which places its own arguments after those - curried in in the form of a proper list. - """ + (2 (1 . self) rest) - def curry(self, *args) -> "Program": + Resulting in a function which places its own arguments after those + curried in in the form of a proper list. + """ return self.to(self.curry_treehasher.curry(self, *args)) - """ - uncurry the given program + def uncurry(self) -> Tuple[Program, Optional[List[Program]]]: + """ + uncurry the given program - returns `mod, [arg1, arg2, ...]` + returns `mod, [arg1, arg2, ...]` - if the program is not a valid curry, return `self, NULL` + if the program is not a valid curry, return `self, NULL` - This distinguishes it from the case of a valid curry of 0 arguments - (which is rather pointless but possible), which returns `self, []` - """ - - def uncurry(self) -> Tuple[Program, Optional[List[Program]]]: + This distinguishes it from the case of a valid curry of 0 arguments + (which is rather pointless but possible), which returns `self, []` + """ mod, args = self.curry_treehasher.uncurry(self) p_args = args if args is None else [self.to(_) for _ in args] return self.to(mod), p_args From 0ffa864383d7c75e36bf29d324d3493d5b9fa531 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Tue, 31 Jan 2023 16:16:58 -0800 Subject: [PATCH 34/45] lint --- wheel/python/clvm_rs/__init__.py | 4 +- wheel/python/clvm_rs/casts.py | 6 ++- wheel/python/clvm_rs/clvm_rs.pyi | 23 +++++++++ wheel/python/clvm_rs/clvm_storage.py | 8 ++-- wheel/python/clvm_rs/clvm_tree.py | 11 +++-- wheel/python/clvm_rs/curry_and_treehash.py | 4 ++ wheel/python/clvm_rs/de.py | 34 +++++++------- wheel/python/clvm_rs/program.py | 47 +++++++++++++------ wheel/python/clvm_rs/ser.py | 6 +-- wheel/python/clvm_rs/tree_hash.py | 14 ++---- wheel/python/tests/test_as_python.py | 12 ++++- wheel/python/tests/test_curry_and_treehash.py | 10 ++++ wheel/python/tests/test_program.py | 5 +- wheel/python/tests/test_to_program.py | 12 +++-- 14 files changed, 132 insertions(+), 64 deletions(-) create mode 100644 wheel/python/clvm_rs/clvm_rs.pyi diff --git a/wheel/python/clvm_rs/__init__.py b/wheel/python/clvm_rs/__init__.py index fcd136c9..1648893c 100644 --- a/wheel/python/clvm_rs/__init__.py +++ b/wheel/python/clvm_rs/__init__.py @@ -1,2 +1,2 @@ -from .eval_error import EvalError -from .program import Program +from .eval_error import EvalError # noqa: F401 +from .program import Program # noqa: F401 diff --git a/wheel/python/clvm_rs/casts.py b/wheel/python/clvm_rs/casts.py index 455693fc..960947d8 100644 --- a/wheel/python/clvm_rs/casts.py +++ b/wheel/python/clvm_rs/casts.py @@ -2,7 +2,7 @@ Some utilities to cast python types to and from clvm. """ -from typing import Callable, List, SupportsBytes, Tuple, Union +from typing import Callable, List, SupportsBytes, Tuple, Union, cast from .clvm_storage import CLVMStorage, is_clvm_storage @@ -99,6 +99,7 @@ def to_clvm_object( if op == 0: v = to_convert.pop() if is_clvm_storage(v): + v = cast(CLVMStorage, v) if v.pair is None: atom = v.atom assert atom is not None @@ -114,6 +115,8 @@ def to_clvm_object( ll_right = is_clvm_storage(right) ll_left = is_clvm_storage(left) if ll_right and ll_left: + left = cast(CLVMStorage, left) + right = cast(CLVMStorage, right) did_convert.append(to_pair_f(left, right)) else: ops.append(1) # cons @@ -134,6 +137,7 @@ def to_clvm_object( to_convert.append(_) ops.append(0) # convert continue + v = cast(AtomCastableType, v) did_convert.append(to_atom_f(to_atom_type(v))) continue if op == 1: # cons diff --git a/wheel/python/clvm_rs/clvm_rs.pyi b/wheel/python/clvm_rs/clvm_rs.pyi new file mode 100644 index 00000000..166a1adf --- /dev/null +++ b/wheel/python/clvm_rs/clvm_rs.pyi @@ -0,0 +1,23 @@ +from typing import List, Optional, Tuple + +from .clvm_storage import CLVMStorage + +def run_serialized_chia_program( + program: bytes, environment: bytes, max_cost: int, flags: int +) -> Tuple[int, CLVMStorage]: ... +def deserialize_as_tree( + blob: bytes, calculate_tree_hashes: bool +) -> Tuple[List[Tuple[int, int, int]], Optional[List[bytes]]]: ... +def serialized_length(blob: bytes) -> int: ... + +NO_NEG_DIV: int +NO_UNKNOWN_OPS: int +LIMIT_HEAP: int +LIMIT_STACK: int +MEMPOOL_MODE: int + +class LazyNode(CLVMStorage): + atom: Optional[bytes] + + @property + def pair(self) -> Optional[Tuple[CLVMStorage, CLVMStorage]]: ... diff --git a/wheel/python/clvm_rs/clvm_storage.py b/wheel/python/clvm_rs/clvm_storage.py index 1385089b..65e63507 100644 --- a/wheel/python/clvm_rs/clvm_storage.py +++ b/wheel/python/clvm_rs/clvm_storage.py @@ -1,15 +1,13 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, _SpecialForm, cast # we support py3.7 which doesn't yet have typing.Protocol try: - from typing import Protocol, runtime_checkable + from typing import Protocol except ImportError: - Protocol = object - runtime_checkable = lambda arg: arg + Protocol = cast(_SpecialForm, object) -@runtime_checkable class CLVMStorage(Protocol): atom: Optional[bytes] diff --git a/wheel/python/clvm_rs/clvm_tree.py b/wheel/python/clvm_rs/clvm_tree.py index 6c894589..bca89a55 100644 --- a/wheel/python/clvm_rs/clvm_tree.py +++ b/wheel/python/clvm_rs/clvm_tree.py @@ -1,7 +1,7 @@ from .clvm_storage import CLVMStorage from .de import deserialize_as_tuples -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union class CLVMTree(CLVMStorage): @@ -48,21 +48,22 @@ def from_bytes(cls, blob: bytes, calculate_tree_hash: bool = True) -> "CLVMTree" def __init__( self, - blob: bytes, + blob: Union[memoryview, bytes], int_tuples: List[Tuple[int, int, int]], - tree_hashes: List[Optional[bytes]], + tree_hashes: Optional[List[bytes]], index: int, ): self.blob = blob self.int_tuples = int_tuples self.tree_hashes = tree_hashes self.index = index - self._cached_sha256_treehash = self.tree_hashes[index] + if self.tree_hashes: + self._cached_sha256_treehash = self.tree_hashes[index] start, end, atom_offset = self.int_tuples[self.index] if self.blob[start] == 0xFF: self.atom = None else: - self.atom = bytes(self.blob[start + atom_offset : end]) + self.atom = bytes(self.blob[start + atom_offset:end]) self._pair = None @property diff --git a/wheel/python/clvm_rs/curry_and_treehash.py b/wheel/python/clvm_rs/curry_and_treehash.py index b8d0c266..1c67993e 100644 --- a/wheel/python/clvm_rs/curry_and_treehash.py +++ b/wheel/python/clvm_rs/curry_and_treehash.py @@ -56,6 +56,10 @@ def curry_and_treehash( `arguments` : tree hashes of arguments to be curried """ + for arg in hashed_arguments: + if not isinstance(arg, bytes) or len(arg) != 32: + raise ValueError(f"arguments must be bytes of len 32: {arg.hex()}") + curried_values = self.curried_values_tree_hash(list(hashed_arguments)) return shatree_pair( self.a_kw_treehash, diff --git a/wheel/python/clvm_rs/de.py b/wheel/python/clvm_rs/de.py index aeb8bdb6..fcd9b5b6 100644 --- a/wheel/python/clvm_rs/de.py +++ b/wheel/python/clvm_rs/de.py @@ -2,6 +2,10 @@ from .tree_hash import shatree_atom, shatree_pair +deserialize_as_tree: Optional[ + Callable[[bytes, bool], Tuple[List[Tuple[int, int, int]], Optional[List[bytes]]]] +] + try: from clvm_rs.clvm_rs import deserialize_as_tree except ImportError: @@ -21,15 +25,13 @@ def deserialize_as_tuples( blob: bytes, cursor: int, calculate_tree_hash: bool -) -> Tuple[List[Tuple[int, int, int]], List[Optional[bytes]]]: +) -> Tuple[List[Tuple[int, int, int]], Optional[List[bytes]]]: if deserialize_as_tree: try: tree, hashes = deserialize_as_tree(blob, calculate_tree_hash) except OSError as ex: raise ValueError(ex) - if not calculate_tree_hash: - hashes = [None] * len(tree) return tree, hashes def save_cursor( @@ -41,17 +43,14 @@ def save_cursor( ) -> int: blob_index = obj_list[index][0] assert blob[blob_index] == 0xFF - left_hash = tree_hash_list[index + 1] - hash_index = obj_list[index][2] - right_hash = tree_hash_list[hash_index] - tree_hash_list[index] = None - if calculate_tree_hash: - assert left_hash is not None - assert right_hash is not None - tree_hash_list[index] = shatree_pair(left_hash, right_hash) v0 = obj_list[index][0] v2 = obj_list[index][2] obj_list[index] = (v0, cursor, v2) + if calculate_tree_hash: + left_hash = tree_hash_list[index + 1] + hash_index = obj_list[index][2] + right_hash = tree_hash_list[hash_index] + tree_hash_list[index] = shatree_pair(left_hash, right_hash) return cursor def save_index( @@ -73,28 +72,29 @@ def parse_obj( if blob[cursor] == CONS_BOX_MARKER: index = len(obj_list) - tree_hash_list.append(None) obj_list.append((cursor, 0, 0)) op_stack.append(lambda *args: save_cursor(index, *args)) op_stack.append(parse_obj) op_stack.append(lambda *args: save_index(index, *args)) op_stack.append(parse_obj) + if calculate_tree_hash: + tree_hash_list.append(b"") return cursor + 1 atom_offset, new_cursor = _atom_size_from_cursor(blob, cursor) my_hash = None if calculate_tree_hash: - my_hash = shatree_atom(blob[cursor + atom_offset : new_cursor]) - tree_hash_list.append(my_hash) + my_hash = shatree_atom(blob[cursor + atom_offset:new_cursor]) + tree_hash_list.append(my_hash) obj_list.append((cursor, new_cursor, atom_offset)) return new_cursor obj_list: List[Triple] = [] - tree_hash_list: List[Optional[bytes]] = [] + tree_hash_list: List[bytes] = [] op_stack: List[DeserOp] = [parse_obj] while op_stack: f = op_stack.pop() cursor = f(blob, cursor, obj_list, op_stack) - return obj_list, tree_hash_list + return obj_list, tree_hash_list if calculate_tree_hash else None def _atom_size_from_cursor(blob, cursor) -> Tuple[int, int]: @@ -112,7 +112,7 @@ def _atom_size_from_cursor(blob, cursor) -> Tuple[int, int]: bit_mask >>= 1 size_blob = bytes([b]) if bit_count > 1: - size_blob += blob[cursor + 1 : cursor + bit_count] + size_blob += blob[cursor + 1:cursor + bit_count] size = int.from_bytes(size_blob, "big") new_cursor = cursor + size + bit_count if new_cursor > len(blob): diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index cbe178e8..af1b7f0e 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -19,7 +19,7 @@ class Program(CLVMStorage): """ - A thin wrapper around s-expression data intended to be invoked with "eval". + A wrapper around `CLVMStorage` providing many convenience functions. """ curry_treehasher: CurryTreehasher = CHIA_CURRY_TREEHASHER @@ -234,10 +234,10 @@ def run_with_cost( prog_bytes = bytes(self) args_bytes = bytes(self.to(args)) try: - cost, r = run_serialized_chia_program( + cost, lazy_node = run_serialized_chia_program( prog_bytes, args_bytes, max_cost, flags ) - r = self.wrap(r) + r = self.wrap(lazy_node) except ValueError as ve: raise EvalError(ve.args[0], self.wrap(ve.args[1])) return cost, r @@ -260,7 +260,7 @@ def curry(self, *args) -> "Program": The resulting argument list is interpreted with apply (2) - (2 (1 . self) rest) + (a (q . self) rest) Resulting in a function which places its own arguments after those curried in in the form of a proper list. @@ -282,8 +282,32 @@ def uncurry(self) -> Tuple[Program, Optional[List[Program]]]: p_args = args if args is None else [self.to(_) for _ in args] return self.to(mod), p_args + def curry_hash(self, *args: bytes32) -> bytes32: + """ + Return a puzzle hash that would be created if you curried this puzzle + with arguments that have the given hashes. + + In other words, + + ``` + c1 = self.curry(arg1, arg2, arg3).tree_hash() + c2 = self.curry_hash(arg1.tree_hash(), arg2.tree_hash(), arg3.tree_hash()) + assert c1 == c2 # they will be the same + ``` + + This looks useless to the unitiated, but sometimes you'll need a puzzle + hash where you don't actually know the contents of a clvm subtree -- just its + hash. This lets you calculate the puzzle hash with hidden information. + """ + curry_treehasher = self.curry_treehasher + quoted_mod_hash = curry_treehasher.calculate_hash_of_quoted_mod_hash(self.tree_hash()) + return curry_treehasher.curry_and_treehash(quoted_mod_hash, *args) + def as_int(self) -> int: - return int_from_bytes(self.as_atom()) + v = self.as_atom() + if v is None: + raise ValueError("can't cast pair to int") + return int_from_bytes(v) def as_iter(self) -> Iterator[Program]: v = self @@ -293,11 +317,8 @@ def as_iter(self) -> Iterator[Program]: def as_atom_iter(self) -> Iterator[bytes]: """ - Pretend `self` is a list of atoms. Yield the corresponding atoms. - - At each step, we always assume a node to be an atom or a pair. - If the assumption is wrong, we exit early. This way we never fail - and always return SOMETHING. + Pretend `self` is a list of atoms. Yield the corresponding atoms + up until this assumption is wrong. """ obj = self while obj.pair is not None: @@ -310,11 +331,7 @@ def as_atom_iter(self) -> Iterator[bytes]: def as_atom_list(self) -> List[bytes]: """ Pretend `self` is a list of atoms. Return the corresponding - python list of atoms. - - At each step, we always assume a node to be an atom or a pair. - If the assumption is wrong, we exit early. This way we never fail - and always return SOMETHING. + python list of atoms up until this assumption is wrong. """ return list(self.as_atom_iter()) diff --git a/wheel/python/clvm_rs/ser.py b/wheel/python/clvm_rs/ser.py index 13d3579d..654c099f 100644 --- a/wheel/python/clvm_rs/ser.py +++ b/wheel/python/clvm_rs/ser.py @@ -46,7 +46,7 @@ def sexp_to_byte_iterator(sexp: CLVMStorage) -> Iterator[bytes]: yield from atom_to_byte_iterator(atom) -def size_blob_for_size(blob: int) -> bytes: +def size_blob_for_blob(blob: bytes) -> bytes: size = len(blob) if size < 0x40: return bytes([0x80 | size]) @@ -73,7 +73,7 @@ def size_blob_for_size(blob: int) -> bytes: (size >> 0) & 0xFF, ] ) - raise ValueError("sexp too long %r" % blob) + raise ValueError("blob too long %r" % blob) def atom_to_byte_iterator(as_atom: bytes) -> Iterator[bytes]: @@ -88,7 +88,7 @@ def atom_to_byte_iterator(as_atom: bytes) -> Iterator[bytes]: if as_atom[0] <= MAX_SINGLE_BYTE: yield as_atom return - yield size_blob_for_size(as_atom) + yield size_blob_for_blob(as_atom) yield as_atom diff --git a/wheel/python/clvm_rs/tree_hash.py b/wheel/python/clvm_rs/tree_hash.py index b26257cb..cfbf526e 100644 --- a/wheel/python/clvm_rs/tree_hash.py +++ b/wheel/python/clvm_rs/tree_hash.py @@ -8,12 +8,10 @@ from hashlib import sha256 from typing import List -from weakref import WeakKeyDictionary +from .bytes32 import bytes32 from .clvm_storage import CLVMStorage -bytes32 = bytes - class Treehasher: atom_prefix: bytes @@ -22,7 +20,6 @@ class Treehasher: def __init__(self, atom_prefix: bytes, pair_prefix: bytes): self.atom_prefix = atom_prefix self.pair_prefix = pair_prefix - self.hash_cache: WeakKeyDictionary[CLVMStorage, bytes32] = WeakKeyDictionary() self.cache_hits = 0 def shatree_atom(self, atom: bytes) -> bytes32: @@ -42,8 +39,6 @@ def sha256_treehash(self, sexp: CLVMStorage) -> bytes32: def handle_sexp(sexp_stack, hash_stack, op_stack) -> None: sexp = sexp_stack.pop() r = getattr(sexp, "_cached_sha256_treehash", None) - if r is None: - r = self.hash_cache.get(sexp) if r is not None: self.cache_hits += 1 hash_stack.append(r) @@ -59,15 +54,16 @@ def handle_sexp(sexp_stack, hash_stack, op_stack) -> None: r = shatree_atom(sexp.atom) hash_stack.append(r) sexp._cached_sha256_treehash = r - self.hash_cache[sexp] = r def handle_pair(sexp_stack, hash_stack, op_stack) -> None: p0 = hash_stack.pop() p1 = hash_stack.pop() r = shatree_pair(p0, p1) hash_stack.append(r) - sexp._cached_sha256_treehash = r - self.hash_cache[sexp] = r + try: + setattr(sexp, "_cached_sha256_treehash", r) + except AttributeError: + pass sexp_stack = [sexp] op_stack = [handle_sexp] diff --git a/wheel/python/tests/test_as_python.py b/wheel/python/tests/test_as_python.py index fe38ca30..b33a4672 100644 --- a/wheel/python/tests/test_as_python.py +++ b/wheel/python/tests/test_as_python.py @@ -2,7 +2,13 @@ from clvm_rs.program import Program -from blspy import G1Element + +class castable_to_bytes: + def __init__(self, blob): + self.blob = blob + + def __bytes__(self): + return self.blob class dummy_class: @@ -72,11 +78,13 @@ def test_list_of_one(self): self.assertEqual(v.pair[1].atom, b"") def test_g1element(self): + # we don't import `G1Element` here, but just replicate + # its most important property: a `__bytes__` method b = fh( "b3b8ac537f4fd6bde9b26221d49b54b17a506be147347dae5" "d081c0a6572b611d8484e338f3432971a9823976c6a232b" ) - v = Program.to(G1Element(b)) + v = Program.to(castable_to_bytes(b)) self.assertEqual(v.atom, b) def test_complex(self): diff --git a/wheel/python/tests/test_curry_and_treehash.py b/wheel/python/tests/test_curry_and_treehash.py index 6b651011..18cf2f47 100644 --- a/wheel/python/tests/test_curry_and_treehash.py +++ b/wheel/python/tests/test_curry_and_treehash.py @@ -1,3 +1,5 @@ +import pytest + from clvm_rs.program import Program from clvm_rs.curry_and_treehash import CHIA_CURRY_TREEHASHER @@ -26,3 +28,11 @@ def test_curry_and_treehash() -> None: hashed_args = [Program.to(_).tree_hash() for _ in args] puzzle_hash_via_f = curry_and_treehash(quoted_mod_hash, *hashed_args) assert puzzle_hash_via_curry == puzzle_hash_via_f + puzzle_hash_via_m = arbitrary_mod.curry_hash(*hashed_args) + assert puzzle_hash_via_curry == puzzle_hash_via_m + + +def test_bad_parameter() -> None: + arbitrary_mod = Program.fromhex("ff10ff02ff0580") # `(+ 2 5)` + with pytest.raises(ValueError): + arbitrary_mod.curry_hash(b"foo") diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py index 6e94983e..92c006fa 100644 --- a/wheel/python/tests/test_program.py +++ b/wheel/python/tests/test_program.py @@ -1,7 +1,6 @@ from unittest import TestCase from clvm_rs.chia_dialect import CHIA_DIALECT -from clvm_rs.clvm_storage import CLVMStorage from clvm_rs.program import Program from clvm_rs.eval_error import EvalError @@ -74,6 +73,10 @@ def test_run_exception(self): self.assertEqual(err.args, ("clvm raise",)) self.assertEqual(err._sexp, ["foo", "bar"]) + def test_hash(self): + p1 = Program.fromhex("80") + assert hash(p1) == id(p1) + def check_idempotency(p, *args): curried = p.curry(*args) diff --git a/wheel/python/tests/test_to_program.py b/wheel/python/tests/test_to_program.py index ed6ddeef..4ce6c618 100644 --- a/wheel/python/tests/test_to_program.py +++ b/wheel/python/tests/test_to_program.py @@ -1,7 +1,7 @@ import unittest from typing import Optional, Tuple, Any, Union -from clvm_rs.clvm_storage import CLVMStorage, is_clvm_storage +from clvm_rs.clvm_storage import is_clvm_storage from clvm_rs.program import Program @@ -34,7 +34,9 @@ def print_leaves(tree: Program) -> str: return "%d " % a[0] ret = "" - for i in tree.as_pair(): + pair = tree.as_pair() + assert pair is not None + for i in pair: ret += print_leaves(i) return ret @@ -47,8 +49,10 @@ def print_tree(tree: Program) -> str: return "() " return "%d " % a[0] + pair = tree.as_pair() + assert pair is not None ret = "(" - for i in tree.as_pair(): + for i in pair: ret += print_tree(i) ret += ")" return ret @@ -69,7 +73,7 @@ def test_wrap_program(self): o = Program.to(Program.to(1)) assert o.atom == bytes([1]) - def test_arbitrary_underlying_tree(self): + def test_arbitrary_underlying_tree(self) -> None: # Program provides a view on top of a tree of arbitrary types, as long as # those types implement the CLVMStorage protocol. This is an example of From c1724d71e5fc49b71139e08eb1b17974e469a72f Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Fri, 3 Feb 2023 13:53:00 -0800 Subject: [PATCH 35/45] Tests pass, coverage seems good, benchmarks seem good. --- wheel/python/benchmarks/deserialization.py | 195 +++++++++++------- wheel/python/clvm_rs/at.py | 5 +- wheel/python/clvm_rs/casts.py | 8 +- wheel/python/clvm_rs/clvm_tree.py | 10 +- wheel/python/clvm_rs/program.py | 31 ++- wheel/python/clvm_rs/tree_hash.py | 53 +++-- wheel/python/tests/test_as_python.py | 2 + wheel/python/tests/test_curry_and_treehash.py | 7 +- wheel/python/tests/test_program.py | 5 +- wheel/python/tests/test_to_program.py | 112 +++++++++- 10 files changed, 302 insertions(+), 126 deletions(-) diff --git a/wheel/python/benchmarks/deserialization.py b/wheel/python/benchmarks/deserialization.py index 3884b777..410e6a0c 100644 --- a/wheel/python/benchmarks/deserialization.py +++ b/wheel/python/benchmarks/deserialization.py @@ -1,91 +1,128 @@ import io +import pathlib import time from clvm_rs.program import Program - from clvm_rs.clvm_rs import serialized_length -def bench(f, name: str): +def bench(f, name: str, allow_slow=False): + r, t = bench_w_speed(f, name) + if not allow_slow and t > 0.01: + print("*** TOO SLOW") + print() + return r + + +def bench_w_speed(f, name: str): start = time.time() r = f() end = time.time() d = end - start print(f"{name}: {d:1.4f} s") - print() - return r - - -sha_prog = Program.fromhex("ff0bff0180") - -print(sha_prog.run("food")) -# breakpoint() - - -obj = bench(lambda: Program.parse(open("block-2500014.compressed.bin", "rb")), "obj = Program.parse(open([file]))") -bench(lambda: bytes(obj), "bytes(obj)") - -obj1 = bench( - lambda: Program.from_bytes(open("block-2500014.compressed.bin", "rb").read()), - "obj = Program.from_bytes([blob])", -) -bench(lambda: bytes(obj1), "bytes(obj)") - -cost, output = bench(lambda: obj.run_with_cost(0), "run") - -print(f"cost = {cost}") -blob = bench(lambda: print(f"output = {len(bytes(output))}"), "serialize LazyNode") -blob = bench(lambda: bytes(output), "serialize LazyNode again") - -bench(lambda: print(output.tree_hash().hex()), "print run tree hash LazyNode") -bench(lambda: print(output.tree_hash().hex()), "print run tree hash again LazyNode") - -des_output = bench( - lambda: Program.from_bytes(blob), "from_bytes output (with tree hashing)" -) -bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash") -bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash again") - -bench(lambda: print(serialized_length(blob)), "print serialized_length") - -des_output = bench( - lambda: Program.from_bytes(blob, calculate_tree_hash=False), - "from_bytes output (with no tree hashing)", -) -bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash") -bench(lambda: print(des_output.tree_hash().hex()), "print from_bytes tree hash again") - -reparsed_output = bench(lambda: Program.parse(io.BytesIO(blob)), "reparse output") -bench(lambda: print(reparsed_output.tree_hash().hex()), "print reparsed tree hash") -bench(lambda: print(reparsed_output.tree_hash().hex()), "print reparsed tree hash again") - - -foo = Program.to("foo") -o0 = Program.to((foo, obj)) -o1 = Program.to((foo, obj1)) - - -def compare(): - assert o0 == o1 - - -bench(compare, "compare") - -bench(lambda: bytes(o0), "to_bytes o0") -bench(lambda: bytes(o1), "to_bytes o1") - -bench(lambda: print(o0.tree_hash().hex()), "o0 tree_hash") -bench(lambda: print(o0.tree_hash().hex()), "o0 tree_hash (again)") - -bench(lambda: print(o1.tree_hash().hex()), "o1 tree_hash") -bench(lambda: print(o1.tree_hash().hex()), "o1 tree_hash (again)") - -o2 = Program.to((foo, output)) - -bench(lambda: print(o2.tree_hash().hex()), "o2 tree_hash") -bench(lambda: print(o2.tree_hash().hex()), "o2 tree_hash (again)") - -# start = time.time() -# obj1 = sexp_from_stream(io.BytesIO(out), SExp.to, allow_backrefs=True) -# end = time.time() -# print(end-start) + return r, d + + +def benchmark(): + block_path = pathlib.Path(__file__).parent / "block-2500014.compressed.bin" + obj = bench( + lambda: Program.parse(open(block_path, "rb")), + "obj = Program.parse(open([block_blob]))", + ) + bench(lambda: bytes(obj), "bytes(obj)") + + block_blob = open(block_path, "rb").read() + obj1 = bench( + lambda: Program.from_bytes(block_blob), + "obj = Program.from_bytes([block_blob])", + ) + bench(lambda: bytes(obj1), "bytes(obj)") + + cost, output = bench(lambda: obj.run_with_cost(0), "run", allow_slow=True) + print(f"cost = {cost}") + result_blob = bench( + lambda: bytes(output), + "serialize LazyNode", + allow_slow=True, + ) + print(f"output = {len(result_blob)}"), + + result_blob_2 = bench(lambda: bytes(output), "serialize LazyNode again") + assert result_blob == result_blob_2 + + bench( + lambda: print(output.tree_hash().hex()), + "tree hash LazyNode", + allow_slow=True, + ) + bench(lambda: print(output.tree_hash().hex()), "tree hash again LazyNode") + + des_output = bench( + lambda: Program.from_bytes(result_blob), + "from_bytes (with tree hashing)", + allow_slow=True, + ) + bench(lambda: des_output.tree_hash(), "from_bytes (with tree hashing) tree hash") + bench( + lambda: des_output.tree_hash(), "from_bytes (with tree hashing) tree hash again" + ) + + bench(lambda: serialized_length(result_blob), "serialized_length") + + des_output = bench( + lambda: Program.from_bytes(result_blob, calculate_tree_hash=False), + "from_bytes output (without tree hashing)", + allow_slow=True, + ) + bench( + lambda: des_output.tree_hash(), + "from_bytes (without tree hashing) tree hash", + allow_slow=True, + ) + bench( + lambda: des_output.tree_hash(), + "from_bytes (without tree hashing) tree hash again", + ) + + reparsed_output = bench( + lambda: Program.parse(io.BytesIO(result_blob)), + "reparse output", + allow_slow=True, + ) + bench(lambda: reparsed_output.tree_hash(), "reparsed tree hash", allow_slow=True) + bench( + lambda: reparsed_output.tree_hash(), + "reparsed tree hash again", + ) + + foo = Program.to("foo") + o0 = Program.to((foo, obj)) + o1 = Program.to((foo, obj1)) + + def compare(): + assert o0 == o1 + + bench(compare, "compare") + + bench(lambda: bytes(o0), "to_bytes o0") + bench(lambda: bytes(o1), "to_bytes o1") + + bench(lambda: print(o0.tree_hash().hex()), "o0 tree_hash") + bench(lambda: print(o0.tree_hash().hex()), "o0 tree_hash (again)") + + bench(lambda: print(o1.tree_hash().hex()), "o1 tree_hash") + bench(lambda: print(o1.tree_hash().hex()), "o1 tree_hash (again)") + + o2 = Program.to((foo, output)) + + bench(lambda: print(o2.tree_hash().hex()), "o2 tree_hash") + bench(lambda: print(o2.tree_hash().hex()), "o2 tree_hash (again)") + + # start = time.time() + # obj1 = sexp_from_stream(io.BytesIO(out), SExp.to, allow_backrefs=False) + # end = time.time() + # print(end-start) + + +if __name__ == "__main__": + benchmark() diff --git a/wheel/python/clvm_rs/at.py b/wheel/python/clvm_rs/at.py index dcfa1370..bd82b5e1 100644 --- a/wheel/python/clvm_rs/at.py +++ b/wheel/python/clvm_rs/at.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Optional +from typing import Optional, cast from .clvm_storage import CLVMStorage @@ -25,8 +25,7 @@ def at(obj: CLVMStorage, position: str) -> Optional[CLVMStorage]: """ v: Optional[CLVMStorage] = obj for c in position.lower(): - if v is None: - break + v = cast(CLVMStorage, v) pair = v.pair if pair is None: return None diff --git a/wheel/python/clvm_rs/casts.py b/wheel/python/clvm_rs/casts.py index 960947d8..778c04dc 100644 --- a/wheel/python/clvm_rs/casts.py +++ b/wheel/python/clvm_rs/casts.py @@ -100,13 +100,7 @@ def to_clvm_object( v = to_convert.pop() if is_clvm_storage(v): v = cast(CLVMStorage, v) - if v.pair is None: - atom = v.atom - assert atom is not None - new_obj = to_atom_f(to_atom_type(atom)) - else: - new_obj = to_pair_f(v.pair[0], v.pair[1]) - did_convert.append(new_obj) + did_convert.append(v) continue if isinstance(v, tuple): if len(v) != 2: diff --git a/wheel/python/clvm_rs/clvm_tree.py b/wheel/python/clvm_rs/clvm_tree.py index bca89a55..a2f85994 100644 --- a/wheel/python/clvm_rs/clvm_tree.py +++ b/wheel/python/clvm_rs/clvm_tree.py @@ -72,12 +72,10 @@ def pair(self) -> Optional[Tuple["CLVMStorage", "CLVMStorage"]]: tuples, tree_hashes = self.int_tuples, self.tree_hashes start, end, right_index = tuples[self.index] # if `self.blob[start]` is 0xff, it's a pair - if self.blob[start] == 0xFF: - left = self.__class__(self.blob, tuples, tree_hashes, self.index + 1) - right = self.__class__(self.blob, tuples, tree_hashes, right_index) - self._pair = (left, right) - else: - self._pair = None + assert self.blob[start] == 0xFF + left = self.__class__(self.blob, tuples, tree_hashes, self.index + 1) + right = self.__class__(self.blob, tuples, tree_hashes, right_index) + self._pair = (left, right) return self._pair @property diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index af1b7f0e..81fd97eb 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -23,6 +23,7 @@ class Program(CLVMStorage): """ curry_treehasher: CurryTreehasher = CHIA_CURRY_TREEHASHER + _cached_serialization: Optional[bytes] # serialization/deserialization @@ -44,7 +45,9 @@ def from_bytes(cls, blob: bytes, calculate_tree_hash: bool = True) -> Program: def from_bytes_with_cursor( cls, blob: bytes, cursor: int, calculate_tree_hash: bool = True ) -> Tuple[Program, int]: - tree = CLVMTree.from_bytes(blob, calculate_tree_hash=calculate_tree_hash) + tree = CLVMTree.from_bytes( + blob[cursor:], calculate_tree_hash=calculate_tree_hash + ) obj = cls.wrap(tree) new_cursor = len(bytes(tree)) + cursor return obj, new_cursor @@ -54,7 +57,11 @@ def fromhex(cls, hexstr: str) -> Program: return cls.from_bytes(bytes.fromhex(hexstr)) def __bytes__(self) -> bytes: - return sexp_to_bytes(self) + if self._cached_serialization is None: + self._cached_serialization = sexp_to_bytes(self) + if not isinstance(self._cached_serialization, bytes): + self._cached_serialization = bytes(self._cached_serialization) + return self._cached_serialization def __int__(self) -> int: return self.as_int() @@ -75,7 +82,9 @@ def int_to_bytes(cls, i: int) -> bytes: def __init__(self): self.atom = b"" self._pair = None + self._unwrapped = self self._unwrapped_pair = None + self._cached_serialization = None self._cached_sha256_treehash = None @property @@ -96,7 +105,10 @@ def wrap(cls, v: CLVMStorage) -> Program: o = cls() o.atom = v.atom o._pair = None + o._unwrapped = v o._unwrapped_pair = v.pair + o._cached_serialization = getattr(v, "_cached_serialization", None) + o._cached_sha256_treehash = getattr(v, "_cached_sha256_treehash", None) return o # new object creation on the python heap @@ -124,7 +136,10 @@ def null(cls) -> Program: # display def __str__(self) -> str: - return bytes(self).hex() + s = bytes(self).hex() + if len(s) > 76: + s = f"{s[:70]}...{s[-6:]}" + return s def __repr__(self) -> str: return f"{self.__class__.__name__}({str(self)})" @@ -226,7 +241,11 @@ def replace(self, **kwargs) -> "Program": return self.to(replace(self, **kwargs)) def tree_hash(self) -> bytes32: - return sha256_treehash(self) + # we operate on the unwrapped version to prevent the re-wrapping that + # happens on each invocation of `Program.pair` whenever possible + if self._cached_sha256_treehash is None: + self._cached_sha256_treehash = sha256_treehash(self._unwrapped) + return self._cached_sha256_treehash def run_with_cost( self, args, max_cost: int = MAX_COST, flags: int = 0 @@ -300,7 +319,9 @@ def curry_hash(self, *args: bytes32) -> bytes32: hash. This lets you calculate the puzzle hash with hidden information. """ curry_treehasher = self.curry_treehasher - quoted_mod_hash = curry_treehasher.calculate_hash_of_quoted_mod_hash(self.tree_hash()) + quoted_mod_hash = curry_treehasher.calculate_hash_of_quoted_mod_hash( + self.tree_hash() + ) return curry_treehasher.curry_and_treehash(quoted_mod_hash, *args) def as_int(self) -> int: diff --git a/wheel/python/clvm_rs/tree_hash.py b/wheel/python/clvm_rs/tree_hash.py index cfbf526e..c7a85d1d 100644 --- a/wheel/python/clvm_rs/tree_hash.py +++ b/wheel/python/clvm_rs/tree_hash.py @@ -14,8 +14,18 @@ class Treehasher: + """ + `Treehasher` performs the standard sha256tree hashing in a non-recursive + way so that extremely large objects don't blow out the python stack. + + We also force a `_cached_sha256_treehash` into the hashed sub-objects + whenever possible so that taking the hash of the same sub-tree is + more efficient in future. + """ + atom_prefix: bytes pair_prefix: bytes + cache_hits: int def __init__(self, atom_prefix: bytes, pair_prefix: bytes): self.atom_prefix = atom_prefix @@ -35,42 +45,47 @@ def shatree_pair(self, left_hash: bytes32, right_hash: bytes32) -> bytes32: s.update(right_hash) return bytes32(s.digest()) - def sha256_treehash(self, sexp: CLVMStorage) -> bytes32: - def handle_sexp(sexp_stack, hash_stack, op_stack) -> None: - sexp = sexp_stack.pop() - r = getattr(sexp, "_cached_sha256_treehash", None) + def sha256_treehash(self, clvm_storage: CLVMStorage) -> bytes32: + def handle_obj(obj_stack, hash_stack, op_stack) -> None: + obj = obj_stack.pop() + r = getattr(obj, "_cached_sha256_treehash", None) if r is not None: self.cache_hits += 1 hash_stack.append(r) return - elif sexp.pair: - p0, p1 = sexp.pair - sexp_stack.append(p0) - sexp_stack.append(p1) - op_stack.append(handle_pair) - op_stack.append(handle_sexp) - op_stack.append(handle_sexp) - else: - r = shatree_atom(sexp.atom) + elif obj.atom is not None: + r = shatree_atom(obj.atom) hash_stack.append(r) - sexp._cached_sha256_treehash = r + try: + setattr(obj, "_cached_sha256_treehash", r) + except AttributeError: + pass + else: + p0, p1 = obj.pair + obj_stack.append(obj) + obj_stack.append(p0) + obj_stack.append(p1) + op_stack.append(handle_pair) + op_stack.append(handle_obj) + op_stack.append(handle_obj) - def handle_pair(sexp_stack, hash_stack, op_stack) -> None: + def handle_pair(obj_stack, hash_stack, op_stack) -> None: p0 = hash_stack.pop() p1 = hash_stack.pop() r = shatree_pair(p0, p1) hash_stack.append(r) + obj = obj_stack.pop() try: - setattr(sexp, "_cached_sha256_treehash", r) + setattr(obj, "_cached_sha256_treehash", r) except AttributeError: pass - sexp_stack = [sexp] - op_stack = [handle_sexp] + obj_stack = [clvm_storage] + op_stack = [handle_obj] hash_stack: List[bytes32] = [] while len(op_stack) > 0: op = op_stack.pop() - op(sexp_stack, hash_stack, op_stack) + op(obj_stack, hash_stack, op_stack) return hash_stack[0] diff --git a/wheel/python/tests/test_as_python.py b/wheel/python/tests/test_as_python.py index b33a4672..528cbb03 100644 --- a/wheel/python/tests/test_as_python.py +++ b/wheel/python/tests/test_as_python.py @@ -125,6 +125,8 @@ def test_as_int(self): self.assertEqual(Program.to(fh("ff")).as_int(), -1) self.assertEqual(Program.to(fh("0080")).as_int(), 128) self.assertEqual(Program.to(fh("00ff")).as_int(), 255) + with self.assertRaises(ValueError): + Program.fromhex("ff8080").as_int() def test_string(self): self.assertEqual(Program.to("foobar").as_atom(), b"foobar") diff --git a/wheel/python/tests/test_curry_and_treehash.py b/wheel/python/tests/test_curry_and_treehash.py index 18cf2f47..b88c0c4a 100644 --- a/wheel/python/tests/test_curry_and_treehash.py +++ b/wheel/python/tests/test_curry_and_treehash.py @@ -1,7 +1,12 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + import pytest -from clvm_rs.program import Program +from clvm_rs import Program +from clvm_rs.clvm_storage import CLVMStorage from clvm_rs.curry_and_treehash import CHIA_CURRY_TREEHASHER +from clvm_rs.tree_hash import sha256_treehash curry_and_treehash = CHIA_CURRY_TREEHASHER.curry_and_treehash calculate_hash_of_quoted_mod_hash = ( diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py index 92c006fa..a83e8d1b 100644 --- a/wheel/python/tests/test_program.py +++ b/wheel/python/tests/test_program.py @@ -77,6 +77,10 @@ def test_hash(self): p1 = Program.fromhex("80") assert hash(p1) == id(p1) + def test_long_repr(self): + p1 = Program.fromhex(f"c062{'61'*98}") + assert repr(p1) == f"Program(c062{'61'*33}...616161)" + def check_idempotency(p, *args): curried = p.curry(*args) @@ -134,7 +138,6 @@ def test_uncurry(): plus = Program.fromhex("ff02ffff01ff10ff02ff0580ffff04ffff0101ff018080") prog = Program.fromhex("ff10ff02ff0580") # `(+ 2 5)` args = Program.fromhex("01") # `1` - assert plus.uncurry() == (prog, [args]) diff --git a/wheel/python/tests/test_to_program.py b/wheel/python/tests/test_to_program.py index 4ce6c618..3da1ad49 100644 --- a/wheel/python/tests/test_to_program.py +++ b/wheel/python/tests/test_to_program.py @@ -1,10 +1,42 @@ import unittest from typing import Optional, Tuple, Any, Union -from clvm_rs.clvm_storage import is_clvm_storage +from clvm_rs.clvm_storage import CLVMStorage, is_clvm_storage from clvm_rs.program import Program +class SimpleStorage(CLVMStorage): + """ + A simple implementation of `CLVMStorage`. + """ + + atom: Optional[bytes] + + def __init__(self, atom, pair): + self.atom = atom + self._pair = pair + + @property + def pair(self) -> Optional[Tuple["CLVMStorage", "CLVMStorage"]]: + return self._pair + + +class Uncachable(SimpleStorage): + """ + This object does not allow `_cached_sha256_treehash` or `_cached_serialization` + to be stored. + """ + + def get_th(self): + return None + + def set_th(self, v): + raise AttributeError("can't set property") + + _cached_sha256_treehash = property(get_th, set_th) + _cached_serialization = property(get_th, set_th) + + def convert_atom_to_bytes(castable: Any) -> Optional[bytes]: return Program.to(castable).atom @@ -105,10 +137,6 @@ def pair(self) -> Optional[Tuple[Any, Any]]: GeneratedTree(new_depth, self.val + 2**new_depth), ) - @classmethod - def isinstance(cls, obj): - return isinstance(obj, cls) - tree = Program.to(GeneratedTree(5, 0)) assert ( print_leaves(tree) @@ -237,3 +265,77 @@ def test_to_nil(self): self.assertEqual(Program.to([]), 0) self.assertEqual(Program.to(0), 0) self.assertEqual(Program.to(b""), 0) + + def test_tree_hash_caching(self): + o = SimpleStorage(b"foo", None) + eh = "0080b50a51ecd0ccfaaa4d49dba866fe58724f18445d30202bafb03e21eef6cb" + p = Program.to(o) + self.assertEqual(p.tree_hash().hex(), eh) + self.assertEqual(p._cached_sha256_treehash.hex(), eh) + self.assertEqual(o._cached_sha256_treehash.hex(), eh) + + o2 = SimpleStorage(None, (o, o)) + eh2 = "4a40c538671ef10c8d956e5dd3625e167c8adfb666c943f67f91ea58fd7a302c" + p2 = Program.to(o2) + self.assertEqual(p2.tree_hash().hex(), eh2) + self.assertEqual(p2._cached_sha256_treehash.hex(), eh2) + self.assertEqual(o._cached_sha256_treehash.hex(), eh) + self.assertEqual(o2._cached_sha256_treehash.hex(), eh2) + + p2p = Program.to((p, p)) + self.assertEqual(p2p.tree_hash().hex(), eh2) + self.assertEqual(p2p._cached_sha256_treehash.hex(), eh2) + self.assertEqual(p._cached_sha256_treehash.hex(), eh) + self.assertEqual(p2._cached_sha256_treehash.hex(), eh2) + + o3 = SimpleStorage(None, (o2, o2)) + eh3 = "280df61ed70cac1ec3cf9811c15f75e6698516b0354252960a62fa31240e4970" + p3 = Program.to(o3) + self.assertEqual(p3.tree_hash().hex(), eh3) + self.assertEqual(p3._cached_sha256_treehash.hex(), eh3) + self.assertEqual(o._cached_sha256_treehash.hex(), eh) + self.assertEqual(o2._cached_sha256_treehash.hex(), eh2) + self.assertEqual(o3._cached_sha256_treehash.hex(), eh3) + + p3p = Program.to((p2, p2)) + self.assertEqual(p3p.tree_hash().hex(), eh3) + self.assertEqual(p3p._cached_sha256_treehash.hex(), eh3) + self.assertEqual(p._cached_sha256_treehash.hex(), eh) + self.assertEqual(p2._cached_sha256_treehash.hex(), eh2) + + def test_tree_hash_no_caching(self): + o = Uncachable(b"foo", None) + eh = "0080b50a51ecd0ccfaaa4d49dba866fe58724f18445d30202bafb03e21eef6cb" + p = Program.to(o) + self.assertEqual(p.tree_hash().hex(), eh) + self.assertEqual(p._cached_sha256_treehash.hex(), eh) + self.assertEqual(o._cached_sha256_treehash, None) + + o2 = Uncachable(None, (o, o)) + eh2 = "4a40c538671ef10c8d956e5dd3625e167c8adfb666c943f67f91ea58fd7a302c" + p2 = Program.to(o2) + self.assertEqual(p2.tree_hash().hex(), eh2) + self.assertEqual(p2._cached_sha256_treehash.hex(), eh2) + self.assertEqual(o._cached_sha256_treehash, None) + self.assertEqual(o2._cached_sha256_treehash, None) + + p2p = Program.to((p, p)) + self.assertEqual(p2p.tree_hash().hex(), eh2) + self.assertEqual(p2p._cached_sha256_treehash.hex(), eh2) + self.assertEqual(p._cached_sha256_treehash.hex(), eh) + self.assertEqual(p2._cached_sha256_treehash.hex(), eh2) + + o3 = Uncachable(None, (o2, o2)) + eh3 = "280df61ed70cac1ec3cf9811c15f75e6698516b0354252960a62fa31240e4970" + p3 = Program.to(o3) + self.assertEqual(p3.tree_hash().hex(), eh3) + self.assertEqual(p3._cached_sha256_treehash.hex(), eh3) + self.assertEqual(o._cached_sha256_treehash, None) + self.assertEqual(o2._cached_sha256_treehash, None) + self.assertEqual(o3._cached_sha256_treehash, None) + + p3p = Program.to((p2, p2)) + self.assertEqual(p3p.tree_hash().hex(), eh3) + self.assertEqual(p3p._cached_sha256_treehash.hex(), eh3) + self.assertEqual(p._cached_sha256_treehash.hex(), eh) + self.assertEqual(p2._cached_sha256_treehash.hex(), eh2) From 8f7abca5e73f502e7267576d847be74605feb2c9 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Mon, 6 Feb 2023 18:07:34 -0800 Subject: [PATCH 36/45] Speed up `parse` and `__eq__`. --- wheel/python/clvm_rs/program.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index 81fd97eb..cb6e5f54 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -145,22 +145,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({str(self)})" def __eq__(self, other) -> bool: - stack: List[Tuple[CLVMStorage, CLVMStorage]] = [(self, Program.to(other))] - while stack: - p1, p2 = stack.pop() - if p1.atom is None: - if p2.atom is not None: - return False - pair_1 = p1.pair - pair_2 = p2.pair - assert pair_1 is not None - assert pair_2 is not None - stack.append((pair_1[1], pair_2[1])) - stack.append((pair_1[0], pair_2[0])) - else: - if p1.atom != p2.atom: - return False - return True + return self.tree_hash() == other.tree_hash() def __ne__(self, other) -> bool: return not self.__eq__(other) From 2b7f9cfeef50818e53e250f59529e1262cd6bfc7 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Mon, 6 Feb 2023 21:59:58 -0800 Subject: [PATCH 37/45] Improve benchmarking code. --- wheel/python/benchmarks/deserialization.py | 46 +++++++++++++--------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/wheel/python/benchmarks/deserialization.py b/wheel/python/benchmarks/deserialization.py index 410e6a0c..d2d4902a 100644 --- a/wheel/python/benchmarks/deserialization.py +++ b/wheel/python/benchmarks/deserialization.py @@ -59,53 +59,67 @@ def benchmark(): des_output = bench( lambda: Program.from_bytes(result_blob), - "from_bytes (with tree hashing)", + "from_bytes with tree hashing (fbwth)", allow_slow=True, ) - bench(lambda: des_output.tree_hash(), "from_bytes (with tree hashing) tree hash") + bench(lambda: des_output.tree_hash(), "tree hash (fbwth)") bench( - lambda: des_output.tree_hash(), "from_bytes (with tree hashing) tree hash again" + lambda: des_output.tree_hash(), "tree hash again (fbwth)" ) bench(lambda: serialized_length(result_blob), "serialized_length") des_output = bench( lambda: Program.from_bytes(result_blob, calculate_tree_hash=False), - "from_bytes output (without tree hashing)", + "from_bytes without tree hashing (fbwoth)", allow_slow=True, ) bench( lambda: des_output.tree_hash(), - "from_bytes (without tree hashing) tree hash", + "tree hash (fbwoth)", allow_slow=True, ) bench( lambda: des_output.tree_hash(), - "from_bytes (without tree hashing) tree hash again", + "tree hash (fbwoth) again", ) reparsed_output = bench( lambda: Program.parse(io.BytesIO(result_blob)), - "reparse output", + "parse with tree hashing (pwth)", allow_slow=True, ) - bench(lambda: reparsed_output.tree_hash(), "reparsed tree hash", allow_slow=True) + bench(lambda: reparsed_output.tree_hash(), "tree hash (pwth)") bench( lambda: reparsed_output.tree_hash(), - "reparsed tree hash again", + "tree hash again (pwth)", + ) + + reparsed_output = bench( + lambda: Program.parse(io.BytesIO(result_blob), calculate_tree_hash=False), + "parse without treehashing (pwowt)", + allow_slow=True, + ) + bench( + lambda: reparsed_output.tree_hash(), "tree hash (pwowt)", allow_slow=True + ) + bench( + lambda: reparsed_output.tree_hash(), + "tree hash again (pwowt)", ) foo = Program.to("foo") o0 = Program.to((foo, obj)) o1 = Program.to((foo, obj1)) + o2 = Program.to((foo, output)) def compare(): assert o0 == o1 - bench(compare, "compare") + bench(compare, "compare constructed") - bench(lambda: bytes(o0), "to_bytes o0") - bench(lambda: bytes(o1), "to_bytes o1") + bench(lambda: bytes(o0), "to_bytes constructed o0") + bench(lambda: bytes(o1), "to_bytes constructed o1") bench(lambda: print(o0.tree_hash().hex()), "o0 tree_hash") bench(lambda: print(o0.tree_hash().hex()), "o0 tree_hash (again)") @@ -113,16 +127,10 @@ def compare(): bench(lambda: print(o1.tree_hash().hex()), "o1 tree_hash") bench(lambda: print(o1.tree_hash().hex()), "o1 tree_hash (again)") - o2 = Program.to((foo, output)) - + bench(lambda: bytes(o2), "to_bytes constructed o2") bench(lambda: print(o2.tree_hash().hex()), "o2 tree_hash") bench(lambda: print(o2.tree_hash().hex()), "o2 tree_hash (again)") - # start = time.time() - # obj1 = sexp_from_stream(io.BytesIO(out), SExp.to, allow_backrefs=False) - # end = time.time() - # print(end-start) - if __name__ == "__main__": benchmark() From b6905ea377f149d2ac1c12a6dae2a3430c074552 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Tue, 7 Feb 2023 12:45:25 -0800 Subject: [PATCH 38/45] Factor out `copy_exact` and `skip_bytes`. --- src/serde/de_tree.rs | 25 ++----------------------- src/serde/mod.rs | 1 + src/serde/tools.rs | 3 ++- src/serde/utils.rs | 24 ++++++++++++++++++++++++ wheel/python/clvm_rs/program.py | 2 +- 5 files changed, 30 insertions(+), 25 deletions(-) create mode 100644 src/serde/utils.rs diff --git a/src/serde/de_tree.rs b/src/serde/de_tree.rs index 636824eb..fc80ce30 100644 --- a/src/serde/de_tree.rs +++ b/src/serde/de_tree.rs @@ -1,11 +1,12 @@ use std::convert::TryInto; -use std::io::{copy, sink, Error, Read, Result, Write}; +use std::io::{Error, Read, Result, Write}; use sha2::Digest; use crate::sha2::Sha256; use super::parse_atom::decode_size_with_offset; +use super::utils::{copy_exactly, skip_bytes}; const MAX_SINGLE_BYTE: u8 = 0x7f; const CONS_BOX_MARKER: u8 = 0xff; @@ -64,28 +65,6 @@ fn tree_hash_for_byte(b: u8, calculate_tree_hashes: bool) -> Option<[u8; 32]> { } } -pub fn copy_exactly( - reader: &mut R, - writer: &mut W, - expected_size: u64, -) -> Result<()> { - let mut reader = reader.by_ref().take(expected_size); - - let count = copy(&mut reader, writer)?; - if count < expected_size { - Err(Error::new( - std::io::ErrorKind::UnexpectedEof, - "copy terminated early", - )) - } else { - Ok(()) - } -} - -fn skip_bytes(f: &mut R, skip_size: u64) -> Result<()> { - copy_exactly(f, &mut sink(), skip_size) -} - fn skip_or_sha_bytes( f: &mut R, skip_size: u64, diff --git a/src/serde/mod.rs b/src/serde/mod.rs index d6ff0f8c..4fd8262b 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -9,6 +9,7 @@ mod read_cache_lookup; mod ser; mod ser_br; mod tools; +mod utils; mod write_atom; #[cfg(test)] diff --git a/src/serde/tools.rs b/src/serde/tools.rs index 5a00f2a3..eb34a22b 100644 --- a/src/serde/tools.rs +++ b/src/serde/tools.rs @@ -1,8 +1,9 @@ use std::io; -use std::io::{Cursor, Read, Seek, SeekFrom}; +use std::io::{Cursor, Read}; use super::errors::bad_encoding; use super::parse_atom::decode_size; +use super::utils::skip_bytes; const MAX_SINGLE_BYTE: u8 = 0x7f; const CONS_BOX_MARKER: u8 = 0xff; diff --git a/src/serde/utils.rs b/src/serde/utils.rs new file mode 100644 index 00000000..954d6e09 --- /dev/null +++ b/src/serde/utils.rs @@ -0,0 +1,24 @@ +use std::io; +use std::io::{copy, sink, Error, Read, Write}; + +pub fn copy_exactly( + reader: &mut R, + writer: &mut W, + expected_size: u64, +) -> io::Result<()> { + let mut reader = reader.by_ref().take(expected_size); + + let count = copy(&mut reader, writer)?; + if count < expected_size { + Err(Error::new( + std::io::ErrorKind::UnexpectedEof, + "copy terminated early", + )) + } else { + Ok(()) + } +} + +pub fn skip_bytes(f: &mut R, skip_size: u64) -> io::Result<()> { + copy_exactly(f, &mut sink(), skip_size) +} diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index cb6e5f54..44f172c1 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -145,7 +145,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({str(self)})" def __eq__(self, other) -> bool: - return self.tree_hash() == other.tree_hash() + return self.tree_hash() == self.to(other).tree_hash() def __ne__(self, other) -> bool: return not self.__eq__(other) From c6af427ad0babd6c20b9cc6ea87062cd96ec21ef Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Tue, 7 Feb 2023 22:23:09 -0800 Subject: [PATCH 39/45] Add `skip_clvm_object` api. --- src/serde/mod.rs | 2 +- src/serde/tools.rs | 60 +++++++++++++++++----------------------------- wheel/src/api.rs | 22 ++++++++++++++++- 3 files changed, 44 insertions(+), 40 deletions(-) diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 4fd8262b..1b9b8c6b 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -20,4 +20,4 @@ pub use de_br::node_from_bytes_backrefs; pub use de_tree::{parse_triples, ParsedTriple}; pub use ser::node_to_bytes; pub use ser_br::node_to_bytes_backrefs; -pub use tools::{serialized_length_from_bytes, tree_hash_from_stream}; +pub use tools::{parse_through_clvm_object, serialized_length_from_bytes, tree_hash_from_stream}; diff --git a/src/serde/tools.rs b/src/serde/tools.rs index eb34a22b..5585e4e1 100644 --- a/src/serde/tools.rs +++ b/src/serde/tools.rs @@ -16,43 +16,27 @@ enum ParseOp { pub fn serialized_length_from_bytes(b: &[u8]) -> io::Result { let mut f = Cursor::new(b); - let mut ops = vec![ParseOp::SExp]; + parse_through_clvm_object(&mut f)?; + Ok(f.position()) +} + +pub fn parse_through_clvm_object(f: &mut R) -> io::Result<()> { + let mut to_parse_count = 1; let mut b = [0; 1]; loop { - let op = ops.pop(); - if op.is_none() { + if to_parse_count < 1 { break; + }; + f.read_exact(&mut b)?; + if b[0] == CONS_BOX_MARKER { + to_parse_count += 2; + } else if b[0] != 0x80 && b[0] > MAX_SINGLE_BYTE { + let blob_size = decode_size(f, b[0])?; + skip_bytes(f, blob_size)?; } - match op.unwrap() { - ParseOp::SExp => { - f.read_exact(&mut b)?; - if b[0] == CONS_BOX_MARKER { - // since all we're doing is to determing the length of the - // serialized buffer, we don't need to do anything about - // "cons". So we skip pushing it to lower the pressure on - // the op stack - //ops.push(ParseOp::Cons); - ops.push(ParseOp::SExp); - ops.push(ParseOp::SExp); - } else if b[0] == 0x80 || b[0] <= MAX_SINGLE_BYTE { - // This one byte we just read was the whole atom. - // or the - // special case of NIL - } else { - let blob_size = decode_size(&mut f, b[0])?; - f.seek(SeekFrom::Current(blob_size as i64))?; - if (f.get_ref().len() as u64) < f.position() { - return Err(bad_encoding()); - } - } - } - ParseOp::Cons => { - // cons. No need to construct any structure here. Just keep - // going - } - } + to_parse_count -= 1; } - Ok(f.position()) + Ok(()) } use crate::sha2::{Digest, Sha256}; @@ -240,16 +224,16 @@ fn test_serialized_length_from_bytes() { ); let e = serialized_length_from_bytes(&[0x8f, 0xff]).unwrap_err(); - assert_eq!(e.kind(), bad_encoding().kind()); - assert_eq!(e.to_string(), "bad encoding"); + assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); + assert_eq!(e.to_string(), "copy terminated early"); let e = serialized_length_from_bytes(&[0b11001111, 0xff]).unwrap_err(); - assert_eq!(e.kind(), bad_encoding().kind()); - assert_eq!(e.to_string(), "bad encoding"); + assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); + assert_eq!(e.to_string(), "copy terminated early"); let e = serialized_length_from_bytes(&[0b11001111, 0xff, 0, 0]).unwrap_err(); - assert_eq!(e.kind(), bad_encoding().kind()); - assert_eq!(e.to_string(), "bad encoding"); + assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); + assert_eq!(e.to_string(), "copy terminated early"); assert_eq!( serialized_length_from_bytes(&[0x8f, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).unwrap(), diff --git a/wheel/src/api.rs b/wheel/src/api.rs index dc46ac92..727d30e4 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -7,17 +7,36 @@ use clvmr::chia_dialect::ChiaDialect; use clvmr::cost::Cost; use clvmr::reduction::Response; use clvmr::run_program::run_program; -use clvmr::serde::{node_from_bytes, parse_triples, serialized_length_from_bytes, ParsedTriple}; +use clvmr::serde::{ + node_from_bytes, parse_through_clvm_object, parse_triples, serialized_length_from_bytes, + ParsedTriple, +}; use clvmr::{LIMIT_HEAP, LIMIT_STACK, MEMPOOL_MODE, NO_UNKNOWN_OPS}; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyTuple}; use pyo3::wrap_pyfunction; +struct ReadPyAny<'py>(&'py PyAny); + +impl<'py> std::io::Read for ReadPyAny<'py> { + fn read(&mut self, b: &mut [u8]) -> std::result::Result { + let r: Vec = self.0.call1((b.len(),))?.extract()?; + let (p0, _p1) = b.split_at_mut(r.len()); + p0.copy_from_slice(&r); + Ok(r.len()) + } +} + #[pyfunction] pub fn serialized_length(program: &[u8]) -> PyResult { Ok(serialized_length_from_bytes(program)?) } +#[pyfunction] +pub fn skip_clvm_object(obj: &PyAny) -> PyResult<()> { + Ok(parse_through_clvm_object(&mut ReadPyAny(obj.getattr("read")?))?) +} + #[pyfunction] pub fn run_serialized_chia_program( py: Python, @@ -75,6 +94,7 @@ fn deserialize_as_tree( fn clvm_rs(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(run_serialized_chia_program, m)?)?; m.add_function(wrap_pyfunction!(serialized_length, m)?)?; + m.add_function(wrap_pyfunction!(skip_clvm_object, m)?)?; m.add_function(wrap_pyfunction!(deserialize_as_tree, m)?)?; m.add("NO_UNKNOWN_OPS", NO_UNKNOWN_OPS)?; From 61146eaa11f28835c1a903fcb7b7077b37d05ede Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Tue, 7 Feb 2023 22:49:02 -0800 Subject: [PATCH 40/45] Add more `parse` tests. --- wheel/python/tests/test_serialize.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/wheel/python/tests/test_serialize.py b/wheel/python/tests/test_serialize.py index c55d0f8b..46ca0f1e 100644 --- a/wheel/python/tests/test_serialize.py +++ b/wheel/python/tests/test_serialize.py @@ -99,6 +99,8 @@ def test_very_deep_tree(self): def test_deserialize_empty(self): bytes_in = b"" + with self.assertRaises(ValueError): + Program.from_bytes(bytes_in) with self.assertRaises(ValueError): Program.parse(io.BytesIO(bytes_in)) @@ -106,6 +108,8 @@ def test_deserialize_truncated_size(self): # fe means the total number of bytes in the length-prefix is 7 # one for each bit set. 5 bytes is too few bytes_in = b"\xfe " + with self.assertRaises(ValueError): + Program.from_bytes(bytes_in) with self.assertRaises(ValueError): Program.parse(io.BytesIO(bytes_in)) @@ -114,6 +118,8 @@ def test_deserialize_truncated_blob(self): # the blob itself is truncated though, it's less than 63 bytes bytes_in = b"\xbf " + with self.assertRaises(ValueError): + Program.from_bytes(bytes_in) with self.assertRaises(ValueError): Program.parse(io.BytesIO(bytes_in)) @@ -138,6 +144,7 @@ def test_repr_clvm_tree(self): def test_bad_blob(self): self.assertRaises(ValueError, lambda: Program.fromhex("ff")) + self.assertRaises(ValueError, lambda: Program.parse(io.BytesIO(bytes.fromhex("ff")))) def test_large_atom(self): s = "foo" * 100 From 3a0183ad1c492b350499f9fa0ed1806c8aa24528 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 8 Feb 2023 15:02:07 -0800 Subject: [PATCH 41/45] Use old error messages. --- src/serde/tools.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/serde/tools.rs b/src/serde/tools.rs index 5585e4e1..336b6888 100644 --- a/src/serde/tools.rs +++ b/src/serde/tools.rs @@ -16,7 +16,7 @@ enum ParseOp { pub fn serialized_length_from_bytes(b: &[u8]) -> io::Result { let mut f = Cursor::new(b); - parse_through_clvm_object(&mut f)?; + parse_through_clvm_object(&mut f).map_err(|_e| bad_encoding())?; Ok(f.position()) } @@ -224,16 +224,16 @@ fn test_serialized_length_from_bytes() { ); let e = serialized_length_from_bytes(&[0x8f, 0xff]).unwrap_err(); - assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); - assert_eq!(e.to_string(), "copy terminated early"); + assert_eq!(e.kind(), bad_encoding().kind()); + assert_eq!(e.to_string(), "bad encoding"); let e = serialized_length_from_bytes(&[0b11001111, 0xff]).unwrap_err(); - assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); - assert_eq!(e.to_string(), "copy terminated early"); + assert_eq!(e.kind(), bad_encoding().kind()); + assert_eq!(e.to_string(), "bad encoding"); let e = serialized_length_from_bytes(&[0b11001111, 0xff, 0, 0]).unwrap_err(); - assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); - assert_eq!(e.to_string(), "copy terminated early"); + assert_eq!(e.kind(), bad_encoding().kind()); + assert_eq!(e.to_string(), "bad encoding"); assert_eq!( serialized_length_from_bytes(&[0x8f, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).unwrap(), From ac05bc90282e556278025b134de661f9d22ebeaf Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Fri, 10 Feb 2023 15:01:14 -0800 Subject: [PATCH 42/45] Remove a bunch of `.as_*` methods. --- wheel/python/clvm_rs/program.py | 43 +++------------- wheel/python/tests/test_as_python.py | 73 +++++---------------------- wheel/python/tests/test_program.py | 2 +- wheel/python/tests/test_to_program.py | 18 +++---- 4 files changed, 31 insertions(+), 105 deletions(-) diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index 44f172c1..7897c655 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -64,7 +64,10 @@ def __bytes__(self) -> bytes: return self._cached_serialization def __int__(self) -> int: - return self.as_int() + v = self.as_int() + if v is None: + raise ValueError("can't cast pair to int") + return v def __hash__(self): return id(self) @@ -160,12 +163,6 @@ def rest(self) -> Optional[Program]: return self.pair[1] return None - def as_pair(self) -> Optional[Tuple[Program, Program]]: - return self.pair - - def as_atom(self) -> Optional[bytes]: - return self.atom - def listp(self) -> bool: return self.pair is not None @@ -309,37 +306,11 @@ def curry_hash(self, *args: bytes32) -> bytes32: ) return curry_treehasher.curry_and_treehash(quoted_mod_hash, *args) - def as_int(self) -> int: - v = self.as_atom() + def as_int(self) -> Optional[int]: + v = self.atom if v is None: - raise ValueError("can't cast pair to int") + return v return int_from_bytes(v) - def as_iter(self) -> Iterator[Program]: - v = self - while v.pair: - yield v.pair[0] - v = v.pair[1] - - def as_atom_iter(self) -> Iterator[bytes]: - """ - Pretend `self` is a list of atoms. Yield the corresponding atoms - up until this assumption is wrong. - """ - obj = self - while obj.pair is not None: - left, obj = obj.pair - atom = left.atom - if atom is None: - break - yield atom - - def as_atom_list(self) -> List[bytes]: - """ - Pretend `self` is a list of atoms. Return the corresponding - python list of atoms up until this assumption is wrong. - """ - return list(self.as_atom_iter()) - NULL_PROGRAM = Program.fromhex("80") diff --git a/wheel/python/tests/test_as_python.py b/wheel/python/tests/test_as_python.py index 528cbb03..bf29c081 100644 --- a/wheel/python/tests/test_as_python.py +++ b/wheel/python/tests/test_as_python.py @@ -28,30 +28,6 @@ def gen_tree(depth: int) -> Program: class AsPythonTest(unittest.TestCase): - def check_as_atom_list(self, p): - v = Program.to(p) - p1 = v.as_atom_list() - self.assertEqual(p, p1) - - def test_null(self): - self.check_as_atom_list([]) - - def test_single_bytes(self): - for _ in range(256): - self.check_as_atom_list([bytes([_])]) - - def test_short_lists(self): - self.check_as_atom_list([]) - for _ in range(256): - for size in range(1, 5): - self.check_as_atom_list([bytes([_])] * size) - - def test_non_list(self): - v = Program.to([1, 2, 3, (5, 6), 7]) - p1 = v.as_atom_list() - expected = [Program.to(_).atom for _ in [1, 2, 3]] - self.assertEqual(p1, expected) - def test_int(self): v = Program.to(42) self.assertEqual(v.atom, bytes([42])) @@ -72,8 +48,8 @@ def test_list_of_one(self): v = Program.to([1]) self.assertEqual(type(v.pair[0]), Program) self.assertEqual(type(v.pair[1]), Program) - self.assertEqual(type(v.as_pair()[0]), Program) - self.assertEqual(type(v.as_pair()[1]), Program) + self.assertEqual(type(v.pair[0]), Program) + self.assertEqual(type(v.pair[1]), Program) self.assertEqual(v.pair[0].atom, b"\x01") self.assertEqual(v.pair[1].atom, b"") @@ -87,12 +63,6 @@ def test_g1element(self): v = Program.to(castable_to_bytes(b)) self.assertEqual(v.atom, b) - def test_complex(self): - self.check_as_atom_list([b"foo"]) - self.check_as_atom_list([b"2", b"1"]) - self.check_as_atom_list([b"", b"2", b"1"]) - self.check_as_atom_list([b"", b"1", b"2", b"30", b"40", b"90", b"600"]) - def test_listp(self): self.assertEqual(Program.to(42).listp(), False) self.assertEqual(Program.to(b"").listp(), False) @@ -125,11 +95,12 @@ def test_as_int(self): self.assertEqual(Program.to(fh("ff")).as_int(), -1) self.assertEqual(Program.to(fh("0080")).as_int(), 128) self.assertEqual(Program.to(fh("00ff")).as_int(), 255) + self.assertEqual(Program.fromhex("ff8080").as_int(), None) with self.assertRaises(ValueError): - Program.fromhex("ff8080").as_int() + int(Program.fromhex("ff8080")) def test_string(self): - self.assertEqual(Program.to("foobar").as_atom(), b"foobar") + self.assertEqual(Program.to("foobar").atom, b"foobar") def test_deep_recursion(self): d = b"2" @@ -137,11 +108,11 @@ def test_deep_recursion(self): d = [d] v = Program.to(d) for i in range(1000): - self.assertEqual(v.as_pair()[1].as_atom(), Program.null()) - v = v.as_pair()[0] + self.assertEqual(v.pair[1].atom, Program.null()) + v = v.pair[0] d = d[0] - self.assertEqual(v.as_atom(), b"2") + self.assertEqual(v.atom, b"2") self.assertEqual(d, b"2") def test_long_linked_list(self): @@ -150,21 +121,21 @@ def test_long_linked_list(self): d = (b"2", d) v = Program.to(d) for i in range(1000): - self.assertEqual(v.as_pair()[0].as_atom(), d[0]) - v = v.as_pair()[1] + self.assertEqual(v.pair[0].atom, d[0]) + v = v.pair[1] d = d[1] - self.assertEqual(v.as_atom(), b"") + self.assertEqual(v.atom, b"") self.assertEqual(d, b"") def test_long_list(self): d = [1337] * 1000 v = Program.to(d) for i in range(1000): - self.assertEqual(v.as_pair()[0].as_int(), d[i]) - v = v.as_pair()[1] + self.assertEqual(v.pair[0].as_int(), d[i]) + v = v.pair[1] - self.assertEqual(v.as_atom(), b"") + self.assertEqual(v.atom, b"") def test_invalid_tuple(self): with self.assertRaises(ValueError): @@ -190,22 +161,6 @@ def test_rest(self): val = Program.to((42, val)) self.assertEqual(val.rest(), Program.to(1)) - def test_as_iter(self): - val = list(Program.to((1, (2, (3, (4, b""))))).as_iter()) - self.assertEqual(val, [1, 2, 3, 4]) - - val = list(Program.to(b"").as_iter()) - self.assertEqual(val, []) - - val = list(Program.to((1, b"")).as_iter()) - self.assertEqual(val, [1]) - - # these fail because the lists are not null-terminated - self.assertEqual(list(Program.to(1).as_iter()), []) - self.assertEqual( - list(Program.to((1, (2, (3, (4, 5))))).as_iter()), [1, 2, 3, 4] - ) - def test_eq(self): val = Program.to(1) diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py index a83e8d1b..a91e9bdc 100644 --- a/wheel/python/tests/test_program.py +++ b/wheel/python/tests/test_program.py @@ -51,7 +51,7 @@ def test_first_rest(self): self.assertEqual(p.first(), 4) self.assertEqual(p.rest(), [5]) p = Program.to(4) - self.assertEqual(p.as_pair(), None) + self.assertEqual(p.pair, None) self.assertEqual(p.first(), None) self.assertEqual(p.rest(), None) diff --git a/wheel/python/tests/test_to_program.py b/wheel/python/tests/test_to_program.py index 3da1ad49..db07e45f 100644 --- a/wheel/python/tests/test_to_program.py +++ b/wheel/python/tests/test_to_program.py @@ -51,7 +51,7 @@ def validate_program(program): v1, v2 = v.pair assert is_clvm_storage(v1) assert is_clvm_storage(v2) - s1, s2 = v.as_pair() + s1, s2 = v.pair validate_stack.append(s1) validate_stack.append(s2) else: @@ -59,14 +59,14 @@ def validate_program(program): def print_leaves(tree: Program) -> str: - a = tree.as_atom() + a = tree.atom if a is not None: if len(a) == 0: return "() " return "%d " % a[0] ret = "" - pair = tree.as_pair() + pair = tree.pair assert pair is not None for i in pair: ret += print_leaves(i) @@ -75,13 +75,13 @@ def print_leaves(tree: Program) -> str: def print_tree(tree: Program) -> str: - a = tree.as_atom() + a = tree.atom if a is not None: if len(a) == 0: return "() " return "%d " % a[0] - pair = tree.as_pair() + pair = tree.pair assert pair is not None ret = "(" for i in pair: @@ -180,13 +180,13 @@ def test_list_conversions(self): def test_string_conversions(self): a = Program.to("foobar") - assert a.as_atom() == "foobar".encode() + assert a.atom == "foobar".encode() def test_int_conversions(self): def check(v: int, h: Union[str, list]): a = Program.to(v) b = bytes.fromhex(h) if isinstance(h, str) else bytes(h) - assert a.as_atom() == b + assert a.atom == b # note that this compares to the atom, not the serialization of that atom # so 16384 codes as 0x4000, not 0x824000 @@ -232,11 +232,11 @@ def check(n): def test_none_conversions(self): a = Program.to(None) - assert a.as_atom() == b"" + assert a.atom == b"" def test_empty_list_conversions(self): a = Program.to([]) - assert a.as_atom() == b"" + assert a.atom == b"" def test_eager_conversion(self): with self.assertRaises(ValueError): From 818ee69ee51ae3f16823b007980883aee381ced7 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Fri, 10 Feb 2023 15:51:20 -0800 Subject: [PATCH 43/45] Don't cast `None`. Add `.at_many`. --- wheel/python/clvm_rs/casts.py | 3 --- wheel/python/clvm_rs/program.py | 28 +++++++++++++++++++++++++-- wheel/python/tests/test_as_python.py | 16 ++++++--------- wheel/python/tests/test_program.py | 9 +++++++-- wheel/python/tests/test_serialize.py | 4 ++-- wheel/python/tests/test_to_program.py | 5 ----- 6 files changed, 41 insertions(+), 24 deletions(-) diff --git a/wheel/python/clvm_rs/casts.py b/wheel/python/clvm_rs/casts.py index 778c04dc..80d78c85 100644 --- a/wheel/python/clvm_rs/casts.py +++ b/wheel/python/clvm_rs/casts.py @@ -11,7 +11,6 @@ str, int, SupportsBytes, - None, ] @@ -66,8 +65,6 @@ def to_atom_type(v: AtomCastableType) -> bytes: return int_to_bytes(v) if isinstance(v, (memoryview, SupportsBytes)): return bytes(v) - if v is None: - return NULL_BLOB raise ValueError("can't cast %s (%s) to bytes" % (type(v), v)) diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index 7897c655..fc330bf7 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -148,7 +148,12 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({str(self)})" def __eq__(self, other) -> bool: - return self.tree_hash() == self.to(other).tree_hash() + try: + other_obj = self.to(other) + except ValueError: + # cast failure + return False + return self.tree_hash() == other_obj.tree_hash() def __ne__(self, other) -> bool: return not self.__eq__(other) @@ -195,7 +200,26 @@ def at(self, position: str) -> Optional["Program"]: assert None == at(p1, "rr") ``` """ - return self.to(at(self, position)) + r = at(self, position) + if r is None: + return r + return self.to(r) + + def at_many(self, *positions: str) -> List[Optional["Program"]]: + """ + Call `.at` multiple times. + + Why? So you can write + + ` + if p.at_many("f", "rf", "rfrf") == [5, 10, 15]: + ` + instead of + ` + if [p.at("f"), p.at("rf"), p.at("rfrf")] == [5, 10, 15]: + ` + """ + return [self.at(_) for _ in positions] def replace(self, **kwargs) -> "Program": """ diff --git a/wheel/python/tests/test_as_python.py b/wheel/python/tests/test_as_python.py index bf29c081..3d252223 100644 --- a/wheel/python/tests/test_as_python.py +++ b/wheel/python/tests/test_as_python.py @@ -36,10 +36,6 @@ def test_int(self): self.assertEqual(v.atom, bytes([])) self.assertEqual(v.as_int(), 0) - def test_none(self): - v = Program.to(None) - self.assertEqual(v.atom, b"") - def test_empty_list(self): v = Program.to([]) self.assertEqual(v.atom, b"") @@ -164,14 +160,14 @@ def test_rest(self): def test_eq(self): val = Program.to(1) - self.assertTrue(val == 1) - self.assertFalse(val == 2) + self.assertEqual(val, 1) + self.assertNotEqual(val, 2) # mismatching types - self.assertFalse(val == [1]) - self.assertFalse(val == [1, 2]) - self.assertFalse(val == (1, 2)) - self.assertRaises(ValueError, lambda: val == (dummy_class, dummy_class)) + self.assertNotEqual(val, [1]) + self.assertNotEqual(val, [1, 2]) + self.assertNotEqual(val, (1, 2)) + self.assertNotEqual(val, (dummy_class, dummy_class)) def test_eq_tree(self): val1 = gen_tree(2) diff --git a/wheel/python/tests/test_program.py b/wheel/python/tests/test_program.py index a91e9bdc..5facf5b7 100644 --- a/wheel/python/tests/test_program.py +++ b/wheel/python/tests/test_program.py @@ -24,6 +24,11 @@ def test_at(self): self.assertEqual(None, p.at("ff")) self.assertEqual(None, p.at("ffr")) + def test_at_many(self): + p = Program.to([10, 20, 30, [15, 17], 40, 50]) + self.assertEqual(p.at_many("f", "rrrfrf"), [10, 17]) + self.assertEqual(p.at_many("fff", "rrrfff"), [None, None]) + def test_replace(self): p1 = Program.to([100, 200, 300]) self.assertEqual(p1.replace(f=105), Program.to([105, 200, 300])) @@ -152,11 +157,11 @@ def test_uncurry_not_pair(): # the second item in the list is expected to be a pair, with a qoute # `(a 1 (c (q . 1) (q . 1)))` plus = Program.fromhex("ff02ff01ffff04ffff0101ffff01018080") - assert plus.uncurry() == (plus, Program.to(0)) + assert plus.uncurry() == (plus, None) def test_uncurry_args_garbage(): # there's garbage at the end of the args list # `(a (q . 1) (c (q . 1) (q . 1) (q . 4919)))` plus = Program.fromhex("ff02ffff0101ffff04ffff0101ffff0101ffff018213378080") - assert plus.uncurry() == (plus, Program.to(0)) + assert plus.uncurry() == (plus, None) diff --git a/wheel/python/tests/test_serialize.py b/wheel/python/tests/test_serialize.py index 46ca0f1e..07747d3a 100644 --- a/wheel/python/tests/test_serialize.py +++ b/wheel/python/tests/test_serialize.py @@ -65,8 +65,8 @@ def test_short_lists(self): self.check_serde([_] * size) def test_cons_box(self): - self.check_serde((None, None)) - self.check_serde((None, [1, 2, 30, 40, 600, (None, 18)])) + self.check_serde((0, 0)) + self.check_serde((0, [1, 2, 30, 40, 600, ([], 18)])) self.check_serde((100, (TEXT, (30, (50, (90, (TEXT, TEXT + TEXT))))))) def test_long_blobs(self): diff --git a/wheel/python/tests/test_to_program.py b/wheel/python/tests/test_to_program.py index db07e45f..c234eb65 100644 --- a/wheel/python/tests/test_to_program.py +++ b/wheel/python/tests/test_to_program.py @@ -230,10 +230,6 @@ def check(n): check(n) check(-n) - def test_none_conversions(self): - a = Program.to(None) - assert a.atom == b"" - def test_empty_list_conversions(self): a = Program.to([]) assert a.atom == b"" @@ -251,7 +247,6 @@ def test_convert_atom(self): assert convert_atom_to_bytes("") == b"" assert convert_atom_to_bytes(b"foobar") == b"foobar" - assert convert_atom_to_bytes(None) == b"" assert convert_atom_to_bytes([]) == b"" assert convert_atom_to_bytes([1, 2, 3]) is None From c9ece37a63942a348069e025bc27eff5d27d6e84 Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Fri, 10 Feb 2023 15:51:20 -0800 Subject: [PATCH 44/45] Don't cast `None`. Add `.at_many`. --- wheel/python/clvm_rs/program.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index fc330bf7..a6b5f895 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Iterator, List, Tuple, Optional, Any, BinaryIO +from typing import List, Tuple, Optional, Any, BinaryIO from .at import at from .bytes32 import bytes32 From 17196e726b073681dc2c2df14fcff4b5cf6e527b Mon Sep 17 00:00:00 2001 From: Richard Kiss Date: Wed, 15 Feb 2023 18:50:20 -0800 Subject: [PATCH 45/45] Add back `.as_iter()`. --- wheel/python/clvm_rs/program.py | 6 ++++++ wheel/python/tests/test_as_python.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/wheel/python/clvm_rs/program.py b/wheel/python/clvm_rs/program.py index a6b5f895..7b887f4b 100644 --- a/wheel/python/clvm_rs/program.py +++ b/wheel/python/clvm_rs/program.py @@ -336,5 +336,11 @@ def as_int(self) -> Optional[int]: return v return int_from_bytes(v) + def as_iter(self) -> Iterator[Program]: + v = self + while v.pair: + yield v.pair[0] + v = v.pair[1] + NULL_PROGRAM = Program.fromhex("80") diff --git a/wheel/python/tests/test_as_python.py b/wheel/python/tests/test_as_python.py index 3d252223..219b65af 100644 --- a/wheel/python/tests/test_as_python.py +++ b/wheel/python/tests/test_as_python.py @@ -157,6 +157,22 @@ def test_rest(self): val = Program.to((42, val)) self.assertEqual(val.rest(), Program.to(1)) + def test_as_iter(self): + val = list(Program.to((1, (2, (3, (4, b""))))).as_iter()) + self.assertEqual(val, [1, 2, 3, 4]) + + val = list(Program.to(b"").as_iter()) + self.assertEqual(val, []) + + val = list(Program.to((1, b"")).as_iter()) + self.assertEqual(val, [1]) + + # these fail because the lists are not null-terminated + self.assertEqual(list(Program.to(1).as_iter()), []) + self.assertEqual( + list(Program.to((1, (2, (3, (4, 5))))).as_iter()), [1, 2, 3, 4] + ) + def test_eq(self): val = Program.to(1)