Skip to content

Commit

Permalink
Feat/Hints keccak sign poseidon (keep-starknet-strange#436)
Browse files Browse the repository at this point in the history
* poseidon utils + setup ids without memory

* keccak hint temp

* unsafe keccak hint + unit tests

* unsafe keccak finalize + unit test

* split input/output + test

* splitNBytes + test

* split mid high + test

* integration test
  • Loading branch information
StringNick authored Mar 2, 2024
1 parent a92b023 commit 4b322e5
Show file tree
Hide file tree
Showing 16 changed files with 24,654 additions and 61 deletions.
17,121 changes: 17,121 additions & 0 deletions cairo_programs/keccak_compiled.json

Large diffs are not rendered by default.

6,507 changes: 6,507 additions & 0 deletions cairo_programs/keccak_copy_inputs.json

Large diffs are not rendered by default.

58 changes: 58 additions & 0 deletions src/hint_processor/builtin_hint_codes.zig
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,61 @@ pub const VM_ENTER_SCOPE = "vm_enter_scope()";
pub const VM_EXIT_SCOPE = "vm_exit_scope()";

pub const MEMCPY_ENTER_SCOPE = "vm_enter_scope({'n': ids.len})";
pub const NONDET_N_GREATER_THAN_10 = "memory[ap] = to_felt_or_relocatable(ids.n >= 10)";
pub const NONDET_N_GREATER_THAN_2 = "memory[ap] = to_felt_or_relocatable(ids.n >= 2)";

pub const UNSAFE_KECCAK =
\\from eth_hash.auto import keccak
\\
\\data, length = ids.data, ids.length
\\
\\if '__keccak_max_size' in globals():
\\ assert length <= __keccak_max_size, \
\\ f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \
\\ f'Got: length={length}.'
\\
\\keccak_input = bytearray()
\\for word_i, byte_i in enumerate(range(0, length, 16)):
\\ word = memory[data + word_i]
\\ n_bytes = min(16, length - byte_i)
\\ assert 0 <= word < 2 ** (8 * n_bytes)
\\ keccak_input += word.to_bytes(n_bytes, 'big')
\\
\\hashed = keccak(keccak_input)
\\ids.high = int.from_bytes(hashed[:16], 'big')
\\ids.low = int.from_bytes(hashed[16:32], 'big')
;

pub const UNSAFE_KECCAK_FINALIZE =
\\from eth_hash.auto import keccak
\\keccak_input = bytearray()
\\n_elms = ids.keccak_state.end_ptr - ids.keccak_state.start_ptr
\\for word in memory.get_range(ids.keccak_state.start_ptr, n_elms):
\\ keccak_input += word.to_bytes(16, 'big')
\\hashed = keccak(keccak_input)
\\ids.high = int.from_bytes(hashed[:16], 'big')
\\ids.low = int.from_bytes(hashed[16:32], 'big')
;

pub const SPLIT_INPUT_3 = "ids.high3, ids.low3 = divmod(memory[ids.inputs + 3], 256)";
pub const SPLIT_INPUT_6 = "ids.high6, ids.low6 = divmod(memory[ids.inputs + 6], 256 ** 2)";
pub const SPLIT_INPUT_9 = "ids.high9, ids.low9 = divmod(memory[ids.inputs + 9], 256 ** 3)";
pub const SPLIT_INPUT_12 =
"ids.high12, ids.low12 = divmod(memory[ids.inputs + 12], 256 ** 4)";
pub const SPLIT_INPUT_15 =
"ids.high15, ids.low15 = divmod(memory[ids.inputs + 15], 256 ** 5)";

pub const SPLIT_OUTPUT_0 =
\\ids.output0_low = ids.output0 & ((1 << 128) - 1)
\\ids.output0_high = ids.output0 >> 128
;
pub const SPLIT_OUTPUT_1 =
\\ids.output1_low = ids.output1 & ((1 << 128) - 1)
\\ids.output1_high = ids.output1 >> 128
;

pub const SPLIT_N_BYTES = "ids.n_words_to_copy, ids.n_bytes_left = divmod(ids.n_bytes, ids.BYTES_IN_WORD)";
pub const SPLIT_OUTPUT_MID_LOW_HIGH =
\\tmp, ids.output1_low = divmod(ids.output1, 256 ** 7)
\\ids.output1_high, ids.output1_mid = divmod(tmp, 2 ** 128)
;
32 changes: 31 additions & 1 deletion src/hint_processor/hint_processor_def.zig
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ const Relocatable = @import("../vm/memory/relocatable.zig").Relocatable;
const hint_codes = @import("builtin_hint_codes.zig");
const math_hints = @import("math_hints.zig");
const memcpy_hint_utils = @import("memcpy_hint_utils.zig");

const poseidon_utils = @import("poseidon_utils.zig");
const keccak_utils = @import("keccak_utils.zig");
const felt_bit_length = @import("felt_bit_length.zig");


const deserialize_utils = @import("../parser/deserialize_utils.zig");

const expect = std.testing.expect;
Expand Down Expand Up @@ -198,9 +202,35 @@ pub const CairoVMHintProcessor = struct {
try memcpy_hint_utils.exitScope(exec_scopes);
} else if (std.mem.eql(u8, hint_codes.MEMCPY_ENTER_SCOPE, hint_data.code)) {
try memcpy_hint_utils.memcpyEnterScope(allocator, vm, exec_scopes, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.NONDET_N_GREATER_THAN_10, hint_data.code)) {
try poseidon_utils.nGreaterThan10(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.NONDET_N_GREATER_THAN_2, hint_data.code)) {
try poseidon_utils.nGreaterThan2(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UNSAFE_KECCAK, hint_data.code)) {
try keccak_utils.unsafeKeccak(allocator, vm, exec_scopes, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UNSAFE_KECCAK_FINALIZE, hint_data.code)) {
try keccak_utils.unsafeKeccakFinalize(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.SPLIT_INPUT_3, hint_data.code)) {
try keccak_utils.splitInput(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, 3, 1);
} else if (std.mem.eql(u8, hint_codes.SPLIT_INPUT_6, hint_data.code)) {
try keccak_utils.splitInput(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, 6, 2);
} else if (std.mem.eql(u8, hint_codes.SPLIT_INPUT_9, hint_data.code)) {
try keccak_utils.splitInput(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, 9, 3);
} else if (std.mem.eql(u8, hint_codes.SPLIT_INPUT_12, hint_data.code)) {
try keccak_utils.splitInput(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, 12, 4);
} else if (std.mem.eql(u8, hint_codes.SPLIT_INPUT_15, hint_data.code)) {
try keccak_utils.splitInput(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, 15, 5);
} else if (std.mem.eql(u8, hint_codes.SPLIT_OUTPUT_0, hint_data.code)) {
try keccak_utils.splitOutput(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, 0);
} else if (std.mem.eql(u8, hint_codes.SPLIT_OUTPUT_1, hint_data.code)) {
try keccak_utils.splitOutput(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, 1);
} else if (std.mem.eql(u8, hint_codes.SPLIT_N_BYTES, hint_data.code)) {
try keccak_utils.splitNBytes(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, constants);
} else if (std.mem.eql(u8, hint_codes.SPLIT_OUTPUT_MID_LOW_HIGH, hint_data.code)) {
try keccak_utils.splitOutputMidLowHigh(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.GET_FELT_BIT_LENGTH, hint_data.code)) {
try felt_bit_length.getFeltBitLength(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else {}
}
}

// Executes the hint which's data is provided by a dynamic structure previously created by compile_hint
Expand Down
5 changes: 0 additions & 5 deletions src/hint_processor/hint_processor_utils.zig
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ pub fn getIntegerFromReference(
else => {},
}

// std.debug.print(
// "totototot = {any}\n",
// .{computeAddrFromReference(hint_reference, ap_tracking, vm)},
// );

// Compute the memory address of the variable and retrieve the integer value from memory.
return if (computeAddrFromReference(hint_reference, ap_tracking, vm)) |var_addr|
vm.segments.memory.getFelt(var_addr) catch HintError.WrongIdentifierTypeInternal
Expand Down
Loading

0 comments on commit 4b322e5

Please sign in to comment.