Skip to content

Commit

Permalink
small refactoring (keep-starknet-strange#376)
Browse files Browse the repository at this point in the history
Co-authored-by: lanaivina <[email protected]>
  • Loading branch information
tcoratger and lana-shanghai authored Feb 16, 2024
1 parent 6ba8eeb commit 4c30b8e
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 116 deletions.
28 changes: 12 additions & 16 deletions src/math/crypto/curve/ec_point.zig
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ pub const ProjectivePoint = struct {
}

pub fn doubleAssign(self: *Self) void {
if (self.infinity) {
if (self.infinity)
return;
}

// t=3x^2+az^2 with a=1 from stark curve
const t = Felt252.three().mul(self.x).mul(self.x).add(self.z.mul(self.z));
Expand Down Expand Up @@ -59,9 +58,8 @@ pub const ProjectivePoint = struct {
}

fn addAssign(self: *Self, rhs: ProjectivePoint) void {
if (rhs.infinity) {
if (rhs.infinity)
return;
}

if (self.infinity) {
self.* = rhs;
Expand Down Expand Up @@ -97,9 +95,8 @@ pub const ProjectivePoint = struct {
}

pub fn addAssignAffinePoint(self: *Self, rhs: AffinePoint) void {
if (rhs.infinity) {
if (rhs.infinity)
return;
}

if (self.infinity) {
self.* = .{
Expand Down Expand Up @@ -174,21 +171,21 @@ pub const AffinePoint = struct {
}

pub fn addAssign(self: *Self, rhs: *AffinePoint) void {
if (rhs.infinity) {
if (rhs.infinity)
return;
}

if (self.infinity) {
self.x = rhs.x;
self.y = rhs.y;
self.infinity = rhs.infinity;
self.* = .{ .x = rhs.x, .y = rhs.y, .infinity = rhs.infinity };
return;
}

if (self.x.equal(rhs.x)) {
if (self.y.equal(rhs.y.neg())) {
self.x = Felt252.zero();
self.y = Felt252.zero();
self.infinity = true;
self.* = .{
.x = Felt252.zero(),
.y = Felt252.zero(),
.infinity = true,
};
return;
}
self.doubleAssign();
Expand All @@ -204,9 +201,8 @@ pub const AffinePoint = struct {
}

pub fn doubleAssign(self: *Self) void {
if (self.infinity) {
if (self.infinity)
return;
}

// l = (3x^2+a)/2y with a=1 from stark curve
const lambda = Felt252.three().mul(self.x.mul(self.x)).add(Felt252.one()).mul(Felt252.two().mul(self.y).inv().?);
Expand Down
4 changes: 2 additions & 2 deletions src/math/fields/elliptic_curve.zig
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ pub const ECPoint = struct {
pub fn ecDouble(self: *Self, alpha: Felt252) ECError!ECPoint {

// Assumes the point is given in affine form (x, y) and has y != 0.
if (self.y.equal(Felt252.zero())) {
if (self.y.equal(Felt252.zero()))
return ECError.YCoordinateIsZero;
}

const m = try self.ecDoubleSlope(alpha);
const x = m.pow(2).sub(self.x.mul(Felt252.two()));
const y = m.mul(self.x.sub(x)).sub(self.y);
Expand Down
21 changes: 10 additions & 11 deletions src/math/fields/helper.zig
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub fn tonelliShanks(n: u512, p: u512) struct { u512, u512, bool } {
}

if (s == 1) {
const result: u512 = powModulus(n, (p + 1) / 4, p);
const result = powModulus(n, (p + 1) / 4, p);
return .{ result, p - result, true };
}

Expand All @@ -56,10 +56,10 @@ pub fn tonelliShanks(n: u512, p: u512) struct { u512, u512, bool } {
z = z + 1;
}

var c: u512 = powModulus(z, q, p);
var t: u512 = powModulus(n, q, p);
var m: u512 = s;
var result: u512 = powModulus(n, (q + 1) >> 1, p);
var c = powModulus(z, q, p);
var t = powModulus(n, q, p);
var m = s;
var result = powModulus(n, (q + 1) >> 1, p);

while (t != 1) {
var i: u512 = 1;
Expand All @@ -69,7 +69,7 @@ pub fn tonelliShanks(n: u512, p: u512) struct { u512, u512, bool } {
z = multiplyModulus(z, z, p);
}

const b: u512 = powModulus(c, @as(u512, 1) << @intCast(m - i - 1), p);
const b = powModulus(c, @as(u512, 1) << @intCast(m - i - 1), p);
c = multiplyModulus(b, b, p);
t = multiplyModulus(t, c, p);
m = i;
Expand All @@ -93,9 +93,8 @@ pub fn extendedGCD(self: i256, other: i256) struct { gcd: i256, x: i256, y: i256
t[0] = t[0] - q * t[1];
}

if (r[1] >= 0) {
return .{ .gcd = r[1], .x = s[1], .y = t[1] };
} else {
return .{ .gcd = -r[1], .x = -s[1], .y = -t[1] };
}
return if (r[1] >= 0)
.{ .gcd = r[1], .x = s[1], .y = t[1] }
else
.{ .gcd = -r[1], .x = -s[1], .y = -t[1] };
}
16 changes: 7 additions & 9 deletions src/math/fields/stark_felt_252_gen_fp.zig
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,9 @@ inline fn cast(
const dest = @typeInfo(DestType).Int;
const source = @typeInfo(@TypeOf(target)).Int;
if (dest.bits < source.bits) {
const T = std.meta.Int(
source.signedness,
dest.bits,
);
return @bitCast(
@as(
T,
std.meta.Int(source.signedness, dest.bits),
@truncate(target),
),
);
Expand Down Expand Up @@ -4015,8 +4011,10 @@ pub fn divstep(
pub fn divstepPrecomp(out1: *[4]u64) void {
@setRuntimeSafety(mode == .Debug);

out1[0] = 0x20000001;
out1[1] = 0xfff6678000000000;
out1[2] = 0xfffff273ffffffff;
out1[3] = 0x7fffffbc0000010;
out1.* = [4]u64{
0x20000001,
0xfff6678000000000,
0xfffff273ffffffff,
0x7fffffbc0000010,
};
}
19 changes: 11 additions & 8 deletions src/vm/builtins/builtin_runner/signature.zig
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,16 @@ pub const SignatureBuiltinRunner = struct {
else => return result,
};

const pubkey = memory.getFelt(pubkey_message_addr[0]) catch if (cell_index == 1) return result else return MemoryError.PubKeyNonInt;
const msg = memory.getFelt(pubkey_message_addr[1]) catch if (cell_index == 0) return result else return MemoryError.MsgNonInt;
const pubkey = memory.getFelt(pubkey_message_addr[0]) catch
return if (cell_index == 1) result else MemoryError.PubKeyNonInt;
const msg = memory.getFelt(pubkey_message_addr[1]) catch
return if (cell_index == 0) result else MemoryError.MsgNonInt;

const signature = self.signatures.get(pubkey_message_addr[0]) catch return MemoryError.SignatureNotFound;

if (verify(pubkey, msg, signature.r, signature.s) catch return MemoryError.InvalidSignature) {
if (verify(pubkey, msg, signature.r, signature.s) catch
return MemoryError.InvalidSignature)
{
return result;
}

Expand Down Expand Up @@ -176,20 +180,19 @@ pub const SignatureBuiltinRunner = struct {
if (self.included) {
const stop_pointer_addr = pointer.subUint(1) catch return RunnerError.NoStopPointer;

const stop_pointer = segments.memory.getRelocatable(stop_pointer_addr) catch return RunnerError.NoStopPointer;
const stop_pointer = segments.memory.getRelocatable(stop_pointer_addr) catch
return RunnerError.NoStopPointer;

if (@as(i64, @intCast(self.base)) != stop_pointer.segment_index) {
if (self.base != stop_pointer.segment_index)
return RunnerError.InvalidStopPointerIndex;
}

const stop_ptr = stop_pointer.offset;
const num_instances = try self.getUsedInstances(segments);

const used = num_instances * @as(usize, @intCast(self.cells_per_instance));

if (stop_ptr != used) {
if (stop_ptr != used)
return RunnerError.InvalidStopPointer;
}

self.stop_ptr = stop_ptr;
return stop_pointer_addr;
Expand Down
123 changes: 53 additions & 70 deletions src/vm/decoding/decoder.zig
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,8 @@ pub fn decodeInstructions(encoded_instr: u64) !Instruction {
const OFF2_OFF: u64 = 32;
const OFFX_MASK: u64 = 0xFFFF;

if (encoded_instr & HIGH_BIT != 0) {
if (encoded_instr & HIGH_BIT != 0)
return CairoVMError.InstructionNonZeroHighBit;
}

// Grab offsets and convert them from little endian format.
const off0 = try decodeOffset(encoded_instr >> OFF0_OFF & OFFX_MASK);
const off1 = try decodeOffset(encoded_instr >> OFF1_OFF & OFFX_MASK);
const off2 = try decodeOffset(encoded_instr >> OFF2_OFF & OFFX_MASK);

// Grab flags
const flags = encoded_instr >> FLAGS_OFFSET;
Expand All @@ -63,82 +57,71 @@ pub fn decodeInstructions(encoded_instr: u64) !Instruction {
const ap_update_num = (flags & AP_UPDATE_MASK) >> AP_UPDATE_OFF;
const opcode_num = (flags & OPCODE_MASK) >> OPCODE_OFF;

// Match each flag to its corresponding enum value
const dst_register: Register = switch (dst_reg_num) {
1 => Register.FP,
else => Register.AP,
};

const op0_register: Register = switch (op0_reg_num) {
1 => Register.FP,
else => Register.AP,
};

const op1_addr = switch (op1_src_num) {
0 => Op1Src.Op0,
1 => Op1Src.Imm,
2 => Op1Src.FP,
4 => Op1Src.AP,
else => return CairoVMError.InvalidOp1Reg,
};

const pc_update = switch (pc_update_num) {
0 => PcUpdate.Regular,
1 => PcUpdate.Jump,
2 => PcUpdate.JumpRel,
4 => PcUpdate.Jnz,
const pc_update: PcUpdate = switch (pc_update_num) {
0 => .Regular,
1 => .Jump,
2 => .JumpRel,
4 => .Jnz,
else => return CairoVMError.InvalidPcUpdate,
};

const res = switch (res_logic_num) {
0 => if (pc_update == PcUpdate.Jnz) ResLogic.Unconstrained else ResLogic.Op1,
1 => ResLogic.Add,
2 => ResLogic.Mul,
else => return CairoVMError.InvalidResLogic,
};

const opcode = switch (opcode_num) {
0 => Opcode.NOp,
1 => Opcode.Call,
2 => Opcode.Ret,
4 => Opcode.AssertEq,
const opcode: Opcode = switch (opcode_num) {
0 => .NOp,
1 => .Call,
2 => .Ret,
4 => .AssertEq,
else => return CairoVMError.InvalidOpcode,
};

const ap_update = switch (ap_update_num) {
0 => if (opcode == Opcode.Call) ApUpdate.Add2 else ApUpdate.Regular,
1 => ApUpdate.Add,
2 => ApUpdate.Add1,
else => return CairoVMError.InvalidApUpdate,
};

const fp_update = switch (opcode) {
Opcode.Call => FpUpdate.APPlus2,
Opcode.Ret => FpUpdate.Dst,
else => FpUpdate.Regular,
};

return Instruction{
.off_0 = off0,
.off_1 = off1,
.off_2 = off2,
.dst_reg = dst_register,
.op_0_reg = op0_register,
.op_1_addr = op1_addr,
.res_logic = res,
return .{
.off_0 = try decodeOffset(encoded_instr >> OFF0_OFF & OFFX_MASK),
.off_1 = try decodeOffset(encoded_instr >> OFF1_OFF & OFFX_MASK),
.off_2 = try decodeOffset(encoded_instr >> OFF2_OFF & OFFX_MASK),
.dst_reg = switch (dst_reg_num) {
1 => .FP,
else => .AP,
},
.op_0_reg = switch (op0_reg_num) {
1 => .FP,
else => .AP,
},
.op_1_addr = switch (op1_src_num) {
0 => .Op0,
1 => .Imm,
2 => .FP,
4 => .AP,
else => return CairoVMError.InvalidOp1Reg,
},
.res_logic = switch (res_logic_num) {
0 => if (pc_update == .Jnz) .Unconstrained else .Op1,
1 => .Add,
2 => .Mul,
else => return CairoVMError.InvalidResLogic,
},
.pc_update = pc_update,
.ap_update = ap_update,
.fp_update = fp_update,
.ap_update = switch (ap_update_num) {
0 => if (opcode == .Call) .Add2 else .Regular,
1 => .Add,
2 => .Add1,
else => return CairoVMError.InvalidApUpdate,
},
.fp_update = switch (opcode) {
.Call => .APPlus2,
.Ret => .Dst,
else => .Regular,
},
.opcode = opcode,
};
}

pub fn decodeOffset(offset: u64) !i16 {
var vectorized_offset: [8]u8 = std.mem.toBytes(offset);
const offset_16b_encoded = std.mem.readInt(u16, vectorized_offset[0..2], std.builtin.Endian.little);
const complement_const: u16 = 0x8000;
const result = @subWithOverflow(offset_16b_encoded, complement_const);
return @as(i16, @bitCast(result[0]));
return @bitCast(
@subWithOverflow(
std.mem.readInt(u16, vectorized_offset[0..2], std.builtin.Endian.little),
0x8000,
)[0],
);
}

test "decodeInstructions: non-zero high bit" {
Expand Down

0 comments on commit 4c30b8e

Please sign in to comment.