diff --git a/api/process.go b/api/process.go index b896c39..c7f1869 100644 --- a/api/process.go +++ b/api/process.go @@ -65,12 +65,12 @@ func (a *API) newProcess(w http.ResponseWriter, r *http.Request) { ErrGenericInternalServerError.Withf("could not marshal ballot mode: %v", err).Write(w) return } - state, err := st.Initialize(p.CensusRoot, ballotmode, publicKey.Marshal()) - if err != nil { + + if err := st.Initialize(p.CensusRoot, ballotmode, publicKey.Marshal()); err != nil { ErrGenericInternalServerError.Withf("could not initialize state: %v", err).Write(w) return } - root, err := state.RootAsBigInt() + root, err := st.RootAsBigInt() if err != nil { ErrGenericInternalServerError.Withf("could not get state root: %v", err).Write(w) return diff --git a/circuits/statetransition/circuit.go b/circuits/statetransition/circuit.go new file mode 100644 index 0000000..4cf50c2 --- /dev/null +++ b/circuits/statetransition/circuit.go @@ -0,0 +1,144 @@ +package statetransition + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/vocdoni/gnark-crypto-primitives/elgamal" + "github.com/vocdoni/gnark-crypto-primitives/utils" +) + +const ( + // votes that were processed in AggregatedProof + VoteBatchSize = 10 +) + +type Circuit struct { + // --------------------------------------------------------------------------------------------- + // PUBLIC INPUTS + + // list of root hashes + RootHashBefore frontend.Variable `gnark:",public"` + RootHashAfter frontend.Variable `gnark:",public"` + NumNewVotes frontend.Variable `gnark:",public"` + NumOverwrites frontend.Variable `gnark:",public"` + + // --------------------------------------------------------------------------------------------- + // SECRET INPUTS + + AggregatedProof frontend.Variable // mock, this should be a zkProof + + ProcessID MerkleProof + CensusRoot MerkleProof + BallotMode MerkleProof + EncryptionKey MerkleProof + ResultsAdd MerkleTransition + ResultsSub MerkleTransition + Ballot [VoteBatchSize]MerkleTransition + Commitment [VoteBatchSize]MerkleTransition +} + +// Define declares the circuit's constraints +func (circuit Circuit) Define(api frontend.API) error { + hashFn := func(api frontend.API, data ...frontend.Variable) (frontend.Variable, error) { + h, err := mimc.NewMiMC(api) + if err != nil { + return 0, err + } + h.Write(data...) + return h.Sum(), nil + } + + circuit.VerifyAggregatedZKProof(api) + circuit.VerifyMerkleProofs(api, hashFn) + circuit.VerifyMerkleTransitions(api, hashFn) + circuit.VerifyBallots(api) + return nil +} + +func (circuit Circuit) VerifyAggregatedZKProof(api frontend.API) { + // all of the following values compose the preimage that is hashed + // to produce the public input needed to verify AggregatedProof. + // they are extracted from the MerkleProofs: + // ProcessID := circuit.ProcessID.Value + // CensusRoot := circuit.CensusRoot.Value + // BallotMode := circuit.BallotMode.Value + // EncryptionKey := circuit.EncryptionKey.Value + // Nullifiers := circuit.Ballot[i].NewKey + // Ballots := circuit.Ballot[i].NewValue + // Addressess := circuit.Commitment[i].NewKey + // Commitments := circuit.Commitment[i].NewValue + + api.Println("verify AggregatedZKProof mock:", circuit.AggregatedProof) // mock + + packedInputs := func() frontend.Variable { + for i, p := range []MerkleProof{ + circuit.ProcessID, + circuit.CensusRoot, + circuit.BallotMode, + circuit.EncryptionKey, + } { + api.Println("packInputs mock", i, p.Value) // mock + } + for i := range circuit.Ballot { + api.Println("packInputs mock nullifier", i, circuit.Ballot[i].NewKey) // mock + api.Println("packInputs mock ballot", i, circuit.Ballot[i].NewValue) // mock + } + for i := range circuit.Commitment { + api.Println("packInputs mock address", i, circuit.Commitment[i].NewKey) // mock + api.Println("packInputs mock commitment", i, circuit.Commitment[i].NewValue) // mock + } + return 1 // mock, should return hash of packed inputs + } + + api.AssertIsEqual(packedInputs(), 1) // TODO: mock, should actually verify AggregatedZKProof +} + +func (circuit Circuit) VerifyMerkleProofs(api frontend.API, hFn utils.Hasher) { + api.Println("verify ProcessID, CensusRoot, BallotMode and EncryptionKey belong to RootHashBefore") + circuit.ProcessID.VerifyProof(api, hFn, circuit.RootHashBefore) + circuit.CensusRoot.VerifyProof(api, hFn, circuit.RootHashBefore) + circuit.BallotMode.VerifyProof(api, hFn, circuit.RootHashBefore) + circuit.EncryptionKey.VerifyProof(api, hFn, circuit.RootHashBefore) +} + +func (circuit Circuit) VerifyMerkleTransitions(api frontend.API, hFn utils.Hasher) { + // verify chain of tree transitions, order here is fundamental. + api.Println("tree transition starts with RootHashBefore:", prettyHex(circuit.RootHashBefore)) + root := circuit.RootHashBefore + for i := range circuit.Ballot { + root = circuit.Ballot[i].Verify(api, hFn, root) + } + for i := range circuit.Commitment { + root = circuit.Commitment[i].Verify(api, hFn, root) + } + root = circuit.ResultsAdd.Verify(api, hFn, root) + root = circuit.ResultsSub.Verify(api, hFn, root) + api.Println("and final root is", prettyHex(root), "should be equal to RootHashAfter", prettyHex(circuit.RootHashAfter)) + api.AssertIsEqual(root, circuit.RootHashAfter) +} + +// VerifyBallots counts the ballots using homomorphic encrpytion +func (circuit Circuit) VerifyBallots(api frontend.API) { + ballotSum, overwrittenSum, zero := elgamal.NewCiphertext(), elgamal.NewCiphertext(), elgamal.NewCiphertext() + var ballotCount, overwrittenCount frontend.Variable = 0, 0 + + for _, b := range circuit.Ballot { + // TODO: check that Hash(NewCiphertext) matches b.NewValue + // and Hash(OldCiphertext) matches b.OldValue + ballotSum.Add(api, ballotSum, + elgamal.NewCiphertext().Select(api, b.IsInsertOrUpdate(api), &b.NewCiphertext, zero)) + + overwrittenSum.Add(api, overwrittenSum, + elgamal.NewCiphertext().Select(api, b.IsUpdate(api), &b.OldCiphertext, zero)) + + ballotCount = api.Add(ballotCount, api.Select(b.IsInsertOrUpdate(api), 1, 0)) + overwrittenCount = api.Add(overwrittenCount, api.Select(b.IsUpdate(api), 1, 0)) + } + + circuit.ResultsAdd.NewCiphertext.AssertIsEqual(api, + circuit.ResultsAdd.OldCiphertext.Add(api, &circuit.ResultsAdd.OldCiphertext, ballotSum)) + circuit.ResultsSub.NewCiphertext.AssertIsEqual(api, + circuit.ResultsSub.OldCiphertext.Add(api, &circuit.ResultsSub.OldCiphertext, overwrittenSum)) + api.AssertIsEqual(circuit.NumNewVotes, ballotCount) + api.AssertIsEqual(circuit.NumOverwrites, overwrittenCount) +} diff --git a/circuits/statetransition/circuit_test.go b/circuits/statetransition/circuit_test.go new file mode 100644 index 0000000..29d136e --- /dev/null +++ b/circuits/statetransition/circuit_test.go @@ -0,0 +1,274 @@ +package statetransition_test + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "math/big" + "os" + "reflect" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/logger" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/consensys/gnark/test" + "github.com/rs/zerolog" + "github.com/vocdoni/vocdoni-z-sandbox/circuits/statetransition" + "github.com/vocdoni/vocdoni-z-sandbox/state" + + "github.com/vocdoni/arbo" + "go.vocdoni.io/dvote/db/metadb" +) + +func TestCircuitCompile(t *testing.T) { + // enable log to see nbConstraints + logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger()) + + _, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &statetransition.Circuit{}) + if err != nil { + panic(err) + } +} + +func TestCircuitProve(t *testing.T) { + s, err := state.New(metadb.NewTest(t), + []byte{0xca, 0xfe, 0x00}) + if err != nil { + t.Fatal(err) + } + + if err := s.Initialize( + []byte{0xca, 0xfe, 0x01}, + []byte{0xca, 0xfe, 0x02}, + []byte{0xca, 0xfe, 0x03}, + ); err != nil { + t.Fatal(err) + } + + // first batch + if err := s.StartBatch(); err != nil { + t.Fatal(err) + } + if err := s.AddVote(state.NewVote(1, 10)); err != nil { // new vote 1 + t.Fatal(err) + } + if err := s.AddVote(state.NewVote(2, 20)); err != nil { // new vote 2 + t.Fatal(err) + } + witness, err := GenerateWitnesses(s) + if err != nil { + t.Fatal(err) + } + if err := s.EndBatch(); err != nil { // expected result: 16+17=33 + t.Fatal(err) + } + assert := test.NewAssert(t) + + assert.ProverSucceeded( + &statetransition.Circuit{}, + witness, + test.WithCurves(ecc.BN254), + test.WithBackends(backend.GROTH16)) + + debugLog(t, witness) + + // second batch + if err := s.StartBatch(); err != nil { + t.Fatal(err) + } + if err := s.AddVote(state.NewVote(1, 100)); err != nil { // overwrite vote 1 + t.Fatal(err) + } + if err := s.AddVote(state.NewVote(3, 30)); err != nil { // add vote 3 + t.Fatal(err) + } + if err := s.AddVote(state.NewVote(4, 30)); err != nil { // add vote 4 + t.Fatal(err) + } + witness, err = GenerateWitnesses(s) + if err != nil { + t.Fatal(err) + } + if err := s.EndBatch(); err != nil { + t.Fatal(err) + } + // expected results: + // ResultsAdd: 16+17+10+100 = 143 + // ResultsSub: 16 = 16 + // Final: 16+17-16+10+100 = 127 + assert.ProverSucceeded( + &statetransition.Circuit{}, + witness, + test.WithCurves(ecc.BN254), + test.WithBackends(backend.GROTH16)) + + debugLog(t, witness) +} + +func debugLog(t *testing.T, witness *statetransition.Circuit) { + t.Log("public: RootHashBefore", prettyHex(witness.RootHashBefore)) + t.Log("public: RootHashAfter", prettyHex(witness.RootHashAfter)) + t.Log("public: NumVotes", prettyHex(witness.NumNewVotes)) + t.Log("public: NumOverwrites", prettyHex(witness.NumOverwrites)) + for name, mt := range map[string]state.MerkleTransition{ + "ResultsAdd": witness.ResultsAdd, + "ResultsSub": witness.ResultsSub, + } { + t.Log(name, "transitioned", "(root", prettyHex(mt.OldRoot), "->", prettyHex(mt.NewRoot), ")", + "value", mt.OldValue, "->", mt.NewValue, + ) + t.Log(name, "elgamal.C1.X", mt.OldCiphertext.C1.X, "->", mt.NewCiphertext.C1.X) + t.Log(name, "elgamal.C1.Y", mt.OldCiphertext.C1.Y, "->", mt.NewCiphertext.C1.Y) + t.Log(name, "elgamal.C2.X", mt.OldCiphertext.C2.X, "->", mt.NewCiphertext.C2.X) + t.Log(name, "elgamal.C2.Y", mt.OldCiphertext.C2.Y, "->", mt.NewCiphertext.C2.Y) + } +} + +func debugWitness(witness statetransition.Circuit) { + js, _ := json.MarshalIndent(witness, "", " ") + fmt.Printf("\n\n%s\n\n", js) +} + +func prettyHex(v frontend.Variable) string { + type hasher interface { + HashCode() [16]byte + } + switch v := v.(type) { + case (*big.Int): + return hex.EncodeToString(arbo.BigIntToBytes(32, v)[:4]) + case int: + return fmt.Sprintf("%d", v) + case []byte: + return fmt.Sprintf("%x", v[:4]) + case hasher: + return fmt.Sprintf("%x", v.HashCode()) + default: + return fmt.Sprintf("(%v)=%+v", reflect.TypeOf(v), v) + } +} + +type CircuitBallots struct { + statetransition.Circuit +} + +func (circuit CircuitBallots) Define(api frontend.API) error { + circuit.VerifyBallots(api) + return nil +} + +func TestCircuitBallotsCompile(t *testing.T) { + // enable log to see nbConstraints + logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger()) + + _, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &CircuitBallots{}) + if err != nil { + panic(err) + } +} + +func TestCircuitBallotsProve(t *testing.T) { + s, err := state.New(metadb.NewTest(t), + []byte{0xca, 0xfe, 0x00}) + if err != nil { + t.Fatal(err) + } + + if err := s.Initialize( + []byte{0xca, 0xfe, 0x01}, + []byte{0xca, 0xfe, 0x02}, + []byte{0xca, 0xfe, 0x03}, + ); err != nil { + t.Fatal(err) + } + + if err := s.AddVote(state.NewVote(1, 10)); err != nil { // new vote 1 + t.Fatal(err) + } + + witness, err := GenerateWitnesses(s) + if err != nil { + t.Fatal(err) + } + + if err := s.EndBatch(); err != nil { // expected result: 16+17=33 + t.Fatal(err) + } + assert := test.NewAssert(t) + + assert.ProverSucceeded( + &CircuitBallots{}, + witness, + test.WithCurves(ecc.BN254), + test.WithBackends(backend.GROTH16)) +} + +type CircuitMerkleTransitions struct { + statetransition.Circuit +} + +func (circuit CircuitMerkleTransitions) Define(api frontend.API) error { + hashFn := func(api frontend.API, data ...frontend.Variable) (frontend.Variable, error) { + h, err := mimc.NewMiMC(api) + if err != nil { + return 0, err + } + h.Write(data...) + return h.Sum(), nil + } + + circuit.VerifyMerkleTransitions(api, hashFn) + return nil +} + +func TestCircuitMerkleTransitionsCompile(t *testing.T) { + // enable log to see nbConstraints + logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger()) + + _, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &CircuitMerkleTransitions{}) + if err != nil { + panic(err) + } +} + +func TestCircuitMerkleTransitionsProve(t *testing.T) { + s, err := state.New(metadb.NewTest(t), + []byte{0xca, 0xfe, 0x00}) + if err != nil { + t.Fatal(err) + } + + if err := s.Initialize( + []byte{0xca, 0xfe, 0x01}, + []byte{0xca, 0xfe, 0x02}, + []byte{0xca, 0xfe, 0x03}, + ); err != nil { + t.Fatal(err) + } + + if err := s.AddVote(state.NewVote(1, 10)); err != nil { + t.Fatal(err) + } + + witness, err := GenerateWitnesses(s) + if err != nil { + t.Fatal(err) + } + + if err := s.EndBatch(); err != nil { + t.Fatal(err) + } + + assert := test.NewAssert(t) + + assert.ProverSucceeded( + &CircuitMerkleTransitions{}, + witness, + test.WithCurves(ecc.BN254), + test.WithBackends(backend.GROTH16)) + + debugLog(t, witness) +} diff --git a/circuits/statetransition/state_test.go b/circuits/statetransition/state_test.go new file mode 100644 index 0000000..ade604c --- /dev/null +++ b/circuits/statetransition/state_test.go @@ -0,0 +1,98 @@ +package statetransition_test + +import ( + "fmt" + + "github.com/vocdoni/arbo" + "github.com/vocdoni/vocdoni-z-sandbox/circuits/statetransition" + "github.com/vocdoni/vocdoni-z-sandbox/crypto/ecc/curves" + "github.com/vocdoni/vocdoni-z-sandbox/state" +) + +var curve = curves.New(curves.CurveTypeBabyJubJubGnark) + +func GenerateWitnesses(o *state.State) (*statetransition.Circuit, error) { + var err error + witness := &statetransition.Circuit{} + + // TODO: mock, replace by actual AggregatedProof + witness.AggregatedProof = 0 + + // RootHashBefore + witness.RootHashBefore, err = o.RootAsBigInt() + if err != nil { + return nil, err + } + + // first get MerkleProofs, since they need to belong to RootHashBefore, i.e. before MerkleTransitions + if witness.ProcessID, err = o.GenMerkleProof(state.KeyProcessID); err != nil { + return nil, err + } + if witness.CensusRoot, err = o.GenMerkleProof(state.KeyCensusRoot); err != nil { + return nil, err + } + if witness.BallotMode, err = o.GenMerkleProof(state.KeyBallotMode); err != nil { + return nil, err + } + if witness.EncryptionKey, err = o.GenMerkleProof(state.KeyEncryptionKey); err != nil { + return nil, err + } + + // now build ordered chain of MerkleTransitions + + // add Ballots + for i := range witness.Ballot { + if i < len(o.votes) { + witness.Ballot[i], err = o.MerkleTransitionFromAddOrUpdate( + o.votes[i].nullifier, o.votes[i].elgamalBallot.Serialize()) + } else { + witness.Ballot[i], err = o.MerkleTransitionFromNoop() + } + if err != nil { + return nil, err + } + } + + // add Commitments + for i := range witness.Commitment { + if i < len(o.votes) { + witness.Commitment[i], err = o.MerkleTransitionFromAddOrUpdate( + o.votes[i].address, arbo.BigIntToBytes(32, &o.votes[i].commitment)) + } else { + witness.Commitment[i], err = o.MerkleTransitionFromNoop() + } + if err != nil { + return nil, err + } + } + + // update ResultsAdd + witness.ResultsAdd.OldCiphertext = o.resultsAdd.ToGnark() + witness.ResultsAdd.NewCiphertext = o.resultsAdd.Add(o.resultsAdd, o.ballotSum).ToGnark() + witness.ResultsAdd, err = o.MerkleTransitionFromAddOrUpdate( + KeyResultsAdd, o.resultsAdd.Serialize()) + if err != nil { + return nil, fmt.Errorf("ResultsAdd: %w", err) + } + + // update ResultsSub + witness.ResultsSub.OldCiphertext = o.resultsSub.ToGnark() + witness.ResultsSub.NewCiphertext = o.resultsSub.Add(o.resultsSub, o.overwriteSum).ToGnark() + witness.ResultsSub, err = o.MerkleTransitionFromAddOrUpdate( + KeyResultsSub, o.resultsSub.Serialize()) + if err != nil { + return nil, fmt.Errorf("ResultsSub: %w", err) + } + + // update stats + witness.NumNewVotes = o.ballotCount + witness.NumOverwrites = o.overwriteCount + + // RootHashAfter + witness.RootHashAfter, err = o.RootAsBigInt() + if err != nil { + return nil, err + } + + return witness, nil +} diff --git a/circuits/statetransition/util.go b/circuits/statetransition/util.go new file mode 100644 index 0000000..bb4dbbe --- /dev/null +++ b/circuits/statetransition/util.go @@ -0,0 +1,29 @@ +package statetransition + +import ( + "encoding/hex" + "fmt" + "math/big" + "reflect" + + "github.com/consensys/gnark/frontend" + "github.com/vocdoni/arbo" +) + +func prettyHex(v frontend.Variable) string { + type hasher interface { + HashCode() [16]byte + } + switch v := v.(type) { + case (*big.Int): + return hex.EncodeToString(arbo.BigIntToBytes(32, v)[:4]) + case int: + return fmt.Sprintf("%d", v) + case []byte: + return fmt.Sprintf("%x", v[:4]) + case hasher: + return fmt.Sprintf("%x", v.HashCode()) + default: + return fmt.Sprintf("(%v)=%+v", reflect.TypeOf(v), v) + } +} diff --git a/crypto/elgamal/ciphertext.go b/crypto/elgamal/ciphertext.go index 29ec458..a658cf1 100644 --- a/crypto/elgamal/ciphertext.go +++ b/crypto/elgamal/ciphertext.go @@ -6,9 +6,12 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark/std/algebra/native/twistededwards" "github.com/vocdoni/arbo" + gelgamal "github.com/vocdoni/gnark-crypto-primitives/elgamal" "github.com/vocdoni/vocdoni-z-sandbox/crypto/ecc" "github.com/vocdoni/vocdoni-z-sandbox/crypto/ecc/curves" + "github.com/vocdoni/vocdoni-z-sandbox/crypto/ecc/format" ) // Ciphertext represents an ElGamal encrypted message with homomorphic properties. @@ -43,11 +46,9 @@ func (z *Ciphertext) Encrypt(message *big.Int, publicKey ecc.Point, k *big.Int) if err != nil { return nil, fmt.Errorf("elgamal encryption failed: %w", err) } - return &Ciphertext{ - CurveType: z.CurveType, - C1: c1, - C2: c2, - }, nil + z.C1 = c1 + z.C2 = c2 + return z, nil } // Add adds two Ciphertext and stores the result in z, which is also returned. @@ -58,17 +59,13 @@ func (z *Ciphertext) Add(x, y *Ciphertext) *Ciphertext { } // Serialize returns a slice of len 4*32 bytes, -// representing the C1.X, C1.Y, C2.X, C2.Y as little-endian. +// representing the C1.X, C1.Y, C2.X, C2.Y as little-endian, +// in reduced twisted edwards form. func (z *Ciphertext) Serialize() []byte { - x1, y1 := z.C1.Point() - x2, y2 := z.C2.Point() var buf bytes.Buffer - for _, bi := range []*big.Int{ - x1, - y1, - x2, - y2, - } { + c1x, c1y := format.FromTEtoRTE(z.C1.Point()) + c2x, c2y := format.FromTEtoRTE(z.C2.Point()) + for _, bi := range []*big.Int{c1x, c1y, c2x, c2y} { if _, err := buf.Write(arbo.BigIntToBytes(32, bi)); err != nil { panic(err) } @@ -76,33 +73,32 @@ func (z *Ciphertext) Serialize() []byte { return buf.Bytes() } -// Deserialize reconstructs a Ciphertext from a slice of bytes. -// The input must be of len 4*32 bytes, representing the C1.X, C1.Y, C2.X, C2.Y as little-endian. -func (z *Ciphertext) Deserialize(data []byte) error { +// Deserialize reconstructs an Ciphertext from a slice of bytes. +// The input must be of len 4*32 bytes (otherwise it panics), +// representing the C1.X, C1.Y, C2.X, C2.Y as little-endian, +// in reduced twisted edwards form. +func (z *Ciphertext) Deserialize(data []byte) { const fieldSize = 32 // Each field element is 32 bytes expectedLen := 4 * fieldSize // Validate the input length if len(data) != expectedLen { - return fmt.Errorf("invalid input length: got %d bytes, expected %d bytes", len(data), expectedLen) + panic(fmt.Errorf("invalid input length: got %d bytes, expected %d bytes", len(data), expectedLen)) } // Helper function to extract *big.Int from a 32-byte slice readBigInt := func(offset int) *big.Int { return arbo.BytesToBigInt(data[offset : offset+fieldSize]) } - // Deserialize each field - x1 := readBigInt(0 * fieldSize) - y1 := readBigInt(1 * fieldSize) - x2 := readBigInt(2 * fieldSize) - y2 := readBigInt(3 * fieldSize) - - // Set the points and store the returned points - z.C1 = z.C1.SetPoint(x1, y1) - z.C2 = z.C2.SetPoint(x2, y2) - - return nil + z.C1 = z.C1.SetPoint(format.FromRTEtoTE( + readBigInt(0*fieldSize), + readBigInt(1*fieldSize), + )) + z.C2 = z.C2.SetPoint(format.FromRTEtoTE( + readBigInt(2*fieldSize), + readBigInt(3*fieldSize), + )) } // Marshal converts Ciphertext to a byte slice. @@ -122,3 +118,14 @@ func (z *Ciphertext) String() string { } return fmt.Sprintf("{C1: %s, C2: %s}", z.C1.String(), z.C2.String()) } + +// ToGnark returns z as the struct used by gnark, +// with the points in reduced twisted edwards format +func (z *Ciphertext) ToGnark() gelgamal.Ciphertext { + c1x, c1y := format.FromTEtoRTE(z.C1.Point()) + c2x, c2y := format.FromTEtoRTE(z.C2.Point()) + return gelgamal.Ciphertext{ + C1: twistededwards.Point{X: c1x, Y: c1y}, + C2: twistededwards.Point{X: c2x, Y: c2y}, + } +} diff --git a/crypto/elgamal/ciphertext_test.go b/crypto/elgamal/ciphertext_test.go index 5014fc4..572c2f7 100644 --- a/crypto/elgamal/ciphertext_test.go +++ b/crypto/elgamal/ciphertext_test.go @@ -99,8 +99,7 @@ func TestCiphertext_SerializeDeserialize(t *testing.T) { // Test deserialization deserialized := NewCiphertext(curves.CurveTypeBN254) - err = deserialized.Deserialize(serialized) - c.Assert(err, qt.IsNil) + deserialized.Deserialize(serialized) // Compare points x1, y1 := encrypted.C1.Point() @@ -172,13 +171,14 @@ func TestCiphertext_String(t *testing.T) { c.Assert(str, qt.Matches, `\{C1: .+, C2: .+\}`) } -func TestCiphertext_DeserializeErrors(t *testing.T) { +func TestCiphertext_DeserializePanic(t *testing.T) { c := qt.New(t) cipher := NewCiphertext(curves.CurveTypeBN254) - // Test with invalid length - err := cipher.Deserialize(make([]byte, 127)) // Should be 128 - c.Assert(err, qt.Not(qt.IsNil)) - c.Assert(err.Error(), qt.Matches, "invalid input length.*") + // Test with invalid length, should panic + c.Assert(func() { + cipher.Deserialize(make([]byte, 127)) // Should be 128 + }, + qt.PanicMatches, "invalid input length.*") } diff --git a/crypto/elgamal/elgamal.go b/crypto/elgamal/elgamal.go index 72db45c..6c7af1e 100644 --- a/crypto/elgamal/elgamal.go +++ b/crypto/elgamal/elgamal.go @@ -41,6 +41,8 @@ func Encrypt(publicKey ecc.Point, msg *big.Int) (ecc.Point, ecc.Point, *big.Int, // EncryptWithK function encrypts a message using the public key provided as // elliptic curve point and the random k value provided. It returns the two // points that represent the encrypted message and error if any. +// +// TODO: remove error return, since it can never error func EncryptWithK(pubKey ecc.Point, msg, k *big.Int) (ecc.Point, ecc.Point, error) { order := pubKey.Order() // ensure the message is within the field diff --git a/go.mod b/go.mod index 2ccabdd..988b324 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/rs/zerolog v1.33.0 github.com/vocdoni/arbo v0.0.0-20241216103934-e64315269b49 github.com/vocdoni/circom2gnark v1.0.1-0.20241118090531-f24bf0de0e2f - github.com/vocdoni/gnark-crypto-primitives v0.0.2-0.20241204072449-cc4388ff8631 + github.com/vocdoni/gnark-crypto-primitives v0.0.2-0.20241216102457-ca039c1c39a8 go.vocdoni.io/dvote v1.10.2-0.20241024102542-c1ce6d744bc5 ) diff --git a/go.sum b/go.sum index 5ee9684..d3d0ab2 100644 --- a/go.sum +++ b/go.sum @@ -209,8 +209,12 @@ github.com/vocdoni/arbo v0.0.0-20241216103934-e64315269b49 h1:GMyepEuxLflqhdDHts github.com/vocdoni/arbo v0.0.0-20241216103934-e64315269b49/go.mod h1:wXxPP+5vkT5t54lrKz6bCXKIyv8aRplKq8uCFb2wgy4= github.com/vocdoni/circom2gnark v1.0.1-0.20241118090531-f24bf0de0e2f h1:iy2/GnPg5IdlkqslXUwGmqlqONDZSDnDu+1+h9LSDwM= github.com/vocdoni/circom2gnark v1.0.1-0.20241118090531-f24bf0de0e2f/go.mod h1:A1WU0hL7rO9oZlvp82you2uCc4T3/ySi1UNW6N6hBJs= +github.com/vocdoni/gnark-crypto-primitives v0.0.1 h1:RmxfYvHCFI1lnSPs07ZSTVNpoj0qrQSaNXwkh6vqcRg= +github.com/vocdoni/gnark-crypto-primitives v0.0.1/go.mod h1:0rxBVlF+OY7EnoDatCPMTxrY6FzxNkyc4c4sweIvTRQ= github.com/vocdoni/gnark-crypto-primitives v0.0.2-0.20241204072449-cc4388ff8631 h1:dQMpZvt7Z7UoCHlPLXSFRdKR/FCZhwmsOJ64EhshdsU= github.com/vocdoni/gnark-crypto-primitives v0.0.2-0.20241204072449-cc4388ff8631/go.mod h1:HyEIwSHyIqrtEKJUUybjvloMDrmWWqmGOhsEsKgAPso= +github.com/vocdoni/gnark-crypto-primitives v0.0.2-0.20241216102457-ca039c1c39a8 h1:G+qqPdXSC/3UO0qNo+epfonWBYfM89DDjxKkdkG3Qj8= +github.com/vocdoni/gnark-crypto-primitives v0.0.2-0.20241216102457-ca039c1c39a8/go.mod h1:HyEIwSHyIqrtEKJUUybjvloMDrmWWqmGOhsEsKgAPso= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= diff --git a/state/merkleproof.go b/state/merkleproof.go index 9744048..9a48c59 100644 --- a/state/merkleproof.go +++ b/state/merkleproof.go @@ -9,6 +9,7 @@ import ( "github.com/consensys/gnark/frontend" "github.com/vocdoni/arbo" garbo "github.com/vocdoni/gnark-crypto-primitives/tree/arbo" + "github.com/vocdoni/gnark-crypto-primitives/utils" encrypt "github.com/vocdoni/vocdoni-z-sandbox/crypto/elgamal" ) @@ -22,12 +23,12 @@ type ArboProof struct { Existence bool } -func GenArboProof(t *arbo.Tree, k []byte) (ArboProof, error) { - root, err := t.Root() +func (o *State) GenArboProof(k []byte) (ArboProof, error) { + root, err := o.tree.Root() if err != nil { return ArboProof{}, err } - leafK, leafV, packedSiblings, existence, err := t.GenProof(k) + leafK, leafV, packedSiblings, existence, err := o.tree.GenProof(k) if err != nil { return ArboProof{}, err } @@ -54,8 +55,8 @@ type MerkleProof struct { Fnc frontend.Variable // 0: inclusion, 1: non inclusion } -func GenMerkleProof(t *arbo.Tree, k []byte) (MerkleProof, error) { - p, err := GenArboProof(t, k) +func (o *State) GenMerkleProof(k []byte) (MerkleProof, error) { + p, err := o.GenArboProof(k) if err != nil { return MerkleProof{}, err } @@ -91,7 +92,7 @@ func padSiblings(unpackedSiblings [][]byte) [MaxLevels]frontend.Variable { // Verify uses garbo.CheckInclusionProof to verify that: // - mp.Root matches passed root // - Key + Value belong to Root -func (mp *MerkleProof) VerifyProof(api frontend.API, hFn garbo.Hash, root frontend.Variable) { +func (mp *MerkleProof) VerifyProof(api frontend.API, hFn utils.Hasher, root frontend.Variable) { api.AssertIsEqual(root, mp.Root) if err := garbo.CheckInclusionProof(api, hFn, mp.Key, mp.Value, mp.Root, mp.Siblings[:]); err != nil { @@ -159,21 +160,21 @@ func MerkleTransitionFromArboProofPair(before, after ArboProof) MerkleTransition // MerkleTransitionFromAddOrUpdate adds or updates a key in the tree, // and returns a MerkleTransition. -func MerkleTransitionFromAddOrUpdate(t *arbo.Tree, k []byte, v []byte) (MerkleTransition, error) { - mpBefore, err := GenArboProof(t, k) +func (o *State) MerkleTransitionFromAddOrUpdate(k []byte, v []byte) (MerkleTransition, error) { + mpBefore, err := o.GenArboProof(k) if err != nil { return MerkleTransition{}, err } - if _, _, err := t.Get(k); errors.Is(err, arbo.ErrKeyNotFound) { - if err := t.Add(k, v); err != nil { + if _, _, err := o.tree.Get(k); errors.Is(err, arbo.ErrKeyNotFound) { + if err := o.tree.Add(k, v); err != nil { return MerkleTransition{}, fmt.Errorf("add key failed: %w", err) } } else { - if err := t.Update(k, v); err != nil { + if err := o.tree.Update(k, v); err != nil { return MerkleTransition{}, fmt.Errorf("update key failed: %w", err) } } - mpAfter, err := GenArboProof(t, k) + mpAfter, err := o.GenArboProof(k) if err != nil { return MerkleTransition{}, err } diff --git a/state/state.go b/state/state.go index 86a5533..9f44cfd 100644 --- a/state/state.go +++ b/state/state.go @@ -5,6 +5,7 @@ import ( "github.com/vocdoni/arbo" "github.com/vocdoni/vocdoni-z-sandbox/crypto/ecc/curves" + "github.com/vocdoni/vocdoni-z-sandbox/crypto/elgamal" encrypt "github.com/vocdoni/vocdoni-z-sandbox/crypto/elgamal" "go.vocdoni.io/dvote/db" "go.vocdoni.io/dvote/db/prefixeddb" @@ -24,28 +25,6 @@ const ( // hashFunc is the hash function used in the state tree. var hashFunc = arbo.HashMiMC_BN254{} -func (o *State) oldVote(nullifier []byte) *encrypt.Ciphertext { - data, err := o.dbTx.Get(nullifier) - if err != nil { - panic(err) - } - v := &encrypt.Ciphertext{} - if err := v.Unmarshal(data); err != nil { - panic(err) - } - return v -} - -func (o *State) storeVote(nullifier []byte, vote *encrypt.Ciphertext) { - data, err := vote.Marshal() - if err != nil { - panic(err) - } - if err := o.dbTx.Set(nullifier, data); err != nil { - panic(err) - } -} - var ( KeyProcessID = []byte{0x00} KeyCensusRoot = []byte{0x01} @@ -64,7 +43,6 @@ type State struct { processID []byte db db.Database dbTx db.WriteTx - // Witnesses statetransition.Circuit // witnesses for the snark circuit resultsAdd *encrypt.Ciphertext resultsSub *encrypt.Ciphertext @@ -95,28 +73,28 @@ func New(db db.Database, processId []byte) (*State, error) { } // Initialize creates a new State, initialized with the passed parameters. -func (o *State) Initialize(censusRoot, ballotMode, encryptionKey []byte) (*State, error) { +// +// after Initialize, caller is expected to StartBatch, AddVote, EndBatch, StartBatch... +func (o *State) Initialize(censusRoot, ballotMode, encryptionKey []byte) error { if err := o.tree.Add(KeyProcessID, o.processID); err != nil { - return nil, err + return err } if err := o.tree.Add(KeyCensusRoot, censusRoot); err != nil { - return nil, err + return err } if err := o.tree.Add(KeyBallotMode, ballotMode); err != nil { - return nil, err + return err } if err := o.tree.Add(KeyEncryptionKey, encryptionKey); err != nil { - return nil, err + return err } if err := o.tree.Add(KeyResultsAdd, encrypt.NewCiphertext(CurveType).Serialize()); err != nil { - return nil, err + return err } if err := o.tree.Add(KeyResultsSub, encrypt.NewCiphertext(CurveType).Serialize()); err != nil { - return nil, err + return err } - o.resultsAdd = encrypt.NewCiphertext(CurveType) - o.resultsSub = encrypt.NewCiphertext(CurveType) - return o, nil + return nil } // Close the database, no more operations can be done after this. @@ -124,6 +102,38 @@ func (o *State) Close() error { return o.db.Close() } +// StartBatch resets counters and sums to zero, +// and creates a new write transaction in the db +func (o *State) StartBatch() error { + o.dbTx = o.db.WriteTx() + + { + _, v, err := o.tree.Get(KeyResultsAdd) + if err != nil { + return err + } + o.resultsAdd.Deserialize(v) + } + { + _, v, err := o.tree.Get(KeyResultsSub) + if err != nil { + return err + } + o.resultsSub.Deserialize(v) + } + + o.ballotSum = elgamal.NewCiphertext(CurveType) + o.overwriteSum = elgamal.NewCiphertext(CurveType) + o.ballotCount = 0 + o.overwriteCount = 0 + o.votes = []Vote{} + return nil +} + +func (o *State) EndBatch() error { + return o.dbTx.Commit() +} + func (o *State) RootAsBigInt() (*big.Int, error) { root, err := o.tree.Root() if err != nil { diff --git a/state/vote.go b/state/vote.go index 0f5a141..ede249e 100644 --- a/state/vote.go +++ b/state/vote.go @@ -2,7 +2,10 @@ package state import ( "fmt" + "math/big" + "github.com/vocdoni/arbo" + "github.com/vocdoni/vocdoni-z-sandbox/crypto/ecc/curves" "github.com/vocdoni/vocdoni-z-sandbox/crypto/elgamal" ) @@ -10,8 +13,33 @@ import ( type Vote struct { nullifier []byte elgamalBallot *elgamal.Ciphertext - // address []byte - // commitment big.Int + address []byte + commitment big.Int +} + +// NewVote creates a new vote +func NewVote(nullifier, amount uint64) Vote { + var v Vote + v.nullifier = arbo.BigIntToBytes(MaxKeyLen, + big.NewInt(int64(nullifier)+int64(KeyNullifiersOffset))) // mock + + // generate a public mocked key + publicKey, _, err := elgamal.GenerateKey(curves.New(CurveType)) + if err != nil { + panic(fmt.Errorf("error generating public key: %v", err)) + } + + c, err := elgamal.NewCiphertext(CurveType).Encrypt(big.NewInt(int64(amount)), publicKey, nil) + if err != nil { + panic(fmt.Errorf("error encrypting: %v", err)) + } + + v.elgamalBallot = c + + v.address = arbo.BigIntToBytes(MaxKeyLen, + big.NewInt(int64(nullifier)+int64(KeyAddressesOffset))) // mock + v.commitment.SetUint64(amount + 256) // mock + return v } // AddVote adds a vote to the state @@ -25,8 +53,10 @@ func (o *State) AddVote(v Vote) error { // if nullifier exists, it's a vote overwrite, need to count the overwritten vote // so it's later added to circuit.ResultsSub - if _, _, err := o.tree.Get(v.nullifier); err == nil { - o.overwriteSum.Add(o.overwriteSum, o.oldVote(v.nullifier)) + if _, value, err := o.tree.Get(v.nullifier); err == nil { + oldVote := elgamal.NewCiphertext(CurveType) + oldVote.Deserialize(value) + o.overwriteSum.Add(o.overwriteSum, oldVote) o.overwriteCount++ } @@ -34,7 +64,5 @@ func (o *State) AddVote(v Vote) error { o.ballotCount++ o.votes = append(o.votes, v) - - o.storeVote(v.nullifier, v.elgamalBallot) return nil }