From 43e95ce63a65e9de5843a6e0b994b3b3e7da618c Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 19 Jun 2023 19:18:49 -0700 Subject: [PATCH] Release 0.1.1 (#19) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Rollback to pse halo2 and halo2wrong for first release (#5) * feat: move `Accumulator` to `accumulator.rs` * feat: update due to halo2 * feat: upgrade to use branch `feature/generic-instructions` of `halo2wrong` * refactor: rollback to `{halo2,halo2_wrong}` without challenge API and cleanup dependencies * chore: rename statement to instance and auxliary to witness * chore: use `finalize` instead of `code` * feat: add `Code::deployment` and `EvmLoader::deployment_code`; add example `evm-verifier-codegen` * fix: typo * feat: reduce generated evm verifier size; rename to `evm-verifier` and add another example `evm-verifier-with-accumulator` * fix: due to `halo2wrong` * feat: reorganize mods and traits * fix: allow empty `values` in `sum_*` and move them under `ScalarLoader` * ci: use `--all-features` for `cargo test` * fix: use same strategy for aggregation testing * fix: simplify trait `PlonkVerifier` again * fix: move system specified transcript under mod `system` * feat: add `quotient_poly` info in `Protocol` * feat: implement linearization for circom integration * feat: re-export loader's dependency for consumer * refactor: for circom's integration * tmp: pin `revm` to rev * fix: remove parentheses * fix: upgrade for multi-phase halo2 * feat: improve error reporting * chore: rename crate to snake case * feat: add `Domain` as an input of `PolynomialCommitmentScheme::read_proof` * refactor: for further integration * feat: generalize to suppoer also ipa and add builder fns to `system::halo2::Config` * feat: add `KzgDecider` for simple evm verifier * refactor: split `AccumulationScheme` and `AccumulatorEncoding` * refactor: split `PolynomialCommitmentScheme` and `MultiOpenScheme` * fix: not need sealed actually * fix: `chunk_size` should be `LIMBS` when recovering accumulator * feat: add `Expression::DistributePowers` to avoid stack overflow * fix: update and pin foundry * fix: move testing circuits under `system/halo2` * fix: allow accumulate single accumulator * feat: remove all patch and make less depending `halo2wrong` * Generalized `Halo2Loader` (#12) * feat: generalize `Protocol` for further usage * feat: add `EccInstruction::{fixed_base_msm,variable_base_msm,sum_with_const}` * chore: move `rand_chacha` as dev dependency * General refactor for further integration (#13) * feat: remove dev-dependency `foundry` and vendor necessary part of it * refactor: simplify traits and remove unused stuff * refactor: much less clone * feat: generalized `AccumulatorEncoding` for `EccInstructions` * feat: implement ipa pcs and accumulation (#14) * ci: add `svm-rs` and install `solc@0.8.17` in job `test` (#16) * Update `EvmLoader` to generate Yul code instead of bytecode (#15) * Update `EvmLoader` to generate Yul instead of bytecode * feat: simplify * feat: Add missing end_gas_metering impl Co-authored-by: Han Co-authored-by: Han * fix: pin all `revm` dependencies (#18) * fix: looser trait bound on impl `CostEstimation` for `Plonk` (#20) * Restructure for more kind of verifier (#21) * feat: restructure to monorepo and expand the project scope to be generic (s)nark verifier * feat: reorganize mods and traits for further new features * refactor: simplify trait bounds * chore: use hyphen case for crate name (`snark_verifier` -> `snark-verifier`) * docs: add `#![deny(missing_docs)]` and simple documents * refactor: remove redudant check `validate_ec_point` (still doesn not support identity) * feat: expand more things and fix typos Co-authored-by: Chih Cheng Liang Co-authored-by: Carlos Pérez <37264926+CPerezz@users.noreply.github.com> * fix: rustdoc warnings * chore: update dependencies (#24) * chore: update `halo2` and `halo2wrong` version (#25) * fix: enable `util::hash::poseidon` only when `feature = loader_halo2` (#27) * feat: working update to halo2-lib v0.3.0 * feat: update zkevm bench * feat: update recursion example * feat: switch poseidon native implementation to Scroll's audited version * fix: invert determinant only once in Cramer's rule * chore: fix doc * chore * chore: forgot to update halo2-base dependency in snark-verifier-sdk * Minor update (#8) * feat(sdk): remove duplicate code in `RangeWithInstanceCircuitBuilder::synthesize` * feat(sdk): Proof caching when using feature 'halo2-pse' * chore: sync with halo2-lib * chore: switch to halo2-lib release-0.3.0 branch * Moved `RangeWithInstanceCircuitBuilder` to halo2-lib (#9) * chore: sync with halo2-lib * fix: clippy * chore: fix halo2-base branch in sdk * feat: update to halo2-lib new types (#10) * feat: add `assert` for non-empty accumulators in `decide_all` (#11) * feat: use `zip_eq` for `Polynomial` add/sub (#12) * fix: git CI turn off all features * fix: `rotate_scalar` misbehaves on `i32::MIN` (#13) Should never actually be callable with such a large negative rotation * fix: cleanup code quality (#14) * fix: `split_by_ascii_whitespace` (#15) * fix: `batch_invert_and_mul` do not allow zeros (#16) * feat: verify proof in release mode (#17) Verify proof before caching it as extra safety * fix: add better error messages/docs for catching empty inputs (#18) * chore: add Cargo.lock * chore: update Cargo dependencies * feat: fix versions and tags for dependencies --------- Co-authored-by: Han Co-authored-by: DoHoon Kim <59155248+DoHoonKim8@users.noreply.github.com> Co-authored-by: Chih Cheng Liang Co-authored-by: Carlos Pérez <37264926+CPerezz@users.noreply.github.com> Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com> Co-authored-by: Jonathan Wang --- .github/workflows/ci.yaml | 54 ++ .gitignore | 1 - Cargo.toml | 4 - README.md | 4 +- snark-verifier-sdk/Cargo.toml | 28 +- snark-verifier-sdk/benches/standard_plonk.rs | 83 +- snark-verifier-sdk/benches/zkevm.rs | 151 ++-- .../benches/zkevm_plus_state.rs | 152 ++-- snark-verifier-sdk/configs/bench_zkevm.config | 1 - snark-verifier-sdk/configs/bench_zkevm.json | 7 + .../configs/bench_zkevm_plus_state.json | 7 + .../configs/example_evm_accumulator.config | 1 - .../configs/example_evm_accumulator.json | 10 + .../configs/verify_circuit.config | 1 - snark-verifier-sdk/src/evm.rs | 65 +- snark-verifier-sdk/src/halo2.rs | 108 +-- snark-verifier-sdk/src/halo2/aggregation.rs | 641 +++++++--------- snark-verifier-sdk/src/lib.rs | 68 +- snark-verifier/Cargo.toml | 34 +- .../configs/example_evm_accumulator.config | 1 - .../configs/example_evm_accumulator.json | 7 + snark-verifier/configs/example_recursion.json | 7 + snark-verifier/configs/verify_circuit.config | 1 - .../examples/evm-verifier-with-accumulator.rs | 500 ++++++------ snark-verifier/examples/evm-verifier.rs | 22 +- snark-verifier/examples/recursion.rs | 567 +++++++------- snark-verifier/src/cost.rs | 44 +- snark-verifier/src/lib.rs | 57 +- snark-verifier/src/loader.rs | 77 +- snark-verifier/src/loader/evm.rs | 10 +- snark-verifier/src/loader/evm/code.rs | 11 +- snark-verifier/src/loader/evm/loader.rs | 124 ++- snark-verifier/src/loader/evm/test.rs | 3 +- snark-verifier/src/loader/evm/test/tui.rs | 18 +- snark-verifier/src/loader/evm/util.rs | 56 +- .../src/loader/evm/util/executor.rs | 33 +- snark-verifier/src/loader/halo2.rs | 46 +- snark-verifier/src/loader/halo2/loader.rs | 384 ++++------ snark-verifier/src/loader/halo2/shim.rs | 597 +++------------ snark-verifier/src/loader/halo2/test.rs | 62 -- snark-verifier/src/loader/native.rs | 20 +- snark-verifier/src/pcs.rs | 129 ++-- snark-verifier/src/pcs/ipa.rs | 108 ++- snark-verifier/src/pcs/ipa/accumulation.rs | 53 +- snark-verifier/src/pcs/ipa/accumulator.rs | 4 + snark-verifier/src/pcs/ipa/decider.rs | 46 +- snark-verifier/src/pcs/ipa/multiopen.rs | 2 +- snark-verifier/src/pcs/ipa/multiopen/bgh19.rs | 162 ++-- snark-verifier/src/pcs/kzg.rs | 25 +- snark-verifier/src/pcs/kzg/accumulation.rs | 77 +- snark-verifier/src/pcs/kzg/accumulator.rs | 161 ++-- snark-verifier/src/pcs/kzg/decider.rs | 92 ++- .../src/pcs/kzg/multiopen/bdfg21.rs | 172 +++-- snark-verifier/src/pcs/kzg/multiopen/gwc19.rs | 53 +- snark-verifier/src/system.rs | 2 + snark-verifier/src/system/halo2.rs | 44 +- .../src/system/halo2/aggregation.rs | 718 ------------------ snark-verifier/src/system/halo2/strategy.rs | 14 +- snark-verifier/src/system/halo2/test.rs | 215 ------ .../src/system/halo2/test/circuit.rs | 2 - .../src/system/halo2/test/circuit/maingate.rs | 111 --- .../src/system/halo2/test/circuit/standard.rs | 146 ---- snark-verifier/src/system/halo2/test/ipa.rs | 143 ---- .../src/system/halo2/test/ipa/native.rs | 59 -- snark-verifier/src/system/halo2/test/kzg.rs | 106 --- .../src/system/halo2/test/kzg/evm.rs | 137 ---- .../src/system/halo2/test/kzg/halo2.rs | 618 --------------- .../src/system/halo2/test/kzg/native.rs | 70 -- snark-verifier/src/system/halo2/transcript.rs | 4 +- .../src/system/halo2/transcript/evm.rs | 25 +- .../src/system/halo2/transcript/halo2.rs | 143 ++-- snark-verifier/src/util.rs | 5 +- snark-verifier/src/util/arithmetic.rs | 109 ++- snark-verifier/src/util/hash.rs | 6 +- snark-verifier/src/util/hash/poseidon.rs | 374 ++++++++- .../src/util/hash/poseidon/tests.rs | 85 +++ snark-verifier/src/util/msm.rs | 59 +- snark-verifier/src/util/poly.rs | 35 +- snark-verifier/src/util/transcript.rs | 16 + snark-verifier/src/verifier.rs | 48 +- snark-verifier/src/verifier/plonk.rs | 441 +++-------- snark-verifier/src/verifier/plonk/proof.rs | 319 ++++++++ .../src/{util => verifier/plonk}/protocol.rs | 146 +++- 83 files changed, 3597 insertions(+), 5754 deletions(-) create mode 100644 .github/workflows/ci.yaml delete mode 100644 snark-verifier-sdk/configs/bench_zkevm.config create mode 100644 snark-verifier-sdk/configs/bench_zkevm.json create mode 100644 snark-verifier-sdk/configs/bench_zkevm_plus_state.json delete mode 100644 snark-verifier-sdk/configs/example_evm_accumulator.config create mode 100644 snark-verifier-sdk/configs/example_evm_accumulator.json delete mode 100644 snark-verifier-sdk/configs/verify_circuit.config delete mode 100644 snark-verifier/configs/example_evm_accumulator.config create mode 100644 snark-verifier/configs/example_evm_accumulator.json create mode 100644 snark-verifier/configs/example_recursion.json delete mode 100644 snark-verifier/configs/verify_circuit.config delete mode 100644 snark-verifier/src/loader/halo2/test.rs delete mode 100644 snark-verifier/src/system/halo2/aggregation.rs delete mode 100644 snark-verifier/src/system/halo2/test.rs delete mode 100644 snark-verifier/src/system/halo2/test/circuit.rs delete mode 100644 snark-verifier/src/system/halo2/test/circuit/maingate.rs delete mode 100644 snark-verifier/src/system/halo2/test/circuit/standard.rs delete mode 100644 snark-verifier/src/system/halo2/test/ipa.rs delete mode 100644 snark-verifier/src/system/halo2/test/ipa/native.rs delete mode 100644 snark-verifier/src/system/halo2/test/kzg.rs delete mode 100644 snark-verifier/src/system/halo2/test/kzg/evm.rs delete mode 100644 snark-verifier/src/system/halo2/test/kzg/halo2.rs delete mode 100644 snark-verifier/src/system/halo2/test/kzg/native.rs create mode 100644 snark-verifier/src/util/hash/poseidon/tests.rs create mode 100644 snark-verifier/src/verifier/plonk/proof.rs rename snark-verifier/src/{util => verifier/plonk}/protocol.rs (68%) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 00000000..634eed77 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,54 @@ +name: CI + +on: + pull_request: + push: + branches: + - main + +jobs: + test: + name: Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Install toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + override: false + + - uses: Swatinem/rust-cache@v1 + with: + cache-on-failure: true + + - name: Install solc + run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.17 && solc --version + + - name: Run test + run: cargo test --all -- --nocapture + + + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Install toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + override: false + components: rustfmt, clippy + + - uses: Swatinem/rust-cache@v1 + with: + cache-on-failure: true + + - name: Run fmt + run: cargo fmt --all -- --check + + - name: Run clippy + run: cargo clippy --all --all-targets -- -D warnings diff --git a/.gitignore b/.gitignore index 0175c775..829691c6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,3 @@ /target testdata -Cargo.lock diff --git a/Cargo.toml b/Cargo.toml index 7b3ec409..552212b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,3 @@ incremental = false [profile.flamegraph] inherits = "release" debug = true - -# patch until PR https://github.com/privacy-scaling-explorations/halo2/pull/111 is merged -[patch."https://github.com/privacy-scaling-explorations/halo2.git"] -halo2_proofs = { git = "https://github.com/axiom-crypto/halo2.git", branch = "feat/serde-raw" } \ No newline at end of file diff --git a/README.md b/README.md index bcd16c74..db401c9f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,3 @@ -# PLONK Verifier +# SNARK Verifier -Generic PLONK verifier. +Generic (S)NARK verifier. diff --git a/snark-verifier-sdk/Cargo.toml b/snark-verifier-sdk/Cargo.toml index 222e6eb9..4901528b 100644 --- a/snark-verifier-sdk/Cargo.toml +++ b/snark-verifier-sdk/Cargo.toml @@ -1,10 +1,10 @@ [package] name = "snark-verifier-sdk" -version = "0.0.1" +version = "0.1.1" edition = "2021" [dependencies] -itertools = "0.10.3" +itertools = "0.10.5" lazy_static = "1.4.0" num-bigint = "0.4.3" num-integer = "0.1.45" @@ -14,24 +14,24 @@ rand_chacha = "0.3.1" hex = "0.4" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +serde_with = { version = "2.2", optional = true } bincode = "1.3.3" ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } - -halo2-base = { git = "https://github.com/axiom-crypto/halo2-lib.git", tag = "v0.2.2", default-features = false } +halo2-base = { git = "https://github.com/axiom-crypto/halo2-lib.git", tag = "v0.3.0", default-features = false } snark-verifier = { path = "../snark-verifier", default-features = false } # loader_evm -ethereum-types = { version = "0.14", default-features = false, features = ["std"], optional = true } +ethereum-types = { version = "=0.14.1", default-features = false, features = ["std"], optional = true } # sha3 = { version = "0.10", optional = true } # revm = { version = "2.3.1", optional = true } # bytes = { version = "1.2", optional = true } # rlp = { version = "0.5", default-features = false, features = ["std"], optional = true } # zkevm benchmarks -zkevm-circuits = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", features = ["test"], optional = true } -bus-mapping = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } -eth-types = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } -mock = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } +zkevm-circuits = { git = "https://github.com/privacy-scaling-explorations/zkevm-circuits.git", rev = "f834e61", features = ["test"], optional = true } +bus-mapping = { git = "https://github.com/privacy-scaling-explorations/zkevm-circuits.git", rev = "f834e61", optional = true } +eth-types = { git = "https://github.com/privacy-scaling-explorations/zkevm-circuits.git", rev = "f834e61", optional = true } +mock = { git = "https://github.com/privacy-scaling-explorations/zkevm-circuits.git", rev = "f834e61", optional = true } [dev-dependencies] ark-std = { version = "0.3.0", features = ["print-trace"] } @@ -44,13 +44,13 @@ crossterm = { version = "0.25" } tui = { version = "0.19", default-features = false, features = ["crossterm"] } [features] -default = ["loader_halo2", "loader_evm", "halo2-axiom", "halo2-base/jemallocator"] +default = ["loader_halo2", "loader_evm", "halo2-axiom", "halo2-base/jemallocator", "display"] display = ["snark-verifier/display", "dep:ark-std"] loader_evm = ["snark-verifier/loader_evm", "dep:ethereum-types"] loader_halo2 = ["snark-verifier/loader_halo2"] parallel = ["snark-verifier/parallel"] # EXACTLY one of halo2-pse / halo2-axiom should always be turned on; not sure how to enforce this with Cargo -halo2-pse = ["snark-verifier/halo2-pse"] +halo2-pse = ["snark-verifier/halo2-pse", "dep:serde_with"] halo2-axiom = ["snark-verifier/halo2-axiom"] zkevm = ["dep:zkevm-circuits", "dep:bus-mapping", "dep:mock", "dep:eth-types"] @@ -62,10 +62,10 @@ harness = false [[bench]] name = "zkevm" -required-features = ["loader_halo2", "zkevm", "halo2-pse", "halo2-base/jemallocator"] +required-features = ["loader_halo2", "loader_evm", "zkevm", "halo2-pse"] harness = false [[bench]] name = "zkevm_plus_state" -required-features = ["loader_halo2", "zkevm", "halo2-pse", "halo2-base/jemallocator"] -harness = false \ No newline at end of file +required-features = ["loader_halo2", "loader_evm", "zkevm", "halo2-pse"] +harness = false diff --git a/snark-verifier-sdk/benches/standard_plonk.rs b/snark-verifier-sdk/benches/standard_plonk.rs index e19776e9..70d600ea 100644 --- a/snark-verifier-sdk/benches/standard_plonk.rs +++ b/snark-verifier-sdk/benches/standard_plonk.rs @@ -1,7 +1,10 @@ use criterion::{criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion}; +use halo2_base::gates::builder::CircuitBuilderStage; +use halo2_base::utils::fs::gen_srs; use pprof::criterion::{Output, PProfProfiler}; - +use rand::rngs::OsRng; +use std::path::Path; use ark_std::{end_timer, start_timer}; use halo2_base::halo2_proofs; use halo2_proofs::halo2curves as halo2_curves; @@ -9,15 +12,14 @@ use halo2_proofs::{ halo2curves::bn256::Bn256, poly::{commitment::Params, kzg::commitment::ParamsKZG}, }; -use rand::rngs::OsRng; -use rand::SeedableRng; -use rand_chacha::ChaCha20Rng; -use snark_verifier_sdk::CircuitExt; +use snark_verifier_sdk::evm::{evm_verify, gen_evm_proof_shplonk, gen_evm_verifier_shplonk}; +use snark_verifier_sdk::halo2::aggregation::AggregationConfigParams; use snark_verifier_sdk::{ gen_pk, halo2::{aggregation::AggregationCircuit, gen_proof_shplonk, gen_snark_shplonk}, Snark, }; +use snark_verifier_sdk::{CircuitExt, SHPLONK}; mod application { use super::halo2_curves::bn256::Fr; @@ -145,9 +147,9 @@ mod application { } #[cfg(feature = "halo2-axiom")] { - region.assign_advice(config.a, 0, Value::known(self.0))?; + region.assign_advice(config.a, 0, Value::known(self.0)); region.assign_fixed(config.q_a, 0, -Fr::one()); - region.assign_advice(config.a, 1, Value::known(-Fr::from(5u64)))?; + region.assign_advice(config.a, 1, Value::known(-Fr::from(5u64))); for (idx, column) in (1..).zip([ config.q_a, config.q_b, @@ -158,7 +160,7 @@ mod application { region.assign_fixed(column, 1, Fr::from(idx as u64)); } - let a = region.assign_advice(config.a, 2, Value::known(Fr::one()))?; + let a = region.assign_advice(config.a, 2, Value::known(Fr::one())); a.copy_advice(&mut region, config.b, 3); a.copy_advice(&mut region, config.c, 4); } @@ -173,42 +175,69 @@ mod application { fn gen_application_snark(params: &ParamsKZG) -> Snark { let circuit = application::StandardPlonk::rand(OsRng); - let pk = gen_pk(params, &circuit, None); - gen_snark_shplonk(params, &pk, circuit, &mut OsRng, None::<&str>) + let pk = gen_pk(params, &circuit, Some(Path::new("app.pk"))); + gen_snark_shplonk(params, &pk, circuit, Some(Path::new("app.snark"))) } fn bench(c: &mut Criterion) { - std::env::set_var("VERIFY_CONFIG", "./configs/example_evm_accumulator.config"); - let k = 21; - let params = halo2_base::utils::fs::gen_srs(k); - let params_app = { - let mut params = params.clone(); - params.downsize(8); - params - }; + let path = "./configs/example_evm_accumulator.json"; + let params_app = gen_srs(8); let snarks = [(); 3].map(|_| gen_application_snark(¶ms_app)); + let agg_config = AggregationConfigParams::from_path(path); + let params = gen_srs(agg_config.degree); + let lookup_bits = params.k() as usize - 1; - let start1 = start_timer!(|| "Create aggregation circuit"); - let mut rng = ChaCha20Rng::from_entropy(); - let agg_circuit = AggregationCircuit::new(¶ms, snarks, &mut rng); - end_timer!(start1); + let agg_circuit = AggregationCircuit::keygen::(¶ms, snarks.clone()); - let pk = gen_pk(¶ms, &agg_circuit, None); + let start0 = start_timer!(|| "gen vk & pk"); + let pk = gen_pk(¶ms, &agg_circuit, Some(Path::new("agg.pk"))); + end_timer!(start0); + let break_points = agg_circuit.break_points(); let mut group = c.benchmark_group("plonk-prover"); group.sample_size(10); group.bench_with_input( - BenchmarkId::new("standard-plonk-agg", k), - &(¶ms, &pk, &agg_circuit), - |b, &(params, pk, agg_circuit)| { + BenchmarkId::new("standard-plonk-agg", params.k()), + &(¶ms, &pk, &break_points, &snarks), + |b, &(params, pk, break_points, snarks)| { b.iter(|| { + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Prover, + Some(break_points.clone()), + lookup_bits, + params, + snarks.clone(), + ); let instances = agg_circuit.instances(); - gen_proof_shplonk(params, pk, agg_circuit.clone(), instances, &mut rng, None) + gen_proof_shplonk(params, pk, agg_circuit, instances, None) }) }, ); group.finish(); + + #[cfg(feature = "loader_evm")] + { + // do one more time to verify + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Prover, + Some(break_points), + lookup_bits, + ¶ms, + snarks.clone(), + ); + let num_instances = agg_circuit.num_instance(); + let instances = agg_circuit.instances(); + let proof = gen_evm_proof_shplonk(¶ms, &pk, agg_circuit, instances.clone()); + + let deployment_code = gen_evm_verifier_shplonk::( + ¶ms, + pk.get_vk(), + num_instances, + None, + ); + evm_verify(deployment_code, instances, proof); + } } criterion_group! { diff --git a/snark-verifier-sdk/benches/zkevm.rs b/snark-verifier-sdk/benches/zkevm.rs index f4c3f3fc..e4ab7e16 100644 --- a/snark-verifier-sdk/benches/zkevm.rs +++ b/snark-verifier-sdk/benches/zkevm.rs @@ -1,6 +1,6 @@ use ark_std::{end_timer, start_timer}; -use halo2_base::halo2_proofs; use halo2_base::utils::fs::gen_srs; +use halo2_base::{gates::builder::CircuitBuilderStage, halo2_proofs}; use halo2_proofs::halo2curves::bn256::Fr; use rand::SeedableRng; use rand_chacha::ChaCha20Rng; @@ -13,10 +13,12 @@ use snark_verifier_sdk::{ }, gen_pk, halo2::{ - aggregation::load_verify_circuit_degree, aggregation::AggregationCircuit, gen_proof_gwc, - gen_proof_shplonk, gen_snark_gwc, gen_snark_shplonk, PoseidonTranscript, POSEIDON_SPEC, + aggregation::load_verify_circuit_degree, + aggregation::{AggregationCircuit, AggregationConfigParams}, + gen_proof_gwc, gen_proof_shplonk, gen_snark_gwc, gen_snark_shplonk, PoseidonTranscript, + POSEIDON_SPEC, }, - CircuitExt, + CircuitExt, GWC, SHPLONK, }; use std::env::{set_var, var}; use std::path::Path; @@ -51,10 +53,6 @@ pub mod zkevm { } fn bench(c: &mut Criterion) { - let mut rng = ChaCha20Rng::from_entropy(); - let mut transcript = - PoseidonTranscript::::from_spec(vec![], POSEIDON_SPEC.clone()); - // === create zkevm evm circuit snark === let k: u32 = var("DEGREE") .unwrap_or_else(|_| { @@ -66,45 +64,88 @@ fn bench(c: &mut Criterion) { let circuit = zkevm::test_circuit(); let params_app = gen_srs(k); let pk = gen_pk(¶ms_app, &circuit, Some(Path::new("data/zkevm_evm.pkey"))); - let snark = gen_snark_gwc( - ¶ms_app, - &pk, - circuit, - &mut transcript, - &mut rng, - Some(Path::new("data/zkevm_evm.snark")), - ); + let snark = gen_snark_shplonk(¶ms_app, &pk, circuit, None::<&str>); let snarks = [snark]; // === finished zkevm evm circuit === // === now to do aggregation === - set_var("VERIFY_CONFIG", "./configs/bench_zkevm.config"); - let k = load_verify_circuit_degree(); + let path = "./configs/bench_zkevm.json"; + let agg_config = AggregationConfigParams::from_path(path); + let k = agg_config.degree; + let lookup_bits = k as usize - 1; let params = gen_srs(k); - let start1 = start_timer!(|| "Create aggregation circuit"); - let agg_circuit = AggregationCircuit::new(¶ms, snarks, &mut transcript, &mut rng); - end_timer!(start1); + let agg_circuit = AggregationCircuit::keygen::(¶ms, snarks.clone()); + let start1 = start_timer!(|| "gen vk & pk"); let pk = gen_pk(¶ms, &agg_circuit, None); + end_timer!(start1); + let break_points = agg_circuit.break_points(); + + #[cfg(feature = "loader_evm")] + { + let start2 = start_timer!(|| "Create EVM SHPLONK proof"); + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Prover, + Some(break_points.clone()), + lookup_bits, + ¶ms, + snarks.clone(), + ); + let instances = agg_circuit.instances(); + let num_instances = agg_circuit.num_instance(); + + let proof = gen_evm_proof_shplonk(¶ms, &pk, agg_circuit, instances.clone()); + end_timer!(start2); + let deployment_code = gen_evm_verifier_shplonk::( + ¶ms, + pk.get_vk(), + num_instances.clone(), + None, + ); + + evm_verify(deployment_code, instances.clone(), proof); + + let start2 = start_timer!(|| "Create EVM GWC proof"); + let agg_circuit = AggregationCircuit::new::( + // note this is still SHPLONK because it refers to how the evm circuit's snark was generated, NOT how the aggregation proof is going to be generated + CircuitBuilderStage::Prover, + Some(break_points.clone()), + lookup_bits, + ¶ms, + snarks.clone(), + ); + let proof = gen_evm_proof_gwc(¶ms, &pk, agg_circuit, instances.clone()); + end_timer!(start2); + + let deployment_code = gen_evm_verifier_gwc::( + ¶ms, + pk.get_vk(), + num_instances.clone(), + None, + ); + + evm_verify(deployment_code, instances, proof); + } + + // run benches let mut group = c.benchmark_group("shplonk-proof"); group.sample_size(10); group.bench_with_input( BenchmarkId::new("zkevm-evm-agg", k), - &(¶ms, &pk, &agg_circuit), - |b, &(params, pk, agg_circuit)| { + &(¶ms, &pk, &break_points, &snarks), + |b, &(params, pk, break_points, snarks)| { b.iter(|| { - let instances = agg_circuit.instances(); - gen_proof_shplonk( + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Prover, + Some(break_points.clone()), + lookup_bits, params, - pk, - agg_circuit.clone(), - instances, - &mut transcript, - &mut rng, - None, + snarks.clone(), ); + let instances = agg_circuit.instances(); + gen_proof_shplonk(params, pk, agg_circuit, instances, None) }) }, ); @@ -114,51 +155,23 @@ fn bench(c: &mut Criterion) { group.sample_size(10); group.bench_with_input( BenchmarkId::new("zkevm-evm-agg", k), - &(¶ms, &pk, &agg_circuit), - |b, &(params, pk, agg_circuit)| { + &(¶ms, &pk, &break_points, &snarks), + |b, &(params, pk, break_points, snarks)| { b.iter(|| { - let instances = agg_circuit.instances(); - gen_proof_gwc( + // note that the generic here remains SHPLONK because it reflects the multi-open scheme for the previous snark (= the zkevm snark) + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Prover, + Some(break_points.clone()), + lookup_bits, params, - pk, - agg_circuit.clone(), - instances, - &mut transcript, - &mut rng, - None, + snarks.clone(), ); + let instances = agg_circuit.instances(); + gen_proof_gwc(params, pk, agg_circuit, instances, None) }) }, ); group.finish(); - - #[cfg(feature = "loader_evm")] - { - let deployment_code = - gen_evm_verifier_shplonk::(¶ms, pk.get_vk(), &(), None::<&str>); - - let start2 = start_timer!(|| "Create EVM SHPLONK proof"); - let proof = gen_evm_proof_shplonk( - ¶ms, - &pk, - agg_circuit.clone(), - agg_circuit.instances(), - &mut rng, - ); - end_timer!(start2); - - evm_verify(deployment_code, agg_circuit.instances(), proof); - - let deployment_code = - gen_evm_verifier_shplonk::(¶ms, pk.get_vk(), &(), None::<&str>); - - let start2 = start_timer!(|| "Create EVM GWC proof"); - let proof = - gen_evm_proof_gwc(¶ms, &pk, agg_circuit.clone(), agg_circuit.instances(), &mut rng); - end_timer!(start2); - - evm_verify(deployment_code, agg_circuit.instances(), proof); - } } criterion_group!(benches, bench); diff --git a/snark-verifier-sdk/benches/zkevm_plus_state.rs b/snark-verifier-sdk/benches/zkevm_plus_state.rs index 840f8581..a0aabf02 100644 --- a/snark-verifier-sdk/benches/zkevm_plus_state.rs +++ b/snark-verifier-sdk/benches/zkevm_plus_state.rs @@ -16,7 +16,7 @@ use snark_verifier_sdk::{ aggregation::load_verify_circuit_degree, aggregation::AggregationCircuit, gen_proof_gwc, gen_proof_shplonk, gen_snark_gwc, gen_snark_shplonk, PoseidonTranscript, POSEIDON_SPEC, }, - CircuitExt, + CircuitExt, SHPLONK, }; use std::env::{set_var, var}; use std::path::Path; @@ -76,57 +76,93 @@ fn bench(c: &mut Criterion) { let params_app = gen_srs(k); let evm_snark = { let pk = gen_pk(¶ms_app, &evm_circuit, Some(Path::new("data/zkevm_evm.pkey"))); - gen_snark_gwc( - ¶ms_app, - &pk, - evm_circuit, - &mut transcript, - &mut rng, - Some(Path::new("data/zkevm_evm.snark")), - ) + gen_snark_shplonk(¶ms_app, &pk, evm_circuit, None::<&str>) }; let state_snark = { let pk = gen_pk(¶ms_app, &state_circuit, Some(Path::new("data/zkevm_state.pkey"))); - gen_snark_shplonk( - ¶ms_app, - &pk, - state_circuit, - &mut transcript, - &mut rng, - Some(Path::new("data/zkevm_state.snark")), - ) + gen_snark_shplonk(¶ms_app, &pk, state_circuit, None::<&str>) }; let snarks = [evm_snark, state_snark]; // === finished zkevm evm circuit === // === now to do aggregation === - set_var("VERIFY_CONFIG", "./configs/bench_zkevm_plus_state.config"); - let k = load_verify_circuit_degree(); + let path = "./configs/bench_zkevm_plus_state.json"; + // everything below exact same as in zkevm bench + let agg_config = AggregationConfigParams::from_path(path); + let k = agg_config.degree; + let lookup_bits = k as usize - 1; let params = gen_srs(k); - let start1 = start_timer!(|| "Create aggregation circuit"); - let agg_circuit = AggregationCircuit::new(¶ms, snarks, &mut transcript, &mut rng); - end_timer!(start1); + let agg_circuit = AggregationCircuit::keygen::(¶ms, snarks.clone()); + let start1 = start_timer!(|| "gen vk & pk"); let pk = gen_pk(¶ms, &agg_circuit, None); + end_timer!(start1); + let break_points = agg_circuit.break_points(); + + #[cfg(feature = "loader_evm")] + { + let start2 = start_timer!(|| "Create EVM SHPLONK proof"); + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Prover, + Some(break_points.clone()), + lookup_bits, + ¶ms, + snarks.clone(), + ); + let instances = agg_circuit.instances(); + let num_instances = agg_circuit.num_instance(); + + let proof = gen_evm_proof_shplonk(¶ms, &pk, agg_circuit, instances.clone()); + end_timer!(start2); + let deployment_code = gen_evm_verifier_shplonk::( + ¶ms, + pk.get_vk(), + num_instances.clone(), + None, + ); + + evm_verify(deployment_code, instances.clone(), proof); + + let start2 = start_timer!(|| "Create EVM GWC proof"); + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Prover, + Some(break_points.clone()), + lookup_bits, + ¶ms, + snarks.clone(), + ); + let proof = gen_evm_proof_gwc(¶ms, &pk, agg_circuit, instances.clone()); + end_timer!(start2); + + let deployment_code = gen_evm_verifier_gwc::( + ¶ms, + pk.get_vk(), + num_instances.clone(), + None, + ); + + evm_verify(deployment_code, instances, proof); + } + + // run benches let mut group = c.benchmark_group("shplonk-proof"); group.sample_size(10); group.bench_with_input( - BenchmarkId::new("zkevm-evm-state-agg", k), - &(¶ms, &pk, &agg_circuit), - |b, &(params, pk, agg_circuit)| { + BenchmarkId::new("zkevm-evm-agg", k), + &(¶ms, &pk, &break_points, &snarks), + |b, &(params, pk, break_points, snarks)| { b.iter(|| { - let instances = agg_circuit.instances(); - gen_proof_shplonk( + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Prover, + Some(break_points.clone()), + lookup_bits, params, - pk, - agg_circuit.clone(), - instances, - &mut transcript, - &mut rng, - None, + snarks.clone(), ); + let instances = agg_circuit.instances(); + gen_proof_shplonk(params, pk, agg_circuit, instances, None) }) }, ); @@ -135,52 +171,24 @@ fn bench(c: &mut Criterion) { let mut group = c.benchmark_group("gwc-proof"); group.sample_size(10); group.bench_with_input( - BenchmarkId::new("zkevm-evm-state-agg", k), - &(¶ms, &pk, &agg_circuit), - |b, &(params, pk, agg_circuit)| { + BenchmarkId::new("zkevm-evm-agg", k), + &(¶ms, &pk, &break_points, &snarks), + |b, &(params, pk, break_points, snarks)| { b.iter(|| { - let instances = agg_circuit.instances(); - gen_proof_gwc( + // note that the generic here remains SHPLONK because it reflects the multi-open scheme for the previous snark (= the zkevm snark) + let agg_circuit = AggregationCircuit::new::( + CircuitBuilderStage::Prover, + Some(break_points.clone()), + lookup_bits, params, - pk, - agg_circuit.clone(), - instances, - &mut transcript, - &mut rng, - None, + snarks.clone(), ); + let instances = agg_circuit.instances(); + gen_proof_gwc(params, pk, agg_circuit, instances, None) }) }, ); group.finish(); - - #[cfg(feature = "loader_evm")] - { - let deployment_code = - gen_evm_verifier_shplonk::(¶ms, pk.get_vk(), &(), None::<&str>); - - let start2 = start_timer!(|| "Create EVM SHPLONK proof"); - let proof = gen_evm_proof_shplonk( - ¶ms, - &pk, - agg_circuit.clone(), - agg_circuit.instances(), - &mut rng, - ); - end_timer!(start2); - - evm_verify(deployment_code, agg_circuit.instances(), proof); - - let deployment_code = - gen_evm_verifier_shplonk::(¶ms, pk.get_vk(), &(), None::<&str>); - - let start2 = start_timer!(|| "Create EVM GWC proof"); - let proof = - gen_evm_proof_gwc(¶ms, &pk, agg_circuit.clone(), agg_circuit.instances(), &mut rng); - end_timer!(start2); - - evm_verify(deployment_code, agg_circuit.instances(), proof); - } } criterion_group!(benches, bench); diff --git a/snark-verifier-sdk/configs/bench_zkevm.config b/snark-verifier-sdk/configs/bench_zkevm.config deleted file mode 100644 index 5156de4f..00000000 --- a/snark-verifier-sdk/configs/bench_zkevm.config +++ /dev/null @@ -1 +0,0 @@ -{"strategy":"Simple","degree":23,"num_advice":[5],"num_lookup_advice":[1],"num_fixed":1,"lookup_bits":22,"limb_bits":88,"num_limbs":3} diff --git a/snark-verifier-sdk/configs/bench_zkevm.json b/snark-verifier-sdk/configs/bench_zkevm.json new file mode 100644 index 00000000..0d2a05c4 --- /dev/null +++ b/snark-verifier-sdk/configs/bench_zkevm.json @@ -0,0 +1,7 @@ +{ + "degree": 23, + "num_advice": 5, + "num_lookup_advice": 1, + "num_fixed": 1, + "lookup_bits": 22 +} diff --git a/snark-verifier-sdk/configs/bench_zkevm_plus_state.json b/snark-verifier-sdk/configs/bench_zkevm_plus_state.json new file mode 100644 index 00000000..03412bff --- /dev/null +++ b/snark-verifier-sdk/configs/bench_zkevm_plus_state.json @@ -0,0 +1,7 @@ +{ + "degree": 24, + "num_advice": 5, + "num_lookup_advice": 1, + "num_fixed": 1, + "lookup_bits": 22 +} diff --git a/snark-verifier-sdk/configs/example_evm_accumulator.config b/snark-verifier-sdk/configs/example_evm_accumulator.config deleted file mode 100644 index 156c402a..00000000 --- a/snark-verifier-sdk/configs/example_evm_accumulator.config +++ /dev/null @@ -1 +0,0 @@ -{"strategy":"Simple","degree":21,"num_advice":[5],"num_lookup_advice":[1],"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} diff --git a/snark-verifier-sdk/configs/example_evm_accumulator.json b/snark-verifier-sdk/configs/example_evm_accumulator.json new file mode 100644 index 00000000..38282d4b --- /dev/null +++ b/snark-verifier-sdk/configs/example_evm_accumulator.json @@ -0,0 +1,10 @@ +{ + "strategy": "Simple", + "degree": 21, + "num_advice": 5, + "num_lookup_advice": 1, + "num_fixed": 1, + "lookup_bits": 20, + "limb_bits": 88, + "num_limbs": 3 +} diff --git a/snark-verifier-sdk/configs/verify_circuit.config b/snark-verifier-sdk/configs/verify_circuit.config deleted file mode 100644 index 90aff847..00000000 --- a/snark-verifier-sdk/configs/verify_circuit.config +++ /dev/null @@ -1 +0,0 @@ -{"strategy":"Simple","degree":21,"num_advice":[4],"num_lookup_advice":[1],"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} diff --git a/snark-verifier-sdk/src/evm.rs b/snark-verifier-sdk/src/evm.rs index 456a9bb4..f423d327 100644 --- a/snark-verifier-sdk/src/evm.rs +++ b/snark-verifier-sdk/src/evm.rs @@ -1,13 +1,14 @@ -use super::{CircuitExt, Plonk}; +use crate::{GWC, SHPLONK}; + +use super::{CircuitExt, PlonkVerifier}; #[cfg(feature = "display")] use ark_std::{end_timer, start_timer}; use ethereum_types::Address; use halo2_base::halo2_proofs::{ - dev::MockProver, halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, poly::{ - commitment::{Params, ParamsProver, Prover, Verifier}, + commitment::{ParamsProver, Prover, Verifier}, kzg::{ commitment::{KZGCommitmentScheme, ParamsKZG}, msm::DualMSM, @@ -19,16 +20,16 @@ use halo2_base::halo2_proofs::{ transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }; use itertools::Itertools; -use rand::Rng; +use rand::{rngs::StdRng, SeedableRng}; pub use snark_verifier::loader::evm::encode_calldata; use snark_verifier::{ loader::evm::{compile_yul, EvmLoader, ExecutorBuilder}, pcs::{ - kzg::{Bdfg21, Gwc19, Kzg, KzgAccumulator, KzgDecidingKey, KzgSuccinctVerifyingKey}, - Decider, MultiOpenScheme, PolynomialCommitmentScheme, + kzg::{KzgAccumulator, KzgAsVerifyingKey, KzgDecidingKey, KzgSuccinctVerifyingKey}, + AccumulationDecider, AccumulationScheme, PolynomialCommitmentScheme, }, system::halo2::{compile, transcript::evm::EvmTranscript, Config}, - verifier::PlonkVerifier, + verifier::SnarkVerifier, }; use std::{fs, io, path::Path, rc::Rc}; @@ -38,7 +39,6 @@ pub fn gen_evm_proof<'params, C, P, V>( pk: &'params ProvingKey, circuit: C, instances: Vec>, - rng: &mut (impl Rng + Send), ) -> Vec where C: Circuit, @@ -50,15 +50,11 @@ where MSMAccumulator = DualMSM<'params, Bn256>, >, { - #[cfg(debug_assertions)] - { - MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); - } - let instances = instances.iter().map(|instances| instances.as_slice()).collect_vec(); #[cfg(feature = "display")] let proof_time = start_timer!(|| "Create EVM proof"); + let rng = StdRng::from_entropy(); let proof = { let mut transcript = TranscriptWriterBuffer::<_, G1Affine, _>::init(Vec::new()); create_proof::, P, _, _, EvmTranscript<_, _, _, _>, _>( @@ -98,9 +94,8 @@ pub fn gen_evm_proof_gwc<'params, C: Circuit>( pk: &'params ProvingKey, circuit: C, instances: Vec>, - rng: &mut (impl Rng + Send), ) -> Vec { - gen_evm_proof::, VerifierGWC<_>>(params, pk, circuit, instances, rng) + gen_evm_proof::, VerifierGWC<_>>(params, pk, circuit, instances) } pub fn gen_evm_proof_shplonk<'params, C: Circuit>( @@ -108,12 +103,23 @@ pub fn gen_evm_proof_shplonk<'params, C: Circuit>( pk: &'params ProvingKey, circuit: C, instances: Vec>, - rng: &mut (impl Rng + Send), ) -> Vec { - gen_evm_proof::, VerifierSHPLONK<_>>(params, pk, circuit, instances, rng) + gen_evm_proof::, VerifierSHPLONK<_>>(params, pk, circuit, instances) } -pub fn gen_evm_verifier( +pub trait EvmKzgAccumulationScheme = PolynomialCommitmentScheme< + G1Affine, + Rc, + VerifyingKey = KzgSuccinctVerifyingKey, + Output = KzgAccumulator>, + > + AccumulationScheme< + G1Affine, + Rc, + VerifyingKey = KzgAsVerifyingKey, + Accumulator = KzgAccumulator>, + > + AccumulationDecider, DecidingKey = KzgDecidingKey>; + +pub fn gen_evm_verifier( params: &ParamsKZG, vk: &VerifyingKey, num_instance: Vec, @@ -121,18 +127,8 @@ pub fn gen_evm_verifier( ) -> Vec where C: CircuitExt, - PCS: PolynomialCommitmentScheme< - G1Affine, - Rc, - Accumulator = KzgAccumulator>, - > + MultiOpenScheme< - G1Affine, - Rc, - SuccinctVerifyingKey = KzgSuccinctVerifyingKey, - > + Decider, DecidingKey = KzgDecidingKey>, + AS: EvmKzgAccumulationScheme, { - let svk = params.get_g()[0].into(); - let dk = (params.g2(), params.s_g2()).into(); let protocol = compile( params, vk, @@ -140,14 +136,17 @@ where .with_num_instance(num_instance.clone()) .with_accumulator_indices(C::accumulator_indices()), ); + // deciding key + let dk = (params.get_g()[0], params.g2(), params.s_g2()).into(); let loader = EvmLoader::new::(); let protocol = protocol.loaded(&loader); let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); let instances = transcript.load_instances(num_instance); - let proof = Plonk::::read_proof(&svk, &protocol, &instances, &mut transcript); - Plonk::::verify(&svk, &dk, &protocol, &instances, &proof); + let proof = + PlonkVerifier::::read_proof(&dk, &protocol, &instances, &mut transcript).unwrap(); + PlonkVerifier::::verify(&dk, &protocol, &instances, &proof).unwrap(); let yul_code = loader.yul_code(); let byte_code = compile_yul(&yul_code); @@ -164,7 +163,7 @@ pub fn gen_evm_verifier_gwc>( num_instance: Vec, path: Option<&Path>, ) -> Vec { - gen_evm_verifier::>(params, vk, num_instance, path) + gen_evm_verifier::(params, vk, num_instance, path) } pub fn gen_evm_verifier_shplonk>( @@ -173,7 +172,7 @@ pub fn gen_evm_verifier_shplonk>( num_instance: Vec, path: Option<&Path>, ) -> Vec { - gen_evm_verifier::>(params, vk, num_instance, path) + gen_evm_verifier::(params, vk, num_instance, path) } pub fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { diff --git a/snark-verifier-sdk/src/halo2.rs b/snark-verifier-sdk/src/halo2.rs index a4526911..b0710230 100644 --- a/snark-verifier-sdk/src/halo2.rs +++ b/snark-verifier-sdk/src/halo2.rs @@ -1,10 +1,9 @@ -use super::{read_instances, write_instances, CircuitExt, Snark, SnarkWitness}; +use super::{read_instances, write_instances, CircuitExt, PlonkSuccinctVerifier, Snark}; #[cfg(feature = "display")] use ark_std::{end_timer, start_timer}; use halo2_base::halo2_proofs; use halo2_proofs::{ circuit::Layouter, - dev::MockProver, halo2curves::{ bn256::{Bn256, Fr, G1Affine}, group::ff::Field, @@ -14,7 +13,7 @@ use halo2_proofs::{ VerifyingKey, }, poly::{ - commitment::{Params, ParamsProver, Prover, Verifier}, + commitment::{ParamsProver, Prover, Verifier}, kzg::{ commitment::{KZGCommitmentScheme, ParamsKZG}, msm::DualMSM, @@ -26,15 +25,18 @@ use halo2_proofs::{ }; use itertools::Itertools; use lazy_static::lazy_static; -use rand::Rng; +use rand::{rngs::StdRng, SeedableRng}; +pub use snark_verifier::util::hash::OptimizedPoseidonSpec; use snark_verifier::{ cost::CostEstimation, loader::native::NativeLoader, - pcs::{self, MultiOpenScheme}, + pcs::{ + kzg::{KzgAccumulator, KzgAsVerifyingKey, KzgSuccinctVerifyingKey}, + AccumulationScheme, PolynomialCommitmentScheme, Query, + }, system::halo2::{compile, Config}, util::transcript::TranscriptWrite, - verifier::PlonkProof, - PoseidonSpec, + verifier::plonk::PlonkProof, }; use std::{ fs::{self, File}, @@ -45,10 +47,13 @@ use std::{ pub mod aggregation; // Poseidon parameters -const T: usize = 5; -const RATE: usize = 4; -const R_F: usize = 8; -const R_P: usize = 60; +// We use the same ones Scroll uses for security: https://github.com/scroll-tech/poseidon-circuit/blob/714f50c7572a4ff6f2b1fa51a9604a99cd7b6c71/src/poseidon/primitives/bn256/fp.rs +// Verify generated constants: https://github.com/scroll-tech/poseidon-circuit/blob/714f50c7572a4ff6f2b1fa51a9604a99cd7b6c71/src/poseidon/primitives/bn256/mod.rs#L65 +const T: usize = 3; +const RATE: usize = 2; +const R_F: usize = 8; // https://github.com/scroll-tech/poseidon-circuit/blob/714f50c7572a4ff6f2b1fa51a9604a99cd7b6c71/src/poseidon/primitives/p128pow5t3.rs#L26 +const R_P: usize = 57; // https://github.com/scroll-tech/poseidon-circuit/blob/714f50c7572a4ff6f2b1fa51a9604a99cd7b6c71/src/poseidon/primitives/bn256/mod.rs#L8 +const SECURE_MDS: usize = 0; pub type PoseidonTranscript = snark_verifier::system::halo2::transcript::halo2::PoseidonTranscript< @@ -62,7 +67,8 @@ pub type PoseidonTranscript = >; lazy_static! { - pub static ref POSEIDON_SPEC: PoseidonSpec = PoseidonSpec::new(R_F, R_P); + pub static ref POSEIDON_SPEC: OptimizedPoseidonSpec = + OptimizedPoseidonSpec::new::(); } /// Generates a native proof using either SHPLONK or GWC proving method. Uses Poseidon for Fiat-Shamir. @@ -74,7 +80,6 @@ pub fn gen_proof<'params, C, P, V>( pk: &ProvingKey, circuit: C, instances: Vec>, - rng: &mut (impl Rng + Send), path: Option<(&Path, &Path)>, ) -> Vec where @@ -87,11 +92,6 @@ where MSMAccumulator = DualMSM<'params, Bn256>, >, { - #[cfg(debug_assertions)] - { - MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); - } - if let Some((instance_path, proof_path)) = path { let cached_instances = read_instances(instance_path); if matches!(cached_instances, Ok(tmp) if tmp == instances) && proof_path.exists() { @@ -113,6 +113,7 @@ where let mut transcript = PoseidonTranscript::>::from_spec(vec![], POSEIDON_SPEC.clone()); + let rng = StdRng::from_entropy(); create_proof::<_, P, _, _, _, _>(params, pk, &[circuit], &[&instances], rng, &mut transcript) .unwrap(); let proof = transcript.finalize(); @@ -120,13 +121,10 @@ where #[cfg(feature = "display")] end_timer!(proof_time); - if let Some((instance_path, proof_path)) = path { - write_instances(&instances, instance_path); - fs::write(proof_path, &proof).unwrap(); - } - - debug_assert!({ - let mut transcript_read = PoseidonTranscript::::new(proof.as_slice()); + // validate proof before caching + assert!({ + let mut transcript_read = + PoseidonTranscript::::from_spec(&proof[..], POSEIDON_SPEC.clone()); VerificationStrategy::<_, V>::finalize( verify_proof::<_, V, _, _, _>( params.verifier_params(), @@ -139,6 +137,11 @@ where ) }); + if let Some((instance_path, proof_path)) = path { + write_instances(&instances, instance_path); + fs::write(proof_path, &proof).unwrap(); + } + proof } @@ -150,10 +153,9 @@ pub fn gen_proof_gwc>( pk: &ProvingKey, circuit: C, instances: Vec>, - rng: &mut (impl Rng + Send), path: Option<(&Path, &Path)>, ) -> Vec { - gen_proof::, VerifierGWC<_>>(params, pk, circuit, instances, rng, path) + gen_proof::, VerifierGWC<_>>(params, pk, circuit, instances, path) } /// Generates a native proof using SHPLONK multi-open scheme. Uses Poseidon for Fiat-Shamir. @@ -164,10 +166,9 @@ pub fn gen_proof_shplonk>( pk: &ProvingKey, circuit: C, instances: Vec>, - rng: &mut (impl Rng + Send), path: Option<(&Path, &Path)>, ) -> Vec { - gen_proof::, VerifierSHPLONK<_>>(params, pk, circuit, instances, rng, path) + gen_proof::, VerifierSHPLONK<_>>(params, pk, circuit, instances, path) } /// Generates a SNARK using either SHPLONK or GWC multi-open scheme. Uses Poseidon for Fiat-Shamir. @@ -178,7 +179,6 @@ pub fn gen_snark<'params, ConcreteCircuit, P, V>( params: &'params ParamsKZG, pk: &ProvingKey, circuit: ConcreteCircuit, - rng: &mut (impl Rng + Send), path: Option>, ) -> Snark where @@ -191,6 +191,7 @@ where MSMAccumulator = DualMSM<'params, Bn256>, >, { + #[cfg(feature = "halo2-axiom")] if let Some(path) = &path { if let Ok(snark) = read_snark(path) { return snark; @@ -205,10 +206,21 @@ where ); let instances = circuit.instances(); - let proof = - gen_proof::(params, pk, circuit, instances.clone(), rng, None); + #[cfg(feature = "halo2-axiom")] + let proof = gen_proof::(params, pk, circuit, instances.clone(), None); + // If we can't serialize the entire snark, at least serialize the proof + #[cfg(not(feature = "halo2-axiom"))] + let proof = { + let path = path.map(|path| { + let path = path.as_ref().to_str().unwrap(); + (format!("{path}.instances"), format!("{path}.proof")) + }); + let paths = path.as_ref().map(|path| (Path::new(&path.0), Path::new(&path.1))); + gen_proof::(params, pk, circuit, instances.clone(), paths) + }; let snark = Snark::new(protocol, instances, proof); + #[cfg(feature = "halo2-axiom")] if let Some(path) = &path { let f = File::create(path).unwrap(); #[cfg(feature = "display")] @@ -217,6 +229,7 @@ where #[cfg(feature = "display")] end_timer!(write_time); } + #[allow(clippy::let_and_return)] snark } @@ -228,10 +241,9 @@ pub fn gen_snark_gwc>( params: &ParamsKZG, pk: &ProvingKey, circuit: ConcreteCircuit, - rng: &mut (impl Rng + Send), path: Option>, ) -> Snark { - gen_snark::, VerifierGWC<_>>(params, pk, circuit, rng, path) + gen_snark::, VerifierGWC<_>>(params, pk, circuit, path) } /// Generates a SNARK using SHPLONK multi-open scheme. Uses Poseidon for Fiat-Shamir. @@ -242,31 +254,39 @@ pub fn gen_snark_shplonk>( params: &ParamsKZG, pk: &ProvingKey, circuit: ConcreteCircuit, - rng: &mut (impl Rng + Send), path: Option>, ) -> Snark { - gen_snark::, VerifierSHPLONK<_>>( - params, pk, circuit, rng, path, - ) + gen_snark::, VerifierSHPLONK<_>>(params, pk, circuit, path) } /// Tries to deserialize a SNARK from the specified `path` using `bincode`. /// /// WARNING: The user must keep track of whether the SNARK was generated using the GWC or SHPLONK multi-open scheme. +#[cfg(feature = "halo2-axiom")] pub fn read_snark(path: impl AsRef) -> Result { let f = File::open(path).map_err(Box::::from)?; bincode::deserialize_from(f) } -pub fn gen_dummy_snark( +// copied from snark_verifier --example recursion +pub fn gen_dummy_snark( params: &ParamsKZG, vk: Option<&VerifyingKey>, num_instance: Vec, ) -> Snark where ConcreteCircuit: CircuitExt, - MOS: MultiOpenScheme - + CostEstimation>>, + AS: PolynomialCommitmentScheme< + G1Affine, + NativeLoader, + VerifyingKey = KzgSuccinctVerifyingKey, + Output = KzgAccumulator, + > + AccumulationScheme< + G1Affine, + NativeLoader, + Accumulator = KzgAccumulator, + VerifyingKey = KzgAsVerifyingKey, + > + CostEstimation>>, { struct CsProxy(PhantomData<(F, C)>); @@ -314,7 +334,7 @@ where ); let instances = num_instance.into_iter().map(|n| vec![Fr::default(); n]).collect(); let proof = { - let mut transcript = PoseidonTranscript::::new(Vec::new()); + let mut transcript = PoseidonTranscript::::new::(Vec::new()); for _ in 0..protocol .num_witness .iter() @@ -326,8 +346,8 @@ where for _ in 0..protocol.evaluations.len() { transcript.write_scalar(Fr::default()).unwrap(); } - let queries = PlonkProof::::empty_queries(&protocol); - for _ in 0..MOS::estimate_cost(&queries).num_commitment { + let queries = PlonkProof::::empty_queries(&protocol); + for _ in 0..AS::estimate_cost(&queries).num_commitment { transcript.write_ec_point(G1Affine::default()).unwrap(); } transcript.finalize() diff --git a/snark-verifier-sdk/src/halo2/aggregation.rs b/snark-verifier-sdk/src/halo2/aggregation.rs index 4748c930..b5dc148d 100644 --- a/snark-verifier-sdk/src/halo2/aggregation.rs +++ b/snark-verifier-sdk/src/halo2/aggregation.rs @@ -1,71 +1,55 @@ -#![allow(clippy::clone_on_copy)] -use crate::{Plonk, BITS, LIMBS}; -#[cfg(feature = "display")] -use ark_std::{end_timer, start_timer}; +use super::PlonkSuccinctVerifier; +use crate::{BITS, LIMBS}; use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, FlexGateConfigParams, GateThreadBuilder, + MultiPhaseThreadBreakPoints, RangeCircuitBuilder, RangeWithInstanceCircuitBuilder, + RangeWithInstanceConfig, + }, + RangeChip, + }, halo2_proofs::{ - circuit::{Layouter, SimpleFloorPlanner, Value}, - halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, - plonk::{self, Circuit, Column, ConstraintSystem, Instance, Selector}, - poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, + circuit::{Layouter, SimpleFloorPlanner}, + halo2curves::bn256::{Bn256, Fr, G1Affine}, + plonk::{self, Circuit, ConstraintSystem, Selector}, + poly::{ + commitment::{Params, ParamsProver}, + kzg::commitment::ParamsKZG, + }, }, - utils::value_to_option, + utils::ScalarField, AssignedValue, }; -use halo2_base::{Context, ContextParams}; use itertools::Itertools; -use rand::Rng; +use rand::{rngs::StdRng, SeedableRng}; +use serde::{Deserialize, Serialize}; +#[cfg(debug_assertions)] +use snark_verifier::util::arithmetic::fe_to_limbs; use snark_verifier::{ loader::{ self, - halo2::{ - halo2_ecc::{self, ecc::EccChip}, - EccInstructions, - }, + halo2::halo2_ecc::{self, bn254::FpChip}, native::NativeLoader, }, pcs::{ - kzg::{Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey}, - AccumulationScheme, AccumulationSchemeProver, MultiOpenScheme, PolynomialCommitmentScheme, + kzg::{KzgAccumulator, KzgAsProvingKey, KzgAsVerifyingKey, KzgSuccinctVerifyingKey}, + AccumulationScheme, AccumulationSchemeProver, PolynomialCommitmentScheme, }, - util::arithmetic::fe_to_limbs, - verifier::PlonkVerifier, + verifier::SnarkVerifier, +}; +use std::{ + env::{set_var, var}, + fs::File, + path::Path, + rc::Rc, }; -use std::{fs::File, rc::Rc}; -use super::{CircuitExt, PoseidonTranscript, Snark, SnarkWitness, POSEIDON_SPEC}; +use super::{CircuitExt, PoseidonTranscript, Snark, POSEIDON_SPEC}; pub type Svk = KzgSuccinctVerifyingKey; -pub type BaseFieldEccChip = halo2_ecc::ecc::BaseFieldEccChip; -pub type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; -pub type Shplonk = Plonk>; - -pub fn load_verify_circuit_degree() -> u32 { - let path = std::env::var("VERIFY_CONFIG") - .unwrap_or_else(|_| "./configs/verify_circuit.config".to_string()); - let params: AggregationConfigParams = serde_json::from_reader( - File::open(path.as_str()).unwrap_or_else(|_| panic!("{path} does not exist")), - ) - .unwrap(); - params.degree -} - -pub fn flatten_accumulator<'b, 'a: 'b>( - accumulator: KzgAccumulator>>, -) -> Vec> { - let KzgAccumulator { lhs, rhs } = accumulator; - let lhs = lhs.into_assigned(); - let rhs = rhs.into_assigned(); - - lhs.x - .truncation - .limbs - .into_iter() - .chain(lhs.y.truncation.limbs.into_iter()) - .chain(rhs.x.truncation.limbs.into_iter()) - .chain(rhs.y.truncation.limbs.into_iter()) - .collect() -} +pub type BaseFieldEccChip<'chip> = halo2_ecc::ecc::BaseFieldEccChip<'chip, G1Affine>; +pub type Halo2Loader<'chip> = loader::halo2::Halo2Loader>; #[allow(clippy::type_complexity)] /// Core function used in `synthesize` to aggregate multiple `snarks`. @@ -73,23 +57,30 @@ pub fn flatten_accumulator<'b, 'a: 'b>( /// Returns the assigned instances of previous snarks and the new final pair that needs to be verified in a pairing check. /// For each previous snark, we concatenate all instances into a single vector. We return a vector of vectors, /// one vector per snark, for convenience. -pub fn aggregate<'a, PCS>( - svk: &PCS::SuccinctVerifyingKey, +/// +/// # Assumptions +/// * `snarks` is not empty +pub fn aggregate<'a, AS>( + svk: &Svk, loader: &Rc>, - snarks: &[SnarkWitness], - as_proof: Value<&'_ [u8]>, -) -> ( - Vec>::AssignedScalar>>, - KzgAccumulator>>, -) + snarks: &[Snark], + as_proof: &[u8], +) -> (Vec>>, KzgAccumulator>>) where - PCS: PolynomialCommitmentScheme< + AS: PolynomialCommitmentScheme< + G1Affine, + Rc>, + VerifyingKey = Svk, + Output = KzgAccumulator>>, + > + AccumulationScheme< G1Affine, Rc>, Accumulator = KzgAccumulator>>, - > + MultiOpenScheme>>, + VerifyingKey = KzgAsVerifyingKey, + >, { - let assign_instances = |instances: &[Vec>]| { + assert!(!snarks.is_empty(), "trying to aggregate 0 snarks"); + let assign_instances = |instances: &[Vec]| { instances .iter() .map(|instances| { @@ -98,11 +89,11 @@ where .collect_vec() }; - // TODO pre-allocate capacity better let mut previous_instances = Vec::with_capacity(snarks.len()); - let mut transcript = PoseidonTranscript::>, _>::from_spec( + // to avoid re-loading the spec each time, we create one transcript and clear the stream + let mut transcript = PoseidonTranscript::>, &[u8]>::from_spec( loader, - Value::unknown(), + &[], POSEIDON_SPEC.clone(), ); @@ -110,14 +101,20 @@ where .iter() .flat_map(|snark| { let protocol = snark.protocol.loaded(loader); - // TODO use 1d vector let instances = assign_instances(&snark.instances); // read the transcript and perform Fiat-Shamir // run through verification computation and produce the final pair `succinct` transcript.new_stream(snark.proof()); - let proof = Plonk::::read_proof(svk, &protocol, &instances, &mut transcript); - let accumulator = Plonk::::succinct_verify(svk, &protocol, &instances, &proof); + let proof = PlonkSuccinctVerifier::::read_proof( + svk, + &protocol, + &instances, + &mut transcript, + ) + .unwrap(); + let accumulator = + PlonkSuccinctVerifier::::verify(svk, &protocol, &instances, &proof).unwrap(); previous_instances.push( instances.into_iter().flatten().map(|scalar| scalar.into_assigned()).collect(), @@ -129,9 +126,14 @@ where let accumulator = if accumulators.len() > 1 { transcript.new_stream(as_proof); - let proof = - KzgAs::::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); - KzgAs::::verify(&Default::default(), &accumulators, &proof).unwrap() + let proof = >::read_proof( + &Default::default(), + &accumulators, + &mut transcript, + ) + .unwrap(); + >::verify(&Default::default(), &accumulators, &proof) + .unwrap() } else { accumulators.pop().unwrap() }; @@ -139,297 +141,283 @@ where (previous_instances, accumulator) } -#[derive(serde::Serialize, serde::Deserialize)] +/// Same as `FlexGateConfigParams` except we assume a single Phase and default 'Vertical' strategy. +/// Also adds `lookup_bits` field. +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct AggregationConfigParams { - pub strategy: halo2_ecc::fields::fp::FpStrategy, pub degree: u32, - pub num_advice: Vec, - pub num_lookup_advice: Vec, + pub num_advice: usize, + pub num_lookup_advice: usize, pub num_fixed: usize, pub lookup_bits: usize, - pub limb_bits: usize, - pub num_limbs: usize, } -#[derive(Clone, Debug)] -pub struct AggregationConfig { - pub base_field_config: halo2_ecc::fields::fp::FpConfig, - pub instance: Column, -} - -impl AggregationConfig { - pub fn configure(meta: &mut ConstraintSystem, params: AggregationConfigParams) -> Self { - assert!( - params.limb_bits == BITS && params.num_limbs == LIMBS, - "For now we fix limb_bits = {}, otherwise change code", - BITS - ); - let base_field_config = halo2_ecc::fields::fp::FpConfig::configure( - meta, - params.strategy, - ¶ms.num_advice, - ¶ms.num_lookup_advice, - params.num_fixed, - params.lookup_bits, - BITS, - LIMBS, - halo2_base::utils::modulus::(), - 0, - params.degree as usize, - ); - - let instance = meta.instance_column(); - meta.enable_equality(instance); - - Self { base_field_config, instance } - } - - pub fn range(&self) -> &halo2_base::gates::range::RangeConfig { - &self.base_field_config.range - } - - pub fn gate(&self) -> &halo2_base::gates::flex_gate::FlexGateConfig { - &self.base_field_config.range.gate - } - - pub fn ecc_chip(&self) -> halo2_ecc::ecc::BaseFieldEccChip { - EccChip::construct(self.base_field_config.clone()) +impl AggregationConfigParams { + pub fn from_path(path: impl AsRef) -> Self { + serde_json::from_reader(File::open(path).expect("Aggregation config path does not exist")) + .unwrap() } } -/// Aggregation circuit that does not re-expose any public inputs from aggregated snarks -/// -/// This is mostly a reference implementation. In practice one will probably need to re-implement the circuit for one's particular use case with specific instance logic. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct AggregationCircuit { - svk: Svk, - snarks: Vec, - instances: Vec, - as_proof: Value>, + pub inner: RangeWithInstanceCircuitBuilder, + // the public instances from previous snarks that were aggregated, now collected as PRIVATE assigned values + // the user can optionally append these to `inner.assigned_instances` to expose them + pub previous_instances: Vec>>, + // accumulation scheme proof, private input + pub as_proof: Vec, // not sure this needs to be stored, keeping for now } +// trait just so we can have a generic that is either SHPLONK or GWC +pub trait Halo2KzgAccumulationScheme<'a> = PolynomialCommitmentScheme< + G1Affine, + Rc>, + VerifyingKey = Svk, + Output = KzgAccumulator>>, + > + AccumulationScheme< + G1Affine, + Rc>, + Accumulator = KzgAccumulator>>, + VerifyingKey = KzgAsVerifyingKey, + > + PolynomialCommitmentScheme< + G1Affine, + NativeLoader, + VerifyingKey = Svk, + Output = KzgAccumulator, + > + AccumulationScheme< + G1Affine, + NativeLoader, + Accumulator = KzgAccumulator, + VerifyingKey = KzgAsVerifyingKey, + > + AccumulationSchemeProver>; + impl AggregationCircuit { - pub fn new( + /// Given snarks, this creates a circuit and runs the `GateThreadBuilder` to verify all the snarks. + /// By default, the returned circuit has public instances equal to the limbs of the pair of elliptic curve points, referred to as the `accumulator`, that need to be verified in a final pairing check. + /// + /// The user can optionally modify the circuit after calling this function to add more instances to `assigned_instances` to expose. + /// + /// Warning: will fail silently if `snarks` were created using a different multi-open scheme than `AS` + /// where `AS` can be either [`crate::SHPLONK`] or [`crate::GWC`] (for original PLONK multi-open scheme) + pub fn new( + stage: CircuitBuilderStage, + break_points: Option, + lookup_bits: usize, params: &ParamsKZG, snarks: impl IntoIterator, - rng: impl Rng + Send, - ) -> Self { - let svk = params.get_g()[0].into(); + ) -> Self + where + AS: for<'a> Halo2KzgAccumulationScheme<'a>, + { + let svk: Svk = params.get_g()[0].into(); let snarks = snarks.into_iter().collect_vec(); - // TODO: this is all redundant calculation to get the public output - // Halo2 should just be able to expose public output to instance column directly let mut transcript_read = PoseidonTranscript::::from_spec(&[], POSEIDON_SPEC.clone()); + // TODO: the snarks can probably store these accumulators let accumulators = snarks .iter() .flat_map(|snark| { - transcript_read.new_stream(snark.proof.as_slice()); - let proof = Shplonk::read_proof( + transcript_read.new_stream(snark.proof()); + let proof = PlonkSuccinctVerifier::::read_proof( &svk, &snark.protocol, &snark.instances, &mut transcript_read, - ); - Shplonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof) + ) + .unwrap(); + PlonkSuccinctVerifier::::verify(&svk, &snark.protocol, &snark.instances, &proof) + .unwrap() }) .collect_vec(); - let (accumulator, as_proof) = { + let (_accumulator, as_proof) = { let mut transcript_write = PoseidonTranscript::>::from_spec( vec![], POSEIDON_SPEC.clone(), ); - // We always use SHPLONK for accumulation scheme when aggregating proofs - let accumulator = KzgAs::>::create_proof( - &Default::default(), - &accumulators, - &mut transcript_write, - rng, - ) - .unwrap(); + let rng = StdRng::from_entropy(); + let accumulator = + AS::create_proof(&Default::default(), &accumulators, &mut transcript_write, rng) + .unwrap(); (accumulator, transcript_write.finalize()) }; - let KzgAccumulator { lhs, rhs } = accumulator; - let instances = [lhs.x, lhs.y, rhs.x, rhs.y].map(fe_to_limbs::<_, _, LIMBS, BITS>).concat(); + // create thread builder and run aggregation witness gen + let builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + // create halo2loader + let range = RangeChip::::default(lookup_bits); + let fp_chip = FpChip::::new(&range, BITS, LIMBS); + let ecc_chip = BaseFieldEccChip::new(&fp_chip); + let loader = Halo2Loader::new(ecc_chip, builder); + + let (previous_instances, accumulator) = + aggregate::(&svk, &loader, &snarks, as_proof.as_slice()); + let lhs = accumulator.lhs.assigned(); + let rhs = accumulator.rhs.assigned(); + let assigned_instances = lhs + .x() + .limbs() + .iter() + .chain(lhs.y().limbs().iter()) + .chain(rhs.x().limbs().iter()) + .chain(rhs.y().limbs().iter()) + .copied() + .collect_vec(); - Self { - svk, - snarks: snarks.into_iter().map_into().collect(), - instances, - as_proof: Value::known(as_proof), + #[cfg(debug_assertions)] + { + let KzgAccumulator { lhs, rhs } = _accumulator; + let instances = + [lhs.x, lhs.y, rhs.x, rhs.y].map(fe_to_limbs::<_, Fr, LIMBS, BITS>).concat(); + for (lhs, rhs) in instances.iter().zip(assigned_instances.iter()) { + assert_eq!(lhs, rhs.value()); + } } + + let builder = loader.take_ctx(); + let circuit = match stage { + CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder), + CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(builder, break_points.unwrap()) + } + }; + let inner = RangeWithInstanceCircuitBuilder::new(circuit, assigned_instances); + Self { inner, previous_instances, as_proof } } - pub fn instance(&self) -> Vec { - self.instances.clone() + pub fn public( + stage: CircuitBuilderStage, + break_points: Option, + lookup_bits: usize, + params: &ParamsKZG, + snarks: impl IntoIterator, + has_prev_accumulator: bool, + ) -> Self + where + AS: for<'a> Halo2KzgAccumulationScheme<'a>, + { + let mut private = Self::new::(stage, break_points, lookup_bits, params, snarks); + private.expose_previous_instances(has_prev_accumulator); + private + } + + // this function is for convenience + /// `params` should be the universal trusted setup to be used for the aggregation circuit, not the one used to generate the previous snarks, although we assume both use the same generator g[0] + pub fn keygen(params: &ParamsKZG, snarks: impl IntoIterator) -> Self + where + AS: for<'a> Halo2KzgAccumulationScheme<'a>, + { + let lookup_bits = params.k() as usize - 1; // almost always we just use the max lookup bits possible, which is k - 1 because of blinding factors + let circuit = + Self::new::(CircuitBuilderStage::Keygen, None, lookup_bits, params, snarks); + circuit.config(params.k(), Some(10)); + set_var("LOOKUP_BITS", lookup_bits.to_string()); + circuit + } + + // this function is for convenience + pub fn prover( + params: &ParamsKZG, + snarks: impl IntoIterator, + break_points: MultiPhaseThreadBreakPoints, + ) -> Self + where + AS: for<'a> Halo2KzgAccumulationScheme<'a>, + { + let lookup_bits: usize = var("LOOKUP_BITS").expect("LOOKUP_BITS not set").parse().unwrap(); + let circuit = Self::new::( + CircuitBuilderStage::Prover, + Some(break_points), + lookup_bits, + params, + snarks, + ); + let minimum_rows = var("MINIMUM_ROWS").map(|s| s.parse().unwrap_or(10)).unwrap_or(10); + circuit.config(params.k(), Some(minimum_rows)); + set_var("LOOKUP_BITS", lookup_bits.to_string()); + circuit } - pub fn succinct_verifying_key(&self) -> &Svk { - &self.svk + /// Re-expose the previous public instances of aggregated snarks again. + /// If `hash_prev_accumulator` is true, then we assume all aggregated snarks were themselves + /// aggregation snarks, and we exclude the old accumulators from the public input. + pub fn expose_previous_instances(&mut self, has_prev_accumulator: bool) { + let start = (has_prev_accumulator as usize) * 4 * LIMBS; + for prev in self.previous_instances.iter() { + self.inner.assigned_instances.extend_from_slice(&prev[start..]); + } + } + + pub fn as_proof(&self) -> &[u8] { + &self.as_proof[..] + } + + pub fn config(&self, k: u32, minimum_rows: Option) -> FlexGateConfigParams { + self.inner.config(k, minimum_rows) + } + + pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { + self.inner.break_points() } - pub fn snarks(&self) -> &[SnarkWitness] { - &self.snarks + pub fn instance_count(&self) -> usize { + self.inner.instance_count() } - pub fn as_proof(&self) -> Value<&[u8]> { - self.as_proof.as_ref().map(Vec::as_slice) + pub fn instance(&self) -> Vec { + self.inner.instance() } } -impl CircuitExt for AggregationCircuit { +impl CircuitExt for RangeWithInstanceCircuitBuilder { fn num_instance(&self) -> Vec { - // [..lhs, ..rhs] - vec![4 * LIMBS] + vec![self.instance_count()] } - fn instances(&self) -> Vec> { + fn instances(&self) -> Vec> { vec![self.instance()] } - fn accumulator_indices() -> Option> { - Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) - } - fn selectors(config: &Self::Config) -> Vec { - config.gate().basic_gates[0].iter().map(|gate| gate.q_enable).collect() + config.range.gate.basic_gates[0].iter().map(|gate| gate.q_enable).collect() } } impl Circuit for AggregationCircuit { - type Config = AggregationConfig; + type Config = RangeWithInstanceConfig; type FloorPlanner = SimpleFloorPlanner; fn without_witnesses(&self) -> Self { - Self { - svk: self.svk, - snarks: self.snarks.iter().map(SnarkWitness::without_witnesses).collect(), - instances: Vec::new(), - as_proof: Value::unknown(), - } + unimplemented!() } - fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - let path = std::env::var("VERIFY_CONFIG") - .unwrap_or_else(|_| "configs/verify_circuit.config".to_owned()); - let params: AggregationConfigParams = serde_json::from_reader( - File::open(path.as_str()).unwrap_or_else(|_| panic!("{path:?} does not exist")), - ) - .unwrap(); - - AggregationConfig::configure(meta, params) + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + RangeWithInstanceCircuitBuilder::configure(meta) } fn synthesize( &self, config: Self::Config, - mut layouter: impl Layouter, + layouter: impl Layouter, ) -> Result<(), plonk::Error> { - #[cfg(feature = "display")] - let witness_time = start_timer!(|| "synthesize | Aggregation Circuit"); - config.range().load_lookup_table(&mut layouter).expect("load range lookup table"); - let mut first_pass = halo2_base::SKIP_FIRST_PASS; - let mut instances = vec![]; - layouter - .assign_region( - || "", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - let ctx = Context::new( - region, - ContextParams { - max_rows: config.gate().max_rows, - num_context_ids: 1, - fixed_columns: config.gate().constants.clone(), - }, - ); - - let ecc_chip = config.ecc_chip(); - let loader = Halo2Loader::new(ecc_chip, ctx); - let (_, acc) = aggregate::>( - &self.svk, - &loader, - &self.snarks, - self.as_proof(), - ); - - instances.extend( - flatten_accumulator(acc).iter().map(|assigned| assigned.cell().clone()), - ); - - config.range().finalize(&mut loader.ctx_mut()); - #[cfg(feature = "display")] - loader.ctx_mut().print_stats(&["Range"]); - Ok(()) - }, - ) - .unwrap(); - - // Expose instances - for (i, cell) in instances.into_iter().enumerate() { - layouter.constrain_instance(cell, config.instance, i); - } - #[cfg(feature = "display")] - end_timer!(witness_time); - Ok(()) - } -} - -/// This circuit takes multiple SNARKs and passes through all of their instances except the old accumulators. -/// -/// * If `has_prev_accumulator = true`, we assume all SNARKs are of aggregation circuits with old accumulators -/// only in the first instance column. -/// * Otherwise if `has_prev_accumulator = false`, then all previous instances are passed through. -#[derive(Clone)] -pub struct PublicAggregationCircuit { - pub aggregation: AggregationCircuit, - pub has_prev_accumulator: bool, -} - -impl PublicAggregationCircuit { - pub fn new( - params: &ParamsKZG, - snarks: Vec, - has_prev_accumulator: bool, - rng: &mut (impl Rng + Send), - ) -> Self { - Self { aggregation: AggregationCircuit::new(params, snarks, rng), has_prev_accumulator } + self.inner.synthesize(config, layouter) } } -impl CircuitExt for PublicAggregationCircuit { +impl CircuitExt for AggregationCircuit { fn num_instance(&self) -> Vec { - let prev_num = self - .aggregation - .snarks - .iter() - .map(|snark| snark.instances.iter().map(|instance| instance.len()).sum::()) - .sum::() - - self.aggregation.snarks.len() * 4 * LIMBS * usize::from(self.has_prev_accumulator); - vec![4 * LIMBS + prev_num] + self.inner.num_instance() } fn instances(&self) -> Vec> { - let start_idx = 4 * LIMBS * usize::from(self.has_prev_accumulator); - let instance = self - .aggregation - .instances - .iter() - .cloned() - .chain(self.aggregation.snarks.iter().flat_map(|snark| { - snark.instances.iter().enumerate().flat_map(|(i, instance)| { - instance[usize::from(i == 0) * start_idx..] - .iter() - .map(|v| value_to_option(*v).unwrap()) - }) - })) - .collect_vec(); - vec![instance] + self.inner.instances() } fn accumulator_indices() -> Option> { @@ -437,83 +425,16 @@ impl CircuitExt for PublicAggregationCircuit { } fn selectors(config: &Self::Config) -> Vec { - AggregationCircuit::selectors(config) + RangeWithInstanceCircuitBuilder::selectors(config) } } -impl Circuit for PublicAggregationCircuit { - type Config = AggregationConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - aggregation: self.aggregation.without_witnesses(), - has_prev_accumulator: self.has_prev_accumulator, - } - } - - fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - AggregationCircuit::configure(meta) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), plonk::Error> { - #[cfg(feature = "display")] - let witness_time = start_timer!(|| { "synthesize | EVM verifier" }); - config.range().load_lookup_table(&mut layouter).expect("load range lookup table"); - let mut first_pass = halo2_base::SKIP_FIRST_PASS; - let mut instances = vec![]; - layouter - .assign_region( - || "", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - let ctx = Context::new( - region, - ContextParams { - max_rows: config.gate().max_rows, - num_context_ids: 1, - fixed_columns: config.gate().constants.clone(), - }, - ); - - let ecc_chip = config.ecc_chip(); - let loader = Halo2Loader::new(ecc_chip, ctx); - let (prev_instances, acc) = aggregate::>( - &self.aggregation.svk, - &loader, - &self.aggregation.snarks, - self.aggregation.as_proof(), - ); - - // accumulator - instances.extend(flatten_accumulator(acc).iter().map(|a| a.cell().clone())); - // prev instances except accumulators - let start_idx = 4 * LIMBS * usize::from(self.has_prev_accumulator); - for prev_instance in prev_instances { - instances - .extend(prev_instance[start_idx..].iter().map(|a| a.cell().clone())); - } - - config.range().finalize(&mut loader.ctx_mut()); - #[cfg(feature = "display")] - loader.ctx_mut().print_stats(&["Range"]); - Ok(()) - }, - ) - .unwrap(); - // Expose instances - for (i, cell) in instances.into_iter().enumerate() { - layouter.constrain_instance(cell, config.instance, i); - } - #[cfg(feature = "display")] - end_timer!(witness_time); - Ok(()) - } +pub fn load_verify_circuit_degree() -> u32 { + let path = std::env::var("VERIFY_CONFIG") + .unwrap_or_else(|_| "./configs/verify_circuit.config".to_string()); + let params: AggregationConfigParams = serde_json::from_reader( + File::open(path.as_str()).unwrap_or_else(|_| panic!("{path} does not exist")), + ) + .unwrap(); + params.degree } diff --git a/snark-verifier-sdk/src/lib.rs b/snark-verifier-sdk/src/lib.rs index c2342866..9a5833d6 100644 --- a/snark-verifier-sdk/src/lib.rs +++ b/snark-verifier-sdk/src/lib.rs @@ -1,9 +1,9 @@ #![feature(associated_type_defaults)] +#![feature(trait_alias)] #[cfg(feature = "display")] use ark_std::{end_timer, start_timer}; -use halo2_base::halo2_proofs; +use halo2_base::halo2_proofs::{self}; use halo2_proofs::{ - circuit::Value, halo2curves::{ bn256::{Bn256, Fr, G1Affine}, group::ff::Field, @@ -15,7 +15,10 @@ use halo2_proofs::{ use itertools::Itertools; use serde::{Deserialize, Serialize}; pub use snark_verifier::loader::native::NativeLoader; -use snark_verifier::{pcs::kzg::LimbsEncoding, verifier, Protocol}; +use snark_verifier::{ + pcs::kzg::{Bdfg21, Gwc19, KzgAs, LimbsEncoding}, + verifier::{self, plonk::PlonkProtocol}, +}; use std::{ fs::{self, File}, io::{self, BufReader, BufWriter}, @@ -30,58 +33,29 @@ pub mod halo2; pub const LIMBS: usize = 3; pub const BITS: usize = 88; -/// PCS be either `Kzg` or `Kzg` -pub type Plonk = verifier::Plonk>; +/// AS stands for accumulation scheme. +/// AS can be either `Kzg` (the original PLONK KZG multi-open) or `Kzg` (SHPLONK) +pub type PlonkVerifier = verifier::plonk::PlonkVerifier>; +pub type PlonkSuccinctVerifier = + verifier::plonk::PlonkSuccinctVerifier>; +pub type SHPLONK = KzgAs; +pub type GWC = KzgAs; -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug)] +#[cfg_attr(feature = "halo2-axiom", derive(Serialize, Deserialize))] pub struct Snark { - pub protocol: Protocol, + pub protocol: PlonkProtocol, pub instances: Vec>, pub proof: Vec, } impl Snark { - pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + pub fn new(protocol: PlonkProtocol, instances: Vec>, proof: Vec) -> Self { Self { protocol, instances, proof } } -} - -impl From for SnarkWitness { - fn from(snark: Snark) -> Self { - Self { - protocol: snark.protocol, - instances: snark - .instances - .into_iter() - .map(|instances| instances.into_iter().map(Value::known).collect_vec()) - .collect(), - proof: Value::known(snark.proof), - } - } -} - -#[derive(Clone, Debug)] -pub struct SnarkWitness { - pub protocol: Protocol, - pub instances: Vec>>, - pub proof: Value>, -} - -impl SnarkWitness { - pub fn without_witnesses(&self) -> Self { - SnarkWitness { - protocol: self.protocol.clone(), - instances: self - .instances - .iter() - .map(|instances| vec![Value::unknown(); instances.len()]) - .collect(), - proof: Value::unknown(), - } - } - pub fn proof(&self) -> Value<&[u8]> { - self.proof.as_ref().map(Vec::as_slice) + pub fn proof(&self) -> &[u8] { + &self.proof[..] } } @@ -194,7 +168,7 @@ mod zkevm { fn instances(&self) -> Vec> { vec![] } - fn num_instance() -> Vec { + fn num_instance(&self) -> Vec { vec![] } } @@ -203,7 +177,7 @@ mod zkevm { fn instances(&self) -> Vec> { vec![] } - fn num_instance() -> Vec { + fn num_instance(&self) -> Vec { vec![] } } diff --git a/snark-verifier/Cargo.toml b/snark-verifier/Cargo.toml index 6ac397a9..74bc99fe 100644 --- a/snark-verifier/Cargo.toml +++ b/snark-verifier/Cargo.toml @@ -1,37 +1,35 @@ [package] name = "snark-verifier" -version = "0.1.0" +version = "0.1.1" edition = "2021" [dependencies] -itertools = "0.10.3" +itertools = "0.10.5" lazy_static = "1.4.0" num-bigint = "0.4.3" num-integer = "0.1.45" num-traits = "0.2.15" hex = "0.4" rand = "0.8" -rustc-hash = "1.1.0" serde = { version = "1.0", features = ["derive"] } # Use halo2-base as non-optional dependency because it re-exports halo2_proofs, halo2curves, and poseidon, using different repos based on feature flag "halo2-axiom" or "halo2-pse" -halo2-base = { git = "https://github.com/axiom-crypto/halo2-lib.git", tag = "v0.2.2", default-features = false } -# This poseidon is identical to PSE (for now) but uses axiom's halo2curves; otherwise would require patching -poseidon-axiom = { git = "https://github.com/axiom-crypto/halo2.git", tag = "v2023_01_17", package = "poseidon", optional = true } -poseidon= { git = "https://github.com/privacy-scaling-explorations/poseidon", optional = true } +halo2-base = { git = "https://github.com/axiom-crypto/halo2-lib.git", tag = "v0.3.0", default-features = false } +# This is Scroll's audited poseidon circuit. We only use it for the Native Poseidon spec. We do not use the halo2 circuit at all (and it wouldn't even work because the halo2_proofs tag is not compatbile). +poseidon-circuit = { git = "https://github.com/scroll-tech/poseidon-circuit.git", rev = "50015b7" } # parallel -rayon = { version = "1.5.3", optional = true } +rayon = { version = "1.7.0", optional = true } # loader_evm -ethereum-types = { version = "0.14", default-features = false, features = ["std"], optional = true } -sha3 = { version = "0.10", optional = true } -revm = { version = "2.3.1", optional = true } -bytes = { version = "1.2", optional = true } -rlp = { version = "0.5", default-features = false, features = ["std"], optional = true } +sha3 = { version = "=0.10.8", optional = true } +bytes = { version = "=1.4.0", default-features = false, optional = true } +primitive-types = { version = "=0.12.1", default-features = false, features = ["std"], optional = true } +rlp = { version = "=0.5.2", default-features = false, features = ["std"], optional = true } +revm = { version = "=2.3.1", optional = true } # loader_halo2 -halo2-ecc = { git = "https://github.com/axiom-crypto/halo2-lib.git", tag = "v0.2.2", default-features = false, optional = true } +halo2-ecc = { git = "https://github.com/axiom-crypto/halo2-lib.git", tag = "v0.3.0", default-features = false, optional = true } [dev-dependencies] ark-std = { version = "0.3.0", features = ["print-trace"] } @@ -44,14 +42,14 @@ crossterm = { version = "0.25" } tui = { version = "0.19", default-features = false, features = ["crossterm"] } [features] -default = ["loader_evm", "loader_halo2", "halo2-axiom"] +default = ["loader_evm", "loader_halo2", "halo2-axiom", "display"] display = ["halo2-base/display", "halo2-ecc?/display"] -loader_evm = ["dep:ethereum-types", "dep:sha3", "dep:revm", "dep:bytes", "dep:rlp"] +loader_evm = ["dep:primitive-types", "dep:sha3", "dep:revm", "dep:bytes", "dep:rlp"] loader_halo2 = ["halo2-ecc"] parallel = ["dep:rayon"] # EXACTLY one of halo2-pse / halo2-axiom should always be turned on; not sure how to enforce this with Cargo -halo2-pse = ["halo2-base/halo2-pse", "halo2-ecc?/halo2-pse", "poseidon"] -halo2-axiom = ["halo2-base/halo2-axiom", "halo2-ecc?/halo2-axiom", "poseidon-axiom"] +halo2-pse = ["halo2-base/halo2-pse", "halo2-ecc?/halo2-pse"] +halo2-axiom = ["halo2-base/halo2-axiom", "halo2-ecc?/halo2-axiom"] [[example]] name = "evm-verifier" diff --git a/snark-verifier/configs/example_evm_accumulator.config b/snark-verifier/configs/example_evm_accumulator.config deleted file mode 100644 index fcda49a0..00000000 --- a/snark-verifier/configs/example_evm_accumulator.config +++ /dev/null @@ -1 +0,0 @@ -{"strategy":"Simple","degree":21,"num_advice":5,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/snark-verifier/configs/example_evm_accumulator.json b/snark-verifier/configs/example_evm_accumulator.json new file mode 100644 index 00000000..c7c63783 --- /dev/null +++ b/snark-verifier/configs/example_evm_accumulator.json @@ -0,0 +1,7 @@ +{ + "degree": 21, + "num_advice": 5, + "num_lookup_advice": 1, + "num_fixed": 1, + "lookup_bits": 20 +} diff --git a/snark-verifier/configs/example_recursion.json b/snark-verifier/configs/example_recursion.json new file mode 100644 index 00000000..986e925e --- /dev/null +++ b/snark-verifier/configs/example_recursion.json @@ -0,0 +1,7 @@ +{ + "degree": 21, + "num_advice": 4, + "num_lookup_advice": 1, + "num_fixed": 1, + "lookup_bits": 20 +} diff --git a/snark-verifier/configs/verify_circuit.config b/snark-verifier/configs/verify_circuit.config deleted file mode 100644 index e65b2b52..00000000 --- a/snark-verifier/configs/verify_circuit.config +++ /dev/null @@ -1 +0,0 @@ -{"strategy":"Simple","degree":21,"num_advice":4,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/snark-verifier/examples/evm-verifier-with-accumulator.rs b/snark-verifier/examples/evm-verifier-with-accumulator.rs index be936611..dd537880 100644 --- a/snark-verifier/examples/evm-verifier-with-accumulator.rs +++ b/snark-verifier/examples/evm-verifier-with-accumulator.rs @@ -1,11 +1,11 @@ -use ethereum_types::Address; -use halo2_base::halo2_proofs; +use aggregation::{AggregationCircuit, AggregationConfigParams}; +use halo2_base::{gates::builder::CircuitBuilderStage, halo2_proofs, utils::fs::gen_srs}; use halo2_proofs::{ dev::MockProver, halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, plonk::{create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey}, poly::{ - commitment::{Params, ParamsProver}, + commitment::ParamsProver, kzg::{ commitment::{KZGCommitmentScheme, ParamsKZG}, multiopen::{ProverGWC, VerifierGWC}, @@ -19,21 +19,21 @@ use itertools::Itertools; use rand::rngs::OsRng; use snark_verifier::{ loader::{ - evm::{self, encode_calldata, EvmLoader, ExecutorBuilder}, + evm::{self, encode_calldata, Address, EvmLoader, ExecutorBuilder}, native::NativeLoader, }, - pcs::kzg::{Gwc19, Kzg, KzgAs, LimbsEncoding}, + pcs::kzg::{Gwc19, KzgAs, LimbsEncoding}, system::halo2::{compile, transcript::evm::EvmTranscript, Config}, - verifier::{self, PlonkVerifier}, + verifier::{self, SnarkVerifier}, }; -use std::{io::Cursor, rc::Rc}; +use std::{env::set_var, fs::File, io::Cursor, rc::Rc}; const LIMBS: usize = 3; const BITS: usize = 88; -type Pcs = Kzg; -type As = KzgAs; -type Plonk = verifier::Plonk>; +type As = KzgAs; +type PlonkSuccinctVerifier = verifier::plonk::PlonkSuccinctVerifier>; +type PlonkVerifier = verifier::plonk::PlonkVerifier>; mod application { use super::halo2_proofs::{ @@ -161,18 +161,14 @@ mod application { } #[cfg(feature = "halo2-axiom")] { - region.assign_advice( - config.a, - 0, - Value::known(Assigned::Trivial(self.0)), - )?; + region.assign_advice(config.a, 0, Value::known(Assigned::Trivial(self.0))); region.assign_fixed(config.q_a, 0, Assigned::Trivial(-Fr::one())); region.assign_advice( config.a, 1, Value::known(Assigned::Trivial(-Fr::from(5u64))), - )?; + ); for (idx, column) in (1..).zip([ config.q_a, config.q_b, @@ -187,7 +183,7 @@ mod application { config.a, 2, Value::known(Assigned::Trivial(Fr::one())), - )?; + ); a.copy_advice(&mut region, config.b, 3); a.copy_advice(&mut region, config.c, 4); } @@ -200,16 +196,26 @@ mod application { } mod aggregation { + use crate::PlonkSuccinctVerifier; + use super::halo2_proofs::{ - circuit::{Cell, Layouter, SimpleFloorPlanner, Value}, - plonk::{self, Circuit, Column, ConstraintSystem, Instance}, - poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, + circuit::{Layouter, SimpleFloorPlanner}, + plonk::{self, Circuit, Column, Instance}, + }; + use super::{As, BITS, LIMBS}; + use super::{Fr, G1Affine}; + use halo2_base::{ + gates::{ + builder::{ + assign_threads_in, CircuitBuilderStage, FlexGateConfigParams, GateThreadBuilder, + MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + range::RangeConfig, + RangeChip, + }, + AssignedValue, SKIP_FIRST_PASS, }; - use super::{As, Plonk, BITS, LIMBS}; - use super::{Bn256, Fq, Fr, G1Affine}; - use ark_std::{end_timer, start_timer}; - use halo2_base::{Context, ContextParams}; - use halo2_ecc::ecc::EccChip; + use halo2_ecc::bn254::FpChip; use itertools::Itertools; use rand::rngs::OsRng; use snark_verifier::{ @@ -220,80 +226,52 @@ mod aggregation { }, system, util::arithmetic::fe_to_limbs, - verifier::PlonkVerifier, - Protocol, + verifier::{plonk::PlonkProtocol, SnarkVerifier}, }; - use std::{fs::File, rc::Rc}; + use std::{collections::HashMap, rc::Rc}; - const T: usize = 5; - const RATE: usize = 4; + const T: usize = 3; + const RATE: usize = 2; const R_F: usize = 8; - const R_P: usize = 60; + const R_P: usize = 57; + const SECURE_MDS: usize = 0; type Svk = KzgSuccinctVerifyingKey; - type BaseFieldEccChip = halo2_ecc::ecc::BaseFieldEccChip; - type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; + type BaseFieldEccChip<'chip> = halo2_ecc::ecc::BaseFieldEccChip<'chip, G1Affine>; + type Halo2Loader<'chip> = loader::halo2::Halo2Loader>; pub type PoseidonTranscript = system::halo2::transcript::halo2::PoseidonTranscript; + #[derive(Clone)] pub struct Snark { - protocol: Protocol, + protocol: PlonkProtocol, instances: Vec>, proof: Vec, } impl Snark { - pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + pub fn new( + protocol: PlonkProtocol, + instances: Vec>, + proof: Vec, + ) -> Self { Self { protocol, instances, proof } } } - impl From for SnarkWitness { - fn from(snark: Snark) -> Self { - Self { - protocol: snark.protocol, - instances: snark - .instances - .into_iter() - .map(|instances| instances.into_iter().map(Value::known).collect_vec()) - .collect(), - proof: Value::known(snark.proof), - } - } - } - - #[derive(Clone)] - pub struct SnarkWitness { - protocol: Protocol, - instances: Vec>>, - proof: Value>, - } - - impl SnarkWitness { - fn without_witnesses(&self) -> Self { - SnarkWitness { - protocol: self.protocol.clone(), - instances: self - .instances - .iter() - .map(|instances| vec![Value::unknown(); instances.len()]) - .collect(), - proof: Value::unknown(), - } - } - - fn proof(&self) -> Value<&[u8]> { - self.proof.as_ref().map(Vec::as_slice) + impl Snark { + fn proof(&self) -> &[u8] { + self.proof.as_slice() } } pub fn aggregate<'a>( svk: &Svk, loader: &Rc>, - snarks: &[SnarkWitness], - as_proof: Value<&'_ [u8]>, + snarks: &[Snark], + as_proof: &[u8], ) -> KzgAccumulator>> { - let assign_instances = |instances: &[Vec>]| { + let assign_instances = |instances: &[Vec]| { instances .iter() .map(|instances| { @@ -308,122 +286,134 @@ mod aggregation { let protocol = snark.protocol.loaded(loader); let instances = assign_instances(&snark.instances); let mut transcript = - PoseidonTranscript::, _>::new(loader, snark.proof()); - let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript); - Plonk::succinct_verify(svk, &protocol, &instances, &proof) + PoseidonTranscript::, _>::new::<0>(loader, snark.proof()); + let proof = + PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript) + .unwrap(); + PlonkSuccinctVerifier::verify(svk, &protocol, &instances, &proof).unwrap() }) .collect_vec(); - let acccumulator = { - let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); - let proof = - As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); - As::verify(&Default::default(), &accumulators, &proof).unwrap() - }; - - acccumulator + let mut transcript = + PoseidonTranscript::, _>::new::(loader, as_proof); + let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); + As::verify(&Default::default(), &accumulators, &proof).unwrap() } #[derive(serde::Serialize, serde::Deserialize)] pub struct AggregationConfigParams { - pub strategy: halo2_ecc::fields::fp::FpStrategy, pub degree: u32, pub num_advice: usize, pub num_lookup_advice: usize, pub num_fixed: usize, pub lookup_bits: usize, - pub limb_bits: usize, - pub num_limbs: usize, } #[derive(Clone)] pub struct AggregationConfig { - pub base_field_config: halo2_ecc::fields::fp::FpConfig, + pub range: RangeConfig, pub instance: Column, } - impl AggregationConfig { - pub fn configure(meta: &mut ConstraintSystem, params: AggregationConfigParams) -> Self { - assert!( - params.limb_bits == BITS && params.num_limbs == LIMBS, - "For now we fix limb_bits = {}, otherwise change code", - BITS - ); - let base_field_config = halo2_ecc::fields::fp::FpConfig::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - halo2_base::utils::modulus::(), - 0, - params.degree as usize, - ); - - let instance = meta.instance_column(); - meta.enable_equality(instance); - - Self { base_field_config, instance } - } - - pub fn range(&self) -> &halo2_base::gates::range::RangeConfig { - &self.base_field_config.range - } - - pub fn ecc_chip(&self) -> halo2_ecc::ecc::BaseFieldEccChip { - EccChip::construct(self.base_field_config.clone()) - } - } - - #[derive(Clone)] + #[derive(Clone, Debug)] pub struct AggregationCircuit { - svk: Svk, - snarks: Vec, - instances: Vec, - as_proof: Value>, + pub circuit: RangeCircuitBuilder, + pub as_proof: Vec, + pub assigned_instances: Vec>, } impl AggregationCircuit { - pub fn new(params: &ParamsKZG, snarks: impl IntoIterator) -> Self { - let svk = params.get_g()[0].into(); + pub fn new( + stage: CircuitBuilderStage, + break_points: Option, + lookup_bits: usize, + params_g0: G1Affine, + snarks: impl IntoIterator, + ) -> Self { + let svk: Svk = params_g0.into(); let snarks = snarks.into_iter().collect_vec(); + // verify the snarks natively to get public instances let accumulators = snarks .iter() .flat_map(|snark| { - let mut transcript = - PoseidonTranscript::::new(snark.proof.as_slice()); - let proof = - Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript); - Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof) + let mut transcript = PoseidonTranscript::::new::( + snark.proof.as_slice(), + ); + let proof = PlonkSuccinctVerifier::read_proof( + &svk, + &snark.protocol, + &snark.instances, + &mut transcript, + ) + .unwrap(); + PlonkSuccinctVerifier::verify(&svk, &snark.protocol, &snark.instances, &proof) + .unwrap() }) .collect_vec(); - let (accumulator, as_proof) = { - let mut transcript = PoseidonTranscript::::new(Vec::new()); + let (_accumulator, as_proof) = { + let mut transcript = + PoseidonTranscript::::new::(Vec::new()); let accumulator = As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng) .unwrap(); (accumulator, transcript.finalize()) }; - let KzgAccumulator { lhs, rhs } = accumulator; - let instances = - [lhs.x, lhs.y, rhs.x, rhs.y].map(fe_to_limbs::<_, _, LIMBS, BITS>).concat(); + // create thread builder and run aggregation witness gen + let builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + // create halo2loader + let range = RangeChip::::default(lookup_bits); + let fp_chip = FpChip::::new(&range, BITS, LIMBS); + let ecc_chip = BaseFieldEccChip::new(&fp_chip); + let loader = Halo2Loader::new(ecc_chip, builder); + + let KzgAccumulator { lhs, rhs } = + aggregate(&svk, &loader, &snarks, as_proof.as_slice()); + let lhs = lhs.assigned(); + let rhs = rhs.assigned(); + let assigned_instances = lhs + .x() + .limbs() + .iter() + .chain(lhs.y().limbs().iter()) + .chain(rhs.x().limbs().iter()) + .chain(rhs.y().limbs().iter()) + .copied() + .collect_vec(); - Self { - svk, - snarks: snarks.into_iter().map_into().collect(), - instances, - as_proof: Value::known(as_proof), + #[cfg(debug_assertions)] + { + let KzgAccumulator { lhs, rhs } = _accumulator; + let instances = + [lhs.x, lhs.y, rhs.x, rhs.y].map(fe_to_limbs::<_, Fr, LIMBS, BITS>).concat(); + for (lhs, rhs) in instances.iter().zip(assigned_instances.iter()) { + assert_eq!(lhs, rhs.value()); + } } + + let builder = loader.take_ctx(); + let circuit = match stage { + CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder), + CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(builder, break_points.unwrap()) + } + }; + Self { circuit, as_proof, assigned_instances } + } + + pub fn config(&self, k: u32, minimum_rows: Option) -> FlexGateConfigParams { + self.circuit.0.builder.borrow().config(k as usize, minimum_rows) } - pub fn as_proof(&self) -> Value<&[u8]> { - self.as_proof.as_ref().map(Vec::as_slice) + pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { + self.circuit.0.break_points.borrow().clone() } pub fn num_instance() -> Vec { @@ -432,7 +422,7 @@ mod aggregation { } pub fn instances(&self) -> Vec> { - vec![self.instances.clone()] + vec![self.assigned_instances.iter().map(|v| *v.value()).collect_vec()] } pub fn accumulator_indices() -> Vec<(usize, usize)> { @@ -445,23 +435,14 @@ mod aggregation { type FloorPlanner = SimpleFloorPlanner; fn without_witnesses(&self) -> Self { - Self { - svk: self.svk, - snarks: self.snarks.iter().map(SnarkWitness::without_witnesses).collect(), - instances: Vec::new(), - as_proof: Value::unknown(), - } + unimplemented!() } fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - let path = std::env::var("VERIFY_CONFIG").unwrap(); - let params: AggregationConfigParams = serde_json::from_reader( - File::open(path.as_str()) - .unwrap_or_else(|err| panic!("Path {path} does not exist: {err:?}")), - ) - .unwrap(); - - AggregationConfig::configure(meta, params) + let range = RangeCircuitBuilder::configure(meta); + let instance = meta.instance_column(); + meta.enable_equality(instance); + AggregationConfig { range, instance } } fn synthesize( @@ -469,63 +450,76 @@ mod aggregation { config: Self::Config, mut layouter: impl Layouter, ) -> Result<(), plonk::Error> { - config.range().load_lookup_table(&mut layouter)?; - let max_rows = config.range().gate.max_rows; - - let mut first_pass = halo2_base::SKIP_FIRST_PASS; // assume using simple floor planner - let mut assigned_instances: Option> = None; - layouter.assign_region( - || "", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - let witness_time = start_timer!(|| "Witness Collection"); - let ctx = Context::new( - region, - ContextParams { - max_rows, - num_context_ids: 1, - fixed_columns: config.base_field_config.range.gate.constants.clone(), - }, - ); - - let ecc_chip = config.ecc_chip(); - let loader = Halo2Loader::new(ecc_chip, ctx); - let KzgAccumulator { lhs, rhs } = - aggregate(&self.svk, &loader, &self.snarks, self.as_proof()); - - let lhs = lhs.assigned(); - let rhs = rhs.assigned(); - - config.base_field_config.finalize(&mut loader.ctx_mut()); - #[cfg(feature = "display")] - println!("Total advice cells: {}", loader.ctx().total_advice); - #[cfg(feature = "display")] - println!("Advice columns used: {}", loader.ctx().advice_alloc[0].0 + 1); - - let instances: Vec<_> = lhs - .x - .truncation - .limbs - .iter() - .chain(lhs.y.truncation.limbs.iter()) - .chain(rhs.x.truncation.limbs.iter()) - .chain(rhs.y.truncation.limbs.iter()) - .map(|assigned| assigned.cell().clone()) - .collect(); - assigned_instances = Some(instances); - end_timer!(witness_time); - Ok(()) - }, - )?; - - // Expose instances - // TODO: use less instances by following Scroll's strategy of keeping only last bit of y coordinate - let mut layouter = layouter.namespace(|| "expose"); - for (i, cell) in assigned_instances.unwrap().into_iter().enumerate() { - layouter.constrain_instance(cell, config.instance, i); + // copied from RangeCircuitBuilder::synthesize but with extra logic to expose public instances + let range = config.range; + let circuit = &self.circuit.0; + range.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); + + // we later `take` the builder, so we need to save this value + let witness_gen_only = circuit.builder.borrow().witness_gen_only(); + let mut assigned_advices = HashMap::new(); + + let mut first_pass = SKIP_FIRST_PASS; + layouter + .assign_region( + || "AggregationCircuit", + |mut region| { + if first_pass { + first_pass = false; + return Ok(()); + } + // only support FirstPhase in this Builder because getting challenge value requires more specialized witness generation during synthesize + if !witness_gen_only { + // clone the builder so we can re-use the circuit for both vk and pk gen + let builder = circuit.builder.borrow(); + let assignments = builder.assign_all( + &range.gate, + &range.lookup_advice, + &range.q_lookup, + &mut region, + Default::default(), + ); + *circuit.break_points.borrow_mut() = assignments.break_points; + assigned_advices = assignments.assigned_advices; + } else { + #[cfg(feature = "display")] + let start0 = std::time::Instant::now(); + let builder = circuit.builder.take(); + let break_points = circuit.break_points.take(); + for (phase, (threads, break_points)) in builder + .threads + .into_iter() + .zip(break_points.into_iter()) + .enumerate() + .take(1) + { + assign_threads_in( + phase, + threads, + &range.gate, + &range.lookup_advice[phase], + &mut region, + break_points, + ); + } + #[cfg(feature = "display")] + println!("assign threads in {:?}", start0.elapsed()); + } + Ok(()) + }, + ) + .unwrap(); + + if !witness_gen_only { + // expose public instances + let mut layouter = layouter.namespace(|| "expose"); + for (i, instance) in self.assigned_instances.iter().enumerate() { + let cell = instance.cell.unwrap(); + let (cell, _) = assigned_advices + .get(&(cell.context_id, cell.offset)) + .expect("instance not assigned"); + layouter.constrain_instance(*cell, config.instance, i); + } } Ok(()) } @@ -534,7 +528,10 @@ mod aggregation { fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { let vk = keygen_vk(params, circuit).unwrap(); - keygen_pk(params, vk, circuit).unwrap() + println!("finished vk"); + let pk = keygen_pk(params, vk, circuit).unwrap(); + println!("finished pk"); + pk } fn gen_proof< @@ -548,8 +545,6 @@ fn gen_proof< circuit: C, instances: Vec>, ) -> Vec { - MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); - let instances = instances.iter().map(|instances| instances.as_slice()).collect_vec(); let proof = { let mut transcript = TW::init(Vec::new()); @@ -608,8 +603,6 @@ fn gen_aggregation_evm_verifier( num_instance: Vec, accumulator_indices: Vec<(usize, usize)>, ) -> Vec { - let svk = params.get_g()[0].into(); - let dk = (params.g2(), params.s_g2()).into(); let protocol = compile( params, vk, @@ -617,14 +610,15 @@ fn gen_aggregation_evm_verifier( .with_num_instance(num_instance.clone()) .with_accumulator_indices(Some(accumulator_indices)), ); + let vk = (params.get_g()[0], params.g2(), params.s_g2()).into(); let loader = EvmLoader::new::(); let protocol = protocol.loaded(&loader); let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); let instances = transcript.load_instances(num_instance); - let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript); - Plonk::verify(&svk, &dk, &protocol, &instances, &proof); + let proof = PlonkVerifier::read_proof(&vk, &protocol, &instances, &mut transcript).unwrap(); + PlonkVerifier::verify(&vk, &protocol, &instances, &proof).unwrap(); evm::compile_yul(&loader.yul_code()) } @@ -646,16 +640,33 @@ fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) } fn main() { - std::env::set_var("VERIFY_CONFIG", "./configs/example_evm_accumulator.config"); - let params = halo2_base::utils::fs::gen_srs(21); - let params_app = { - let mut params = params.clone(); - params.downsize(8); - params - }; + let params_app = gen_srs(8); let snarks = [(); 3].map(|_| gen_application_snark(¶ms_app)); - let agg_circuit = aggregation::AggregationCircuit::new(¶ms, snarks); + + let path = "./configs/example_evm_accumulator.json"; + let agg_config: AggregationConfigParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + let agg_circuit = AggregationCircuit::new( + CircuitBuilderStage::Mock, + None, + agg_config.lookup_bits, + params_app.get_g()[0], + snarks.clone(), + ); + agg_circuit.config(agg_config.degree, Some(6)); + set_var("LOOKUP_BITS", agg_config.lookup_bits.to_string()); + #[cfg(debug_assertions)] + { + MockProver::run(agg_config.degree, &agg_circuit, agg_circuit.instances()) + .unwrap() + .assert_satisfied(); + println!("mock prover passed"); + } + + let params = gen_srs(agg_config.degree); let pk = gen_pk(¶ms, &agg_circuit); let deployment_code = gen_aggregation_evm_verifier( ¶ms, @@ -664,11 +675,22 @@ fn main() { aggregation::AggregationCircuit::accumulator_indices(), ); + let break_points = agg_circuit.break_points(); + drop(agg_circuit); + + let agg_circuit = AggregationCircuit::new( + CircuitBuilderStage::Prover, + Some(break_points), + agg_config.lookup_bits, + params_app.get_g()[0], + snarks, + ); + let instances = agg_circuit.instances(); let proof = gen_proof::<_, _, EvmTranscript, EvmTranscript>( ¶ms, &pk, - agg_circuit.clone(), - agg_circuit.instances(), + agg_circuit, + instances.clone(), ); - evm_verify(deployment_code, agg_circuit.instances(), proof); + evm_verify(deployment_code, instances, proof); } diff --git a/snark-verifier/examples/evm-verifier.rs b/snark-verifier/examples/evm-verifier.rs index e206b5e0..5b2aa802 100644 --- a/snark-verifier/examples/evm-verifier.rs +++ b/snark-verifier/examples/evm-verifier.rs @@ -1,4 +1,3 @@ -use ethereum_types::Address; use halo2_base::halo2_proofs; use halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, @@ -22,14 +21,14 @@ use halo2_proofs::{ use itertools::Itertools; use rand::{rngs::OsRng, RngCore}; use snark_verifier::{ - loader::evm::{self, encode_calldata, EvmLoader, ExecutorBuilder}, - pcs::kzg::{Gwc19, Kzg}, + loader::evm::{self, encode_calldata, Address, EvmLoader, ExecutorBuilder}, + pcs::kzg::{Gwc19, KzgAs}, system::halo2::{compile, transcript::evm::EvmTranscript, Config}, - verifier::{self, PlonkVerifier}, + verifier::{self, SnarkVerifier}, }; use std::rc::Rc; -type Plonk = verifier::Plonk>; +type PlonkVerifier = verifier::plonk::PlonkVerifier>; #[derive(Clone, Copy)] struct StandardPlonkConfig { @@ -140,14 +139,14 @@ impl Circuit for StandardPlonk { } #[cfg(feature = "halo2-axiom")] { - region.assign_advice(config.a, 0, Value::known(Assigned::Trivial(self.0)))?; + region.assign_advice(config.a, 0, Value::known(Assigned::Trivial(self.0))); region.assign_fixed(config.q_a, 0, Assigned::Trivial(-Fr::one())); region.assign_advice( config.a, 1, Value::known(Assigned::Trivial(-Fr::from(5u64))), - )?; + ); for (idx, column) in (1..).zip([ config.q_a, config.q_b, @@ -162,7 +161,7 @@ impl Circuit for StandardPlonk { config.a, 2, Value::known(Assigned::Trivial(Fr::one())), - )?; + ); a.copy_advice(&mut region, config.b, 3); a.copy_advice(&mut region, config.c, 4); } @@ -227,17 +226,16 @@ fn gen_evm_verifier( vk: &VerifyingKey, num_instance: Vec, ) -> Vec { - let svk = params.get_g()[0].into(); - let dk = (params.g2(), params.s_g2()).into(); let protocol = compile(params, vk, Config::kzg().with_num_instance(num_instance.clone())); + let vk = (params.get_g()[0], params.g2(), params.s_g2()).into(); let loader = EvmLoader::new::(); let protocol = protocol.loaded(&loader); let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); let instances = transcript.load_instances(num_instance); - let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript); - Plonk::verify(&svk, &dk, &protocol, &instances, &proof); + let proof = PlonkVerifier::read_proof(&vk, &protocol, &instances, &mut transcript).unwrap(); + PlonkVerifier::verify(&vk, &protocol, &instances, &proof).unwrap(); evm::compile_yul(&loader.yul_code()) } diff --git a/snark-verifier/examples/recursion.rs b/snark-verifier/examples/recursion.rs index abc0c808..5829e1b7 100644 --- a/snark-verifier/examples/recursion.rs +++ b/snark-verifier/examples/recursion.rs @@ -2,13 +2,14 @@ use ark_std::{end_timer, start_timer}; use common::*; -use halo2_base::halo2_proofs; +use halo2_base::gates::flex_gate::GateStrategy; use halo2_base::utils::fs::gen_srs; +use halo2_base::{gates::builder::FlexGateConfigParams, halo2_proofs}; use halo2_proofs::{ - circuit::{Layouter, SimpleFloorPlanner, Value}, + circuit::{Layouter, SimpleFloorPlanner}, dev::MockProver, halo2curves::{ - bn256::{Bn256, Fq, Fr, G1Affine}, + bn256::{Bn256, Fr, G1Affine}, group::ff::Field, FieldExt, }, @@ -31,7 +32,7 @@ use rand_chacha::rand_core::OsRng; use snark_verifier::{ loader::{self, native::NativeLoader, Loader, ScalarLoader}, pcs::{ - kzg::{Gwc19, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding}, + kzg::{Gwc19, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding}, AccumulationScheme, AccumulationSchemeProver, }, system::halo2::{self, compile, Config}, @@ -39,24 +40,28 @@ use snark_verifier::{ arithmetic::{fe_to_fe, fe_to_limbs}, hash, }, - verifier::{self, PlonkProof, PlonkVerifier}, - Protocol, + verifier::{ + self, + plonk::{PlonkProof, PlonkProtocol}, + SnarkVerifier, + }, }; -use std::{fs, iter, marker::PhantomData, rc::Rc}; +use std::{env::set_var, fs, iter, marker::PhantomData, rc::Rc}; use crate::recursion::AggregationConfigParams; const LIMBS: usize = 3; const BITS: usize = 88; -const T: usize = 5; -const RATE: usize = 4; +const T: usize = 3; +const RATE: usize = 2; const R_F: usize = 8; -const R_P: usize = 60; +const R_P: usize = 57; +const SECURE_MDS: usize = 0; -type Pcs = Kzg; type Svk = KzgSuccinctVerifyingKey; -type As = KzgAs; -type Plonk = verifier::Plonk>; +type As = KzgAs; +type PlonkVerifier = verifier::plonk::PlonkVerifier>; +type PlonkSuccinctVerifier = verifier::plonk::PlonkSuccinctVerifier>; type Poseidon = hash::Poseidon; type PoseidonTranscript = halo2::transcript::halo2::PoseidonTranscript; @@ -70,59 +75,30 @@ mod common { loader: &L, inputs: &[L::LoadedScalar], ) -> L::LoadedScalar { - let mut hasher = Poseidon::new(loader, R_F, R_P); + // warning: generating a new spec is time intensive, use lazy_static in production + let mut hasher = Poseidon::new::(loader); hasher.update(inputs); hasher.squeeze() } + #[derive(Clone)] pub struct Snark { - pub protocol: Protocol, + pub protocol: PlonkProtocol, pub instances: Vec>, pub proof: Vec, } impl Snark { - pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + pub fn new( + protocol: PlonkProtocol, + instances: Vec>, + proof: Vec, + ) -> Self { Self { protocol, instances, proof } } - } - - impl From for SnarkWitness { - fn from(snark: Snark) -> Self { - Self { - protocol: snark.protocol, - instances: snark - .instances - .into_iter() - .map(|instances| instances.into_iter().map(Value::known).collect_vec()) - .collect(), - proof: Value::known(snark.proof), - } - } - } - - #[derive(Clone)] - pub struct SnarkWitness { - pub protocol: Protocol, - pub instances: Vec>>, - pub proof: Value>, - } - - impl SnarkWitness { - pub fn without_witnesses(&self) -> Self { - SnarkWitness { - protocol: self.protocol.clone(), - instances: self - .instances - .iter() - .map(|instances| vec![Value::unknown(); instances.len()]) - .collect(), - proof: Value::unknown(), - } - } - pub fn proof(&self) -> Value<&[u8]> { - self.proof.as_ref().map(Vec::as_slice) + pub fn proof(&self) -> &[u8] { + &self.proof[..] } } @@ -160,7 +136,8 @@ mod common { let instances = instances.iter().map(Vec::as_slice).collect_vec(); let proof = { - let mut transcript = PoseidonTranscript::::new(Vec::new()); + let mut transcript = + PoseidonTranscript::::new::(Vec::new()); create_proof::<_, ProverGWC<_>, _, _, _, _>( params, pk, @@ -174,7 +151,8 @@ mod common { }; let accept = { - let mut transcript = PoseidonTranscript::::new(proof.as_slice()); + let mut transcript = + PoseidonTranscript::::new::(proof.as_slice()); VerificationStrategy::<_, VerifierGWC<_>>::finalize( verify_proof::<_, VerifierGWC<_>, _, _, _>( params.verifier_params(), @@ -263,7 +241,8 @@ mod common { .map(|n| iter::repeat_with(|| Fr::random(OsRng)).take(n).collect()) .collect(); let proof = { - let mut transcript = PoseidonTranscript::::new(Vec::new()); + let mut transcript = + PoseidonTranscript::::new::(Vec::new()); for _ in 0..protocol .num_witness .iter() @@ -275,8 +254,8 @@ mod common { for _ in 0..protocol.evaluations.len() { transcript.write_scalar(Fr::random(OsRng)).unwrap(); } - let queries = PlonkProof::::empty_queries(&protocol); - for _ in 0..Pcs::estimate_cost(&queries).num_commitment { + let queries = PlonkProof::::empty_queries(&protocol); + for _ in 0..As::estimate_cost(&queries).num_commitment { transcript.write_ec_point(G1Affine::random(OsRng)).unwrap(); } transcript.finalize() @@ -344,19 +323,24 @@ mod application { } mod recursion { - use std::fs::File; + use std::{collections::HashMap, env::var}; use halo2_base::{ - gates::GateInstructions, AssignedValue, Context, ContextParams, QuantumCell::Existing, + gates::{ + builder::{GateThreadBuilder, RangeCircuitBuilder}, + range::RangeConfig, + GateInstructions, RangeChip, RangeInstructions, + }, + AssignedValue, }; - use halo2_ecc::ecc::EccChip; + use halo2_ecc::{bn254::FpChip, ecc::EcPoint}; use halo2_proofs::plonk::{Column, Instance}; use snark_verifier::loader::halo2::{EccInstructions, IntegerInstructions}; use super::*; - type BaseFieldEccChip = halo2_ecc::ecc::BaseFieldEccChip; - type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; + type BaseFieldEccChip<'chip> = halo2_ecc::ecc::BaseFieldEccChip<'chip, G1Affine>; + type Halo2Loader<'chip> = loader::halo2::Halo2Loader>; pub trait StateTransition { type Input; @@ -369,9 +353,9 @@ mod recursion { fn succinct_verify<'a>( svk: &Svk, loader: &Rc>, - snark: &SnarkWitness, - preprocessed_digest: Option>, - ) -> (Vec>>, Vec>>>) { + snark: &Snark, + preprocessed_digest: Option>, + ) -> (Vec>>, Vec>>>) { let protocol = if let Some(preprocessed_digest) = preprocessed_digest { let preprocessed_digest = loader.scalar_from_assigned(preprocessed_digest); let protocol = snark.protocol.loaded_preprocessed_as_witness(loader); @@ -381,11 +365,11 @@ mod recursion { .flat_map(|preprocessed| { let assigned = preprocessed.assigned(); [assigned.x(), assigned.y()] - .map(|coordinate| loader.scalar_from_assigned(coordinate.native().clone())) + .map(|coordinate| loader.scalar_from_assigned(*coordinate.native())) }) .chain(protocol.transcript_initial_state.clone()) .collect_vec(); - loader.assert_eq("", &poseidon(loader, &inputs), &preprocessed_digest).unwrap(); + loader.assert_eq("", &poseidon(loader, &inputs), &preprocessed_digest); protocol } else { snark.protocol.loaded(loader) @@ -398,9 +382,12 @@ mod recursion { instances.iter().map(|instance| loader.assign_scalar(*instance)).collect_vec() }) .collect_vec(); - let mut transcript = PoseidonTranscript::, _>::new(loader, snark.proof()); - let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript); - let accumulators = Plonk::succinct_verify(svk, &protocol, &instances, &proof); + let mut transcript = + PoseidonTranscript::, _>::new::(loader, snark.proof()); + let proof = + PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + let accumulators = + PlonkSuccinctVerifier::verify(svk, &protocol, &instances, &proof).unwrap(); ( instances @@ -415,14 +402,21 @@ mod recursion { fn select_accumulator<'a>( loader: &Rc>, - condition: &AssignedValue<'a, Fr>, + condition: &AssignedValue, lhs: &KzgAccumulator>>, rhs: &KzgAccumulator>>, ) -> Result>>, Error> { let [lhs, rhs]: [_; 2] = [lhs.lhs.assigned(), lhs.rhs.assigned()] .iter() .zip([rhs.lhs.assigned(), rhs.rhs.assigned()].iter()) - .map(|(lhs, rhs)| loader.ecc_chip().select(&mut loader.ctx_mut(), lhs, rhs, condition)) + .map(|(lhs, rhs)| { + loader.ecc_chip().select( + loader.ctx_mut().main(0), + EcPoint::clone(&lhs), + EcPoint::clone(&rhs), + *condition, + ) + }) .collect::>() .try_into() .unwrap(); @@ -435,80 +429,41 @@ mod recursion { fn accumulate<'a>( loader: &Rc>, accumulators: Vec>>>, - as_proof: Value<&'_ [u8]>, + as_proof: &[u8], ) -> KzgAccumulator>> { - let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); + let mut transcript = + PoseidonTranscript::, _>::new::(loader, as_proof); let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); As::verify(&Default::default(), &accumulators, &proof).unwrap() } #[derive(serde::Serialize, serde::Deserialize)] pub struct AggregationConfigParams { - pub strategy: halo2_ecc::fields::fp::FpStrategy, pub degree: u32, pub num_advice: usize, pub num_lookup_advice: usize, pub num_fixed: usize, pub lookup_bits: usize, - pub limb_bits: usize, - pub num_limbs: usize, } #[derive(Clone)] pub struct RecursionConfig { - pub base_field_config: halo2_ecc::fields::fp::FpConfig, + pub range: RangeConfig, pub instance: Column, } - impl RecursionConfig { - pub fn configure(meta: &mut ConstraintSystem, params: AggregationConfigParams) -> Self { - assert!( - params.limb_bits == BITS && params.num_limbs == LIMBS, - "For now we fix limb_bits = {}, otherwise change code", - BITS - ); - let base_field_config = halo2_ecc::fields::fp::FpConfig::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - halo2_base::utils::modulus::(), - 0, - params.degree as usize, - ); - - let instance = meta.instance_column(); - meta.enable_equality(instance); - - Self { base_field_config, instance } - } - - pub fn gate(&self) -> &halo2_base::gates::flex_gate::FlexGateConfig { - &self.base_field_config.range.gate - } - - pub fn range(&self) -> &halo2_base::gates::range::RangeConfig { - &self.base_field_config.range - } - - pub fn ecc_chip(&self) -> halo2_ecc::ecc::BaseFieldEccChip { - EccChip::construct(self.base_field_config.clone()) - } - } - #[derive(Clone)] pub struct RecursionCircuit { svk: Svk, default_accumulator: KzgAccumulator, - app: SnarkWitness, - previous: SnarkWitness, + app: Snark, + previous: Snark, round: usize, instances: Vec, - as_proof: Value>, + as_proof: Vec, + + inner: RangeCircuitBuilder, + assigned_instances: Vec>, } impl RecursionCircuit { @@ -529,11 +484,18 @@ mod recursion { let default_accumulator = KzgAccumulator::new(params.get_g()[1], params.get_g()[0]); let succinct_verify = |snark: &Snark| { - let mut transcript = - PoseidonTranscript::::new(snark.proof.as_slice()); - let proof = - Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript); - Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof) + let mut transcript = PoseidonTranscript::::new::( + snark.proof.as_slice(), + ); + let proof = PlonkSuccinctVerifier::read_proof( + &svk, + &snark.protocol, + &snark.instances, + &mut transcript, + ) + .unwrap(); + PlonkSuccinctVerifier::verify(&svk, &snark.protocol, &snark.instances, &proof) + .unwrap() }; let accumulators = iter::empty() @@ -545,7 +507,8 @@ mod recursion { .collect_vec(); let (accumulator, as_proof) = { - let mut transcript = PoseidonTranscript::::new(Vec::new()); + let mut transcript = + PoseidonTranscript::::new::(Vec::new()); let accumulator = As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng) .unwrap(); @@ -570,15 +533,111 @@ mod recursion { .chain([preprocessed_digest, initial_state, state, Fr::from(round as u64)]) .collect(); - Self { + let builder = GateThreadBuilder::mock(); + let inner = RangeCircuitBuilder::mock(builder); + let mut circuit = Self { svk, default_accumulator, - app: app.into(), - previous: previous.into(), + app: app, + previous: previous, round, instances, - as_proof: Value::known(as_proof), + as_proof, + inner, + assigned_instances: vec![], + }; + circuit.build(); + circuit + } + + fn build(&mut self) { + let lookup_bits = var("LOOKUP_BITS").unwrap().parse().unwrap(); + let range = RangeChip::::default(lookup_bits); + let main_gate = range.gate(); + let mut builder = GateThreadBuilder::mock(); + let ctx = &mut builder; + let [preprocessed_digest, initial_state, state, round] = [ + self.instances[Self::PREPROCESSED_DIGEST_ROW], + self.instances[Self::INITIAL_STATE_ROW], + self.instances[Self::STATE_ROW], + self.instances[Self::ROUND_ROW], + ] + .map(|instance| main_gate.assign_integer(ctx, instance)); + let first_round = main_gate.is_zero(ctx.main(0), round); + let not_first_round = main_gate.not(ctx.main(0), first_round); + + let fp_chip = FpChip::::new(&range, BITS, LIMBS); + let ecc_chip = BaseFieldEccChip::new(&fp_chip); + let loader = Halo2Loader::new(ecc_chip, builder); + let (mut app_instances, app_accumulators) = + succinct_verify(&self.svk, &loader, &self.app, None); + let (mut previous_instances, previous_accumulators) = + succinct_verify(&self.svk, &loader, &self.previous, Some(preprocessed_digest)); + + let default_accmulator = self.load_default_accumulator(&loader).unwrap(); + let previous_accumulators = previous_accumulators + .iter() + .map(|previous_accumulator| { + select_accumulator( + &loader, + &first_round, + &default_accmulator, + previous_accumulator, + ) + .unwrap() + }) + .collect::>(); + + let KzgAccumulator { lhs, rhs } = accumulate( + &loader, + [app_accumulators, previous_accumulators].concat(), + self.as_proof(), + ); + + let lhs = lhs.into_assigned(); + let rhs = rhs.into_assigned(); + let app_instances = app_instances.pop().unwrap(); + let previous_instances = previous_instances.pop().unwrap(); + + let mut builder = loader.take_ctx(); + let ctx = builder.main(0); + for (lhs, rhs) in [ + // Propagate preprocessed_digest + ( + &main_gate.mul(ctx, preprocessed_digest, not_first_round), + &previous_instances[Self::PREPROCESSED_DIGEST_ROW], + ), + // Propagate initial_state + ( + &main_gate.mul(ctx, initial_state, not_first_round), + &previous_instances[Self::INITIAL_STATE_ROW], + ), + // Verify initial_state is same as the first application snark + ( + &main_gate.mul(ctx, initial_state, first_round), + &main_gate.mul(ctx, app_instances[0], first_round), + ), + // Verify current state is same as the current application snark + (&state, &app_instances[1]), + // Verify previous state is same as the current application snark + ( + &main_gate.mul(ctx, app_instances[0], not_first_round), + &previous_instances[Self::STATE_ROW], + ), + // Verify round is increased by 1 when not at first round + (&round, &main_gate.add(ctx, not_first_round, previous_instances[Self::ROUND_ROW])), + ] { + ctx.constrain_equal(lhs, rhs); } + *self.inner.0.builder.borrow_mut() = builder; + + self.assigned_instances.extend( + [lhs.x(), lhs.y(), rhs.x(), rhs.y()] + .into_iter() + .flat_map(|coordinate| coordinate.limbs()) + .chain([preprocessed_digest, initial_state, state, round].iter()) + .copied(), + ); } fn initial_snark(params: &ParamsKZG, vk: Option<&VerifyingKey>) -> Snark { @@ -592,8 +651,8 @@ mod recursion { snark } - fn as_proof(&self) -> Value<&[u8]> { - self.as_proof.as_ref().map(Vec::as_slice) + fn as_proof(&self) -> &[u8] { + &self.as_proof[..] } fn load_default_accumulator<'a>( @@ -603,7 +662,7 @@ mod recursion { let [lhs, rhs] = [self.default_accumulator.lhs, self.default_accumulator.rhs].map(|default| { let assigned = - loader.ecc_chip().assign_constant(&mut loader.ctx_mut(), default).unwrap(); + loader.ecc_chip().assign_constant(&mut loader.ctx_mut(), default); loader.ec_point_from_assigned(assigned) }); Ok(KzgAccumulator::new(lhs, rhs)) @@ -615,26 +674,14 @@ mod recursion { type FloorPlanner = SimpleFloorPlanner; fn without_witnesses(&self) -> Self { - Self { - svk: self.svk, - default_accumulator: self.default_accumulator.clone(), - app: self.app.without_witnesses(), - previous: self.previous.without_witnesses(), - round: self.round, - instances: self.instances.clone(), - as_proof: Value::unknown(), - } + unimplemented!() } fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - let path = std::env::var("VERIFY_CONFIG") - .unwrap_or_else(|_| "configs/verify_circuit.config".to_owned()); - let params: AggregationConfigParams = serde_json::from_reader( - File::open(path.as_str()).unwrap_or_else(|err| panic!("{err:?}")), - ) - .unwrap(); - - RecursionConfig::configure(meta, params) + let range = RangeCircuitBuilder::configure(meta); + let instance = meta.instance_column(); + meta.enable_equality(instance); + RecursionConfig { range, instance } } fn synthesize( @@ -642,152 +689,45 @@ mod recursion { config: Self::Config, mut layouter: impl Layouter, ) -> Result<(), Error> { - config.range().load_lookup_table(&mut layouter)?; - let max_rows = config.range().gate.max_rows; - let main_gate = config.gate(); + let range = config.range; + range.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); + let circuit = &self.inner.0; + let mut assigned_advices = HashMap::new(); + // POC so will only do mock prover and not real prover let mut first_pass = halo2_base::SKIP_FIRST_PASS; // assume using simple floor planner - let mut assigned_instances = Vec::new(); - layouter.assign_region( - || "", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - let mut ctx = Context::new( - region, - ContextParams { - max_rows, - num_context_ids: 1, - fixed_columns: config.base_field_config.range.gate.constants.clone(), - }, - ); - - let [preprocessed_digest, initial_state, state, round] = [ - self.instances[Self::PREPROCESSED_DIGEST_ROW], - self.instances[Self::INITIAL_STATE_ROW], - self.instances[Self::STATE_ROW], - self.instances[Self::ROUND_ROW], - ] - .map(|instance| { - main_gate.assign_integer(&mut ctx, Value::known(instance)).unwrap() - }); - let first_round = main_gate.is_zero(&mut ctx, &round); - let not_first_round = main_gate.not(&mut ctx, Existing(&first_round)); - - let loader = Halo2Loader::new(config.ecc_chip(), ctx); - let (mut app_instances, app_accumulators) = - succinct_verify(&self.svk, &loader, &self.app, None); - let (mut previous_instances, previous_accumulators) = succinct_verify( - &self.svk, - &loader, - &self.previous, - Some(preprocessed_digest.clone()), - ); - - let default_accmulator = self.load_default_accumulator(&loader)?; - let previous_accumulators = previous_accumulators - .iter() - .map(|previous_accumulator| { - select_accumulator( - &loader, - &first_round, - &default_accmulator, - previous_accumulator, - ) - }) - .collect::, Error>>()?; - - let KzgAccumulator { lhs, rhs } = accumulate( - &loader, - [app_accumulators, previous_accumulators].concat(), - self.as_proof(), - ); - - let lhs = lhs.into_assigned(); - let rhs = rhs.into_assigned(); - let app_instances = app_instances.pop().unwrap(); - let previous_instances = previous_instances.pop().unwrap(); - - let mut ctx = loader.ctx_mut(); - for (lhs, rhs) in [ - // Propagate preprocessed_digest - ( - &main_gate.mul( - &mut ctx, - Existing(&preprocessed_digest), - Existing(¬_first_round), - ), - &previous_instances[Self::PREPROCESSED_DIGEST_ROW], - ), - // Propagate initial_state - ( - &main_gate.mul( - &mut ctx, - Existing(&initial_state), - Existing(¬_first_round), - ), - &previous_instances[Self::INITIAL_STATE_ROW], - ), - // Verify initial_state is same as the first application snark - ( - &main_gate.mul( - &mut ctx, - Existing(&initial_state), - Existing(&first_round), - ), - &main_gate.mul( - &mut ctx, - Existing(&app_instances[0]), - Existing(&first_round), - ), - ), - // Verify current state is same as the current application snark - (&state, &app_instances[1]), - // Verify previous state is same as the current application snark - ( - &main_gate.mul( - &mut ctx, - Existing(&app_instances[0]), - Existing(¬_first_round), - ), - &previous_instances[Self::STATE_ROW], - ), - // Verify round is increased by 1 when not at first round - ( - &round, - &main_gate.add( - &mut ctx, - Existing(¬_first_round), - Existing(&previous_instances[Self::ROUND_ROW]), - ), - ), - ] { - ctx.region.constrain_equal(lhs.cell(), rhs.cell()); - } - - // IMPORTANT: - config.base_field_config.finalize(&mut ctx); - #[cfg(feature = "display")] - dbg!(ctx.total_advice); - #[cfg(feature = "display")] - println!("Advice columns used: {}", ctx.advice_alloc[0][0].0 + 1); - - assigned_instances.extend( - [lhs.x(), lhs.y(), rhs.x(), rhs.y()] - .into_iter() - .flat_map(|coordinate| coordinate.limbs()) - .chain([preprocessed_digest, initial_state, state, round].iter()) - .map(|assigned| assigned.cell()), - ); - Ok(()) - }, - )?; - - assert_eq!(assigned_instances.len(), 4 * LIMBS + 4); - for (row, limb) in assigned_instances.into_iter().enumerate() { - layouter.constrain_instance(limb, config.instance, row); + layouter + .assign_region( + || "Recursion Circuit", + |mut region| { + if first_pass { + first_pass = false; + return Ok(()); + } + // clone the builder so we can re-use the circuit for both vk and pk gen + let builder = circuit.builder.borrow(); + let assignments = builder.assign_all( + &range.gate, + &range.lookup_advice, + &range.q_lookup, + &mut region, + Default::default(), + ); + *circuit.break_points.borrow_mut() = assignments.break_points; + assigned_advices = assignments.assigned_advices; + Ok(()) + }, + ) + .unwrap(); + + // expose public instances + let mut layouter = layouter.namespace(|| "expose"); + for (i, instance) in self.assigned_instances.iter().enumerate() { + let cell = instance.cell.unwrap(); + let (cell, _) = assigned_advices + .get(&(cell.context_id, cell.offset)) + .expect("instance not assigned"); + layouter.constrain_instance(*cell, config.instance, i); } Ok(()) @@ -809,10 +749,7 @@ mod recursion { } fn selectors(config: &Self::Config) -> Vec { - config.base_field_config.range.gate.basic_gates[0] - .iter() - .map(|gate| gate.q_enable) - .collect() + config.range.gate.basic_gates[0].iter().map(|gate| gate.q_enable).collect() } } @@ -829,6 +766,9 @@ mod recursion { Fr::zero(), 0, ); + // we cannot auto-configure the circuit because dummy_snark must know the configuration beforehand + // uncomment the following line only in development to test and print out the optimal configuration ahead of time + // recursion.inner.0.builder.borrow().config(recursion_params.k() as usize, Some(10)); gen_pk(recursion_params, &recursion) } @@ -867,9 +807,18 @@ mod recursion { fn main() { let app_params = gen_srs(3); let recursion_config: AggregationConfigParams = - serde_json::from_reader(fs::File::open("configs/verify_circuit.config").unwrap()).unwrap(); + serde_json::from_reader(fs::File::open("configs/example_recursion.json").unwrap()).unwrap(); let k = recursion_config.degree; let recursion_params = gen_srs(k); + let flex_gate_config = FlexGateConfigParams { + strategy: GateStrategy::Vertical, + k: k as usize, + num_advice_per_phase: vec![recursion_config.num_advice], + num_lookup_advice_per_phase: vec![recursion_config.num_lookup_advice], + num_fixed: recursion_config.num_fixed, + }; + set_var("FLEX_GATE_CONFIG_PARAMS", serde_json::to_string(&flex_gate_config).unwrap()); + set_var("LOOKUP_BITS", recursion_config.lookup_bits.to_string()); let app_pk = gen_pk(&app_params, &application::Square::default()); @@ -894,12 +843,14 @@ fn main() { end_timer!(pf_time); assert_eq!(final_state, Fr::from(2u64).pow(&[1 << num_round, 0, 0, 0])); - let accept = { - let svk = recursion_params.get_g()[0].into(); - let dk = (recursion_params.g2(), recursion_params.s_g2()).into(); - let mut transcript = PoseidonTranscript::::new(snark.proof.as_slice()); - let proof = Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript); - Plonk::verify(&svk, &dk, &snark.protocol, &snark.instances, &proof) + { + let dk = + (recursion_params.get_g()[0], recursion_params.g2(), recursion_params.s_g2()).into(); + let mut transcript = + PoseidonTranscript::::new::(snark.proof.as_slice()); + let proof = + PlonkVerifier::read_proof(&dk, &snark.protocol, &snark.instances, &mut transcript) + .unwrap(); + PlonkVerifier::verify(&dk, &snark.protocol, &snark.instances, &proof).unwrap() }; - assert!(accept) } diff --git a/snark-verifier/src/cost.rs b/snark-verifier/src/cost.rs index b085aed8..46bc6145 100644 --- a/snark-verifier/src/cost.rs +++ b/snark-verifier/src/cost.rs @@ -1,44 +1,40 @@ +//! Cost estimation. + use std::ops::Add; -#[derive(Debug, Clone, PartialEq, Eq)] +/// Cost of verification. +#[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct Cost { + /// Number of instances. pub num_instance: usize, + /// Number of commitments in proof. pub num_commitment: usize, + /// Number of evaluations in proof. pub num_evaluation: usize, + /// Number of scalar multiplications to perform. pub num_msm: usize, -} - -impl Cost { - pub fn new( - num_instance: usize, - num_commitment: usize, - num_evaluation: usize, - num_msm: usize, - ) -> Self { - Self { - num_instance, - num_commitment, - num_evaluation, - num_msm, - } - } + /// Number of pairings to perform. + pub num_pairing: usize, } impl Add for Cost { type Output = Cost; - fn add(self, rhs: Cost) -> Self::Output { - Cost::new( - self.num_instance + rhs.num_instance, - self.num_commitment + rhs.num_commitment, - self.num_evaluation + rhs.num_evaluation, - self.num_msm + rhs.num_msm, - ) + fn add(mut self, rhs: Cost) -> Self::Output { + self.num_instance += rhs.num_instance; + self.num_commitment += rhs.num_commitment; + self.num_evaluation += rhs.num_evaluation; + self.num_msm += rhs.num_msm; + self.num_pairing += rhs.num_pairing; + self } } +/// For estimating cost of a verifier. pub trait CostEstimation { + /// Input for [`CostEstimation::estimate_cost`]. type Input; + /// Estimate cost of verifier given the input. fn estimate_cost(input: &Self::Input) -> Cost; } diff --git a/snark-verifier/src/lib.rs b/snark-verifier/src/lib.rs index 1976def7..e9866167 100644 --- a/snark-verifier/src/lib.rs +++ b/snark-verifier/src/lib.rs @@ -1,6 +1,7 @@ -#![allow(clippy::type_complexity)] -#![allow(clippy::too_many_arguments)] -#![allow(clippy::upper_case_acronyms)] +//! Generic (S)NARK verifier. + +#![allow(clippy::type_complexity, clippy::too_many_arguments, clippy::upper_case_acronyms)] +#![deny(missing_debug_implementations, missing_docs, unsafe_code, rustdoc::all)] pub mod cost; pub mod loader; @@ -11,54 +12,16 @@ pub mod verifier; pub(crate) use halo2_base::halo2_proofs; pub(crate) use halo2_proofs::halo2curves as halo2_curves; -#[cfg(feature = "halo2-pse")] -pub(crate) use poseidon; -#[cfg(feature = "halo2-axiom")] -pub(crate) use poseidon_axiom as poseidon; - -pub use poseidon::Spec as PoseidonSpec; -use serde::{Deserialize, Serialize}; +/// Error that could happen while verification. #[derive(Clone, Debug)] pub enum Error { + /// Instances that don't match the amount specified in protocol. InvalidInstances, - InvalidLinearization, - InvalidQuery(util::protocol::Query), - InvalidChallenge(usize), + /// Protocol that is unreasonable for a verifier. + InvalidProtocol(String), + /// Assertion failure while verification. AssertionFailure(String), + /// Transcript error. Transcript(std::io::ErrorKind, String), } - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Protocol -where - C: util::arithmetic::CurveAffine, - L: loader::Loader, -{ - // Common description - #[serde(bound( - serialize = "C::Scalar: Serialize", - deserialize = "C::Scalar: Deserialize<'de>" - ))] - pub domain: util::arithmetic::Domain, - #[serde(bound( - serialize = "L::LoadedEcPoint: Serialize", - deserialize = "L::LoadedEcPoint: Deserialize<'de>" - ))] - pub preprocessed: Vec, - pub num_instance: Vec, - pub num_witness: Vec, - pub num_challenge: Vec, - pub evaluations: Vec, - pub queries: Vec, - pub quotient: util::protocol::QuotientPolynomial, - // Minor customization - #[serde(bound( - serialize = "L::LoadedScalar: Serialize", - deserialize = "L::LoadedScalar: Deserialize<'de>" - ))] - pub transcript_initial_state: Option, - pub instance_committing_key: Option>, - pub linearization: Option, - pub accumulator_indices: Vec>, -} diff --git a/snark-verifier/src/loader.rs b/snark-verifier/src/loader.rs index 297390d0..77a8f54b 100644 --- a/snark-verifier/src/loader.rs +++ b/snark-verifier/src/loader.rs @@ -1,39 +1,51 @@ -use crate::{ - util::{ - arithmetic::{CurveAffine, FieldOps, PrimeField}, - Itertools, - }, - Error, +//! Abstraction of field element and elliptic curve point for generic verifier +//! implementation. + +use crate::util::{ + arithmetic::{CurveAffine, FieldOps, PrimeField}, + Itertools, }; use std::{borrow::Cow, fmt::Debug, iter, ops::Deref}; +/// Native (cpu) loader pub mod native; #[cfg(feature = "loader_evm")] +/// EVM loader pub mod evm; #[cfg(feature = "loader_halo2")] +/// Halo2 loader pub mod halo2; +/// Loaded elliptic curve point. pub trait LoadedEcPoint: Clone + Debug + PartialEq { + /// [`Loader`]. type Loader: Loader; + /// Returns [`Loader`]. fn loader(&self) -> &Self::Loader; } +/// Loaded field element. pub trait LoadedScalar: Clone + Debug + PartialEq + FieldOps { + /// [`Loader`]. type Loader: ScalarLoader; + /// Returns [`Loader`]. fn loader(&self) -> &Self::Loader; + /// Returns square. fn square(&self) -> Self { self.clone() * self } + /// Returns inverse if any. fn invert(&self) -> Option { FieldOps::invert(self) } + /// Returns power to exponent. fn pow_const(&self, mut exp: u64) -> Self { assert!(exp > 0); @@ -55,6 +67,7 @@ pub trait LoadedScalar: Clone + Debug + PartialEq + FieldOps { acc } + /// Returns powers up to exponent `n-1`. fn powers(&self, n: usize) -> Vec { iter::once(self.loader().load_one()) .chain( @@ -65,26 +78,33 @@ pub trait LoadedScalar: Clone + Debug + PartialEq + FieldOps { } } +/// Elliptic curve point loader. pub trait EcPointLoader { + /// [`LoadedEcPoint`]. type LoadedEcPoint: LoadedEcPoint; + /// Load a constant elliptic curve point. fn ec_point_load_const(&self, value: &C) -> Self::LoadedEcPoint; + /// Load `identity` as constant. fn ec_point_load_zero(&self) -> Self::LoadedEcPoint { self.ec_point_load_const(&C::identity()) } + /// Load `generator` as constant. fn ec_point_load_one(&self) -> Self::LoadedEcPoint { self.ec_point_load_const(&C::generator()) } + /// Assert lhs and rhs elliptic curve points are equal. fn ec_point_assert_eq( &self, annotation: &str, lhs: &Self::LoadedEcPoint, rhs: &Self::LoadedEcPoint, - ) -> Result<(), Error>; + ); + /// Perform multi-scalar multiplication. fn multi_scalar_multiplication( pairs: &[(&Self::LoadedScalar, &Self::LoadedEcPoint)], ) -> Self::LoadedEcPoint @@ -92,26 +112,28 @@ pub trait EcPointLoader { Self: ScalarLoader; } +/// Field element loader. pub trait ScalarLoader { + /// [`LoadedScalar`]. type LoadedScalar: LoadedScalar; + /// Load a constant field element. fn load_const(&self, value: &F) -> Self::LoadedScalar; + /// Load `zero` as constant. fn load_zero(&self) -> Self::LoadedScalar { self.load_const(&F::zero()) } + /// Load `one` as constant. fn load_one(&self) -> Self::LoadedScalar { self.load_const(&F::one()) } - fn assert_eq( - &self, - annotation: &str, - lhs: &Self::LoadedScalar, - rhs: &Self::LoadedScalar, - ) -> Result<(), Error>; + /// Assert lhs and rhs field elements are equal. + fn assert_eq(&self, annotation: &str, lhs: &Self::LoadedScalar, rhs: &Self::LoadedScalar); + /// Sum field elements with coefficients and constant. fn sum_with_coeff_and_const( &self, values: &[(F, &Self::LoadedScalar)], @@ -140,6 +162,7 @@ pub trait ScalarLoader { .into_owned() } + /// Sum product of field elements with coefficients and constant. fn sum_products_with_coeff_and_const( &self, values: &[(F, &Self::LoadedScalar, &Self::LoadedScalar)], @@ -151,11 +174,7 @@ pub trait ScalarLoader { let loader = values.first().unwrap().1.loader(); iter::empty() - .chain(if constant == F::zero() { - None - } else { - Some(loader.load_const(&constant)) - }) + .chain(if constant == F::zero() { None } else { Some(loader.load_const(&constant)) }) .chain(values.iter().map(|&(coeff, lhs, rhs)| { if coeff == F::one() { lhs.clone() * rhs @@ -167,10 +186,12 @@ pub trait ScalarLoader { .unwrap() } + /// Sum field elements with coefficients. fn sum_with_coeff(&self, values: &[(F, &Self::LoadedScalar)]) -> Self::LoadedScalar { self.sum_with_coeff_and_const(values, F::zero()) } + /// Sum field elements and constant. fn sum_with_const(&self, values: &[&Self::LoadedScalar], constant: F) -> Self::LoadedScalar { self.sum_with_coeff_and_const( &values.iter().map(|&value| (F::one(), value)).collect_vec(), @@ -178,10 +199,12 @@ pub trait ScalarLoader { ) } + /// Sum field elements. fn sum(&self, values: &[&Self::LoadedScalar]) -> Self::LoadedScalar { self.sum_with_const(values, F::zero()) } + /// Sum product of field elements with coefficients. fn sum_products_with_coeff( &self, values: &[(F, &Self::LoadedScalar, &Self::LoadedScalar)], @@ -189,20 +212,19 @@ pub trait ScalarLoader { self.sum_products_with_coeff_and_const(values, F::zero()) } + /// Sum product of field elements and constant. fn sum_products_with_const( &self, values: &[(&Self::LoadedScalar, &Self::LoadedScalar)], constant: F, ) -> Self::LoadedScalar { self.sum_products_with_coeff_and_const( - &values - .iter() - .map(|&(lhs, rhs)| (F::one(), lhs, rhs)) - .collect_vec(), + &values.iter().map(|&(lhs, rhs)| (F::one(), lhs, rhs)).collect_vec(), constant, ) } + /// Sum product of field elements. fn sum_products( &self, values: &[(&Self::LoadedScalar, &Self::LoadedScalar)], @@ -210,12 +232,12 @@ pub trait ScalarLoader { self.sum_products_with_const(values, F::zero()) } + /// Product of field elements. fn product(&self, values: &[&Self::LoadedScalar]) -> Self::LoadedScalar { - values - .iter() - .fold(self.load_one(), |acc, value| acc * *value) + values.iter().fold(self.load_one(), |acc, value| acc * *value) } + /// Batch invert field elements. fn batch_invert<'a>(values: impl IntoIterator) where Self::LoadedScalar: 'a, @@ -226,10 +248,13 @@ pub trait ScalarLoader { } } +/// [`EcPointLoader`] and [`ScalarLoader`] with some helper methods. pub trait Loader: EcPointLoader + ScalarLoader + Clone + Debug { - fn start_cost_metering(&self, _: &str) {} + /// Start cost metering with an `identifier`. + fn start_cost_metering(&self, _identifier: &str) {} + /// End latest started cost metering. fn end_cost_metering(&self) {} } diff --git a/snark-verifier/src/loader/evm.rs b/snark-verifier/src/loader/evm.rs index 263da0e2..e942b4a3 100644 --- a/snark-verifier/src/loader/evm.rs +++ b/snark-verifier/src/loader/evm.rs @@ -1,17 +1,17 @@ +//! `Loader` implementation for generating yul code as EVM verifier. + mod code; pub(crate) mod loader; -mod util; +pub(crate) mod util; #[cfg(test)] mod test; pub use loader::{EcPoint, EvmLoader, Scalar}; pub use util::{ - compile_yul, encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, ExecutorBuilder, - MemoryChunk, + compile_yul, encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, Address, + ExecutorBuilder, H256, U256, U512, }; -pub use ethereum_types::U256; - #[cfg(test)] pub use test::execute; diff --git a/snark-verifier/src/loader/evm/code.rs b/snark-verifier/src/loader/evm/code.rs index 840d1e67..2fec71d2 100644 --- a/snark-verifier/src/loader/evm/code.rs +++ b/snark-verifier/src/loader/evm/code.rs @@ -44,20 +44,13 @@ impl YulCode { let y_lt_p:bool := lt(y, {base_modulus}) valid := and(x_lt_p, y_lt_p) }} - {{ - let x_is_zero:bool := eq(x, 0) - let y_is_zero:bool := eq(y, 0) - let x_or_y_is_zero:bool := or(x_is_zero, y_is_zero) - let x_and_y_is_not_zero:bool := not(x_or_y_is_zero) - valid := and(x_and_y_is_not_zero, valid) - }} {{ let y_square := mulmod(y, y, {base_modulus}) let x_square := mulmod(x, x, {base_modulus}) let x_cube := mulmod(x_square, x, {base_modulus}) let x_cube_plus_3 := addmod(x_cube, 3, {base_modulus}) - let y_square_eq_x_cube_plus_3:bool := eq(x_cube_plus_3, y_square) - valid := and(y_square_eq_x_cube_plus_3, valid) + let is_affine:bool := eq(x_cube_plus_3, y_square) + valid := and(valid, is_affine) }} }} {} diff --git a/snark-verifier/src/loader/evm/loader.rs b/snark-verifier/src/loader/evm/loader.rs index db15c8d7..98ca5ca4 100644 --- a/snark-verifier/src/loader/evm/loader.rs +++ b/snark-verifier/src/loader/evm/loader.rs @@ -2,7 +2,7 @@ use crate::{ loader::{ evm::{ code::{Precompiled, YulCode}, - fe_to_u256, modulus, u256_to_fe, + fe_to_u256, modulus, u256_to_fe, U256, U512, }, EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader, }, @@ -10,9 +10,7 @@ use crate::{ arithmetic::{CurveAffine, FieldOps, PrimeField}, Itertools, }, - Error, }; -use ethereum_types::{U256, U512}; use hex; use std::{ cell::RefCell, @@ -40,7 +38,7 @@ impl PartialEq for Value { impl Value { fn identifier(&self) -> String { - match &self { + match self { Value::Constant(_) | Value::Memory(_) => format!("{self:?}"), Value::Negated(value) => format!("-({value:?})"), Value::Sum(lhs, rhs) => format!("({lhs:?} + {rhs:?})"), @@ -49,6 +47,7 @@ impl Value { } } +/// `Loader` implementation for generating yul code as EVM verifier. #[derive(Clone, Debug)] pub struct EvmLoader { base_modulus: U256, @@ -67,6 +66,7 @@ fn hex_encode_u256(value: &U256) -> String { } impl EvmLoader { + /// Initialize a [`EvmLoader`] with base and scalar field. pub fn new() -> Rc where Base: PrimeField, @@ -87,18 +87,19 @@ impl EvmLoader { }) } + /// Returns generated yul code. pub fn yul_code(self: &Rc) -> String { let code = " if not(success) { revert(0, 0) } return(0, 0)" .to_string(); self.code.borrow_mut().runtime_append(code); - self.code.borrow().code( - hex_encode_u256(&self.base_modulus), - hex_encode_u256(&self.scalar_modulus), - ) + self.code + .borrow() + .code(hex_encode_u256(&self.base_modulus), hex_encode_u256(&self.scalar_modulus)) } + /// Allocates memory chunk with given `size` and returns pointer. pub fn allocate(self: &Rc, size: usize) -> usize { let ptr = *self.ptr.borrow(); *self.ptr.borrow_mut() += size; @@ -138,6 +139,7 @@ impl EvmLoader { } } + /// Calldata load a field element. pub fn calldataload_scalar(self: &Rc, offset: usize) -> Scalar { let ptr = self.allocate(0x20); let code = format!("mstore({ptr:#x}, mod(calldataload({offset:#x}), f_q))"); @@ -145,6 +147,8 @@ impl EvmLoader { self.scalar(Value::Memory(ptr)) } + /// Calldata load an elliptic curve point and validate it's on affine plane. + /// Note that identity will cause the verification to fail. pub fn calldataload_ec_point(self: &Rc, offset: usize) -> EcPoint { let x_ptr = self.allocate(0x40); let y_ptr = x_ptr + 0x20; @@ -165,6 +169,7 @@ impl EvmLoader { self.ec_point(Value::Memory(x_ptr)) } + /// Decode an elliptic curve point from limbs. pub fn ec_point_from_limbs( self: &Rc, x_limbs: [&Scalar; LIMBS], @@ -210,10 +215,7 @@ impl EvmLoader { } pub(crate) fn scalar(self: &Rc, value: Value) -> Scalar { - let value = if matches!( - value, - Value::Constant(_) | Value::Memory(_) | Value::Negated(_) - ) { + let value = if matches!(value, Value::Constant(_) | Value::Memory(_) | Value::Negated(_)) { value } else { let identifier = value.identifier(); @@ -221,32 +223,23 @@ impl EvmLoader { let ptr = if let Some(ptr) = some_ptr { ptr } else { - let v = self.push(&Scalar { - loader: self.clone(), - value, - }); + let v = self.push(&Scalar { loader: self.clone(), value }); let ptr = self.allocate(0x20); - self.code - .borrow_mut() - .runtime_append(format!("mstore({ptr:#x}, {v})")); + self.code.borrow_mut().runtime_append(format!("mstore({ptr:#x}, {v})")); self.cache.borrow_mut().insert(identifier, ptr); ptr }; Value::Memory(ptr) }; - Scalar { - loader: self.clone(), - value, - } + Scalar { loader: self.clone(), value } } fn ec_point(self: &Rc, value: Value<(U256, U256)>) -> EcPoint { - EcPoint { - loader: self.clone(), - value, - } + EcPoint { loader: self.clone(), value } } + /// Performs `KECCAK256` on `memory[ptr..ptr+len]` and returns pointer of + /// hash. pub fn keccak256(self: &Rc, ptr: usize, len: usize) -> usize { let hash_ptr = self.allocate(0x20); let code = format!("mstore({hash_ptr:#x}, keccak256({ptr:#x}, {len}))"); @@ -254,19 +247,20 @@ impl EvmLoader { hash_ptr } + /// Copies a field element into given `ptr`. pub fn copy_scalar(self: &Rc, scalar: &Scalar, ptr: usize) { let scalar = self.push(scalar); - self.code - .borrow_mut() - .runtime_append(format!("mstore({ptr:#x}, {scalar})")); + self.code.borrow_mut().runtime_append(format!("mstore({ptr:#x}, {scalar})")); } + /// Allocates a new field element and copies the given value into it. pub fn dup_scalar(self: &Rc, scalar: &Scalar) -> Scalar { let ptr = self.allocate(0x20); self.copy_scalar(scalar, ptr); self.scalar(Value::Memory(ptr)) } + /// Allocates a new elliptic curve point and copies the given value into it. pub fn dup_ec_point(self: &Rc, value: &EcPoint) -> EcPoint { let ptr = self.allocate(0x40); match value.value { @@ -340,6 +334,7 @@ impl EvmLoader { self.ec_point(Value::Memory(rd_ptr)) } + /// Performs pairing. pub fn pairing( self: &Rc, lhs: &EcPoint, @@ -392,10 +387,7 @@ impl EvmLoader { return self.scalar(Value::Constant(out.try_into().unwrap())); } - self.scalar(Value::Sum( - Box::new(lhs.value.clone()), - Box::new(rhs.value.clone()), - )) + self.scalar(Value::Sum(Box::new(lhs.value.clone()), Box::new(rhs.value.clone()))) } fn sub(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { @@ -415,10 +407,7 @@ impl EvmLoader { return self.scalar(Value::Constant(out.try_into().unwrap())); } - self.scalar(Value::Product( - Box::new(lhs.value.clone()), - Box::new(rhs.value.clone()), - )) + self.scalar(Value::Product(Box::new(lhs.value.clone()), Box::new(rhs.value.clone()))) } fn neg(self: &Rc, scalar: &Scalar) -> Scalar { @@ -433,28 +422,25 @@ impl EvmLoader { #[cfg(test)] impl EvmLoader { fn start_gas_metering(self: &Rc, identifier: &str) { - self.gas_metering_ids - .borrow_mut() - .push(identifier.to_string()); + self.gas_metering_ids.borrow_mut().push(identifier.to_string()); let code = format!("let {identifier} := gas()"); self.code.borrow_mut().runtime_append(code); } fn end_gas_metering(self: &Rc) { - let code = format!( - "log1(0, 0, sub({}, gas()))", - self.gas_metering_ids.borrow().last().unwrap() - ); + let code = + format!("log1(0, 0, sub({}, gas()))", self.gas_metering_ids.borrow().last().unwrap()); self.code.borrow_mut().runtime_append(code); } pub fn print_gas_metering(self: &Rc, costs: Vec) { for (identifier, cost) in self.gas_metering_ids.borrow().iter().zip(costs) { - println!("{}: {}", identifier, cost); + println!("{identifier}: {cost}"); } } } +/// Elliptic curve point. #[derive(Clone)] pub struct EcPoint { loader: Rc, @@ -480,9 +466,7 @@ impl EcPoint { impl Debug for EcPoint { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("EcPoint") - .field("value", &self.value) - .finish() + f.debug_struct("EcPoint").field("value", &self.value).finish() } } @@ -504,6 +488,7 @@ where } } +/// Field element. #[derive(Clone)] pub struct Scalar { loader: Rc, @@ -526,21 +511,14 @@ impl Scalar { pub(crate) fn ptr(&self) -> usize { match self.value { Value::Memory(ptr) => ptr, - _ => *self - .loader - .cache - .borrow() - .get(&self.value.identifier()) - .unwrap(), + _ => *self.loader.cache.borrow().get(&self.value.identifier()).unwrap(), } } } impl Debug for Scalar { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Scalar") - .field("value", &self.value) - .finish() + f.debug_struct("Scalar").field("value", &self.value).finish() } } @@ -670,7 +648,7 @@ where self.ec_point(Value::Constant((x, y))) } - fn ec_point_assert_eq(&self, _: &str, _: &EcPoint, _: &EcPoint) -> Result<(), Error> { + fn ec_point_assert_eq(&self, _: &str, _: &EcPoint, _: &EcPoint) { unimplemented!() } @@ -685,7 +663,7 @@ where _ => ec_point.loader.ec_point_scalar_mul(ec_point, scalar), }) .reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point)) - .unwrap() + .expect("pairs should not be empty") } } @@ -696,7 +674,7 @@ impl> ScalarLoader for Rc { self.scalar(Value::Constant(fe_to_u256(*value))) } - fn assert_eq(&self, _: &str, _: &Scalar, _: &Scalar) -> Result<(), Error> { + fn assert_eq(&self, _: &str, _: &Scalar, _: &Scalar) { unimplemented!() } @@ -709,9 +687,9 @@ impl> ScalarLoader for Rc { assert_ne!(*coeff, F::zero()); match (*coeff == F::one(), &value.value) { (true, _) => self.push(value), - (false, Value::Constant(value)) => self.push(&self.scalar(Value::Constant( - fe_to_u256(*coeff * u256_to_fe::(*value)), - ))), + (false, Value::Constant(value)) => self.push( + &self.scalar(Value::Constant(fe_to_u256(*coeff * u256_to_fe::(*value)))), + ), (false, _) => { let value = self.push(value); let coeff = self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); @@ -765,9 +743,10 @@ impl> ScalarLoader for Rc { (_, value @ Value::Memory(_), Value::Constant(constant)) | (_, Value::Constant(constant), value @ Value::Memory(_)) => { let v1 = self.push(&self.scalar(value.clone())); - let v2 = self.push(&self.scalar(Value::Constant(fe_to_u256( - *coeff * u256_to_fe::(*constant), - )))); + let v2 = + self.push(&self.scalar(Value::Constant(fe_to_u256( + *coeff * u256_to_fe::(*constant), + )))); format!("mulmod({v1}, {v2}, f_q)") } (true, _, _) => { @@ -858,14 +837,9 @@ impl> ScalarLoader for Rc { let v " ); - for (value, product) in values.iter().rev().zip( - products - .iter() - .rev() - .skip(1) - .map(Some) - .chain(iter::once(None)), - ) { + for (value, product) in + values.iter().rev().zip(products.iter().rev().skip(1).map(Some).chain(iter::once(None))) + { if let Some(product) = product { let val_ptr = value.ptr(); let prod_ptr = product.ptr(); diff --git a/snark-verifier/src/loader/evm/test.rs b/snark-verifier/src/loader/evm/test.rs index e6f3703e..e3467408 100644 --- a/snark-verifier/src/loader/evm/test.rs +++ b/snark-verifier/src/loader/evm/test.rs @@ -1,8 +1,7 @@ use crate::{ - loader::evm::{test::tui::Tui, util::ExecutorBuilder}, + loader::evm::{test::tui::Tui, Address, ExecutorBuilder, U256}, util::Itertools, }; -use ethereum_types::{Address, U256}; use std::env::var_os; mod tui; diff --git a/snark-verifier/src/loader/evm/test/tui.rs b/snark-verifier/src/loader/evm/test/tui.rs index c0c4d7f8..328082c7 100644 --- a/snark-verifier/src/loader/evm/test/tui.rs +++ b/snark-verifier/src/loader/evm/test/tui.rs @@ -1,6 +1,9 @@ //! Copied and modified from https://github.com/foundry-rs/foundry/blob/master/ui/src/lib.rs -use crate::loader::evm::util::executor::{CallKind, DebugStep}; +use crate::loader::evm::{ + util::executor::{CallKind, DebugStep}, + Address, +}; use crossterm::{ event::{ self, DisableMouseCapture, EnableMouseCapture, Event, KeyCode, KeyEvent, KeyModifiers, @@ -9,7 +12,6 @@ use crossterm::{ execute, terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, }; -use ethereum_types::Address; use revm::opcode; use std::{ cmp::{max, min}, @@ -572,10 +574,10 @@ impl Tui { let line_number_format = if line_number == current_step { let step: &DebugStep = &debug_steps[line_number]; - format!("{:0>max_pc_len$x}|▶", step.pc, max_pc_len = max_pc_len) + format!("{:0>max_pc_len$x}|▶", step.pc) } else if line_number < debug_steps.len() { let step: &DebugStep = &debug_steps[line_number]; - format!("{:0>max_pc_len$x}| ", step.pc, max_pc_len = max_pc_len) + format!("{:0>max_pc_len$x}| ", step.pc) } else { "END CALL".to_string() }; @@ -634,7 +636,7 @@ impl Tui { .map(|i| stack_item.byte(i)) .map(|byte| { Span::styled( - format!("{:02x} ", byte), + format!("{byte:02x} "), if affected.is_some() { Style::default().fg(Color::Cyan) } else if byte == 0 { @@ -655,7 +657,7 @@ impl Tui { } let mut spans = vec![Span::styled( - format!("{:0min_len$}| ", i, min_len = min_len), + format!("{i:0min_len$}| "), Style::default().fg(Color::White), )]; spans.extend(words); @@ -727,7 +729,7 @@ impl Tui { .iter() .map(|byte| { Span::styled( - format!("{:02x} ", byte), + format!("{byte:02x} "), if let (Some(w), Some(color)) = (word, color) { if i == w { Style::default().fg(color) @@ -746,7 +748,7 @@ impl Tui { .collect(); let mut spans = vec![Span::styled( - format!("{:0min_len$x}| ", i * 32, min_len = min_len), + format!("{:0min_len$x}| ", i * 32), Style::default().fg(Color::White), )]; spans.extend(words); diff --git a/snark-verifier/src/loader/evm/util.rs b/snark-verifier/src/loader/evm/util.rs index a7df5209..a84df4c3 100644 --- a/snark-verifier/src/loader/evm/util.rs +++ b/snark-verifier/src/loader/evm/util.rs @@ -2,54 +2,55 @@ use crate::{ cost::Cost, util::{arithmetic::PrimeField, Itertools}, }; -use ethereum_types::U256; use std::{ io::Write, iter, process::{Command, Stdio}, }; +pub use primitive_types::{H160 as Address, H256, U256, U512}; + pub(crate) mod executor; pub use executor::ExecutorBuilder; +/// Memory chunk in EVM. +#[derive(Debug)] pub struct MemoryChunk { ptr: usize, len: usize, } impl MemoryChunk { - pub fn new(ptr: usize) -> Self { + pub(crate) fn new(ptr: usize) -> Self { Self { ptr, len: 0 } } - pub fn ptr(&self) -> usize { + pub(crate) fn ptr(&self) -> usize { self.ptr } - pub fn len(&self) -> usize { + pub(crate) fn len(&self) -> usize { self.len } - pub fn is_empty(&self) -> bool { - self.len == 0 - } - - pub fn end(&self) -> usize { + pub(crate) fn end(&self) -> usize { self.ptr + self.len } - pub fn reset(&mut self, ptr: usize) { + pub(crate) fn reset(&mut self, ptr: usize) { self.ptr = ptr; self.len = 0; } - pub fn extend(&mut self, size: usize) { + pub(crate) fn extend(&mut self, size: usize) { self.len += size; } } -// Assume fields implements traits in crate `ff` always have little-endian representation. +/// Convert a [`PrimeField`] into a [`U256`]. +/// Assuming fields that implement traits in crate `ff` always have +/// little-endian representation. pub fn fe_to_u256(f: F) -> U256 where F: PrimeField, @@ -57,6 +58,7 @@ where U256::from_little_endian(f.to_repr().as_ref()) } +/// Convert a [`U256`] into a [`PrimeField`]. pub fn u256_to_fe(value: U256) -> F where F: PrimeField, @@ -67,6 +69,7 @@ where F::from_repr(repr).unwrap() } +/// Returns modulus of [`PrimeField`] as [`U256`]. pub fn modulus() -> U256 where F: PrimeField, @@ -74,6 +77,7 @@ where U256::from_little_endian((-F::one()).to_repr().as_ref()) + 1 } +/// Encode instances and proof into calldata. pub fn encode_calldata(instances: &[Vec], proof: &[u8]) -> Vec where F: PrimeField, @@ -89,16 +93,18 @@ where .collect() } +/// Estimate gas cost with given [`Cost`]. pub fn estimate_gas(cost: Cost) -> usize { let proof_size = cost.num_commitment * 64 + (cost.num_evaluation + cost.num_instance) * 32; let intrinsic_cost = 21000; let calldata_cost = (proof_size as f64 * 15.25).ceil() as usize; - let ec_operation_cost = 113100 + (cost.num_msm - 2) * 6350; + let ec_operation_cost = (45100 + cost.num_pairing * 34000) + (cost.num_msm - 2) * 6350; intrinsic_cost + calldata_cost + ec_operation_cost } +/// Compile given yul `code` into deployment bytecode. pub fn compile_yul(code: &str) -> Vec { let mut cmd = Command::new("solc") .stdin(Stdio::piped()) @@ -108,13 +114,10 @@ pub fn compile_yul(code: &str) -> Vec { .arg("-") .spawn() .unwrap(); - cmd.stdin - .take() - .unwrap() - .write_all(code.as_bytes()) - .unwrap(); + cmd.stdin.take().unwrap().write_all(code.as_bytes()).unwrap(); let output = cmd.wait_with_output().unwrap().stdout; let binary = *split_by_ascii_whitespace(&output).last().unwrap(); + assert!(!binary.is_empty()); hex::decode(binary).unwrap() } @@ -130,5 +133,22 @@ fn split_by_ascii_whitespace(bytes: &[u8]) -> Vec<&[u8]> { start = Some(idx); } } + if let Some(last) = start { + split.push(&bytes[last..]); + } split } + +#[test] +fn test_split_by_ascii_whitespace_1() { + let bytes = b" \x01 \x02 \x03"; + let split = split_by_ascii_whitespace(bytes); + assert_eq!(split, [b"\x01", b"\x02", b"\x03"]); +} + +#[test] +fn test_split_by_ascii_whitespace_2() { + let bytes = b"123456789abc"; + let split = split_by_ascii_whitespace(bytes); + assert_eq!(split, [b"123456789abc"]); +} diff --git a/snark-verifier/src/loader/evm/util/executor.rs b/snark-verifier/src/loader/evm/util/executor.rs index ec9695e0..a7697a0e 100644 --- a/snark-verifier/src/loader/evm/util/executor.rs +++ b/snark-verifier/src/loader/evm/util/executor.rs @@ -1,7 +1,8 @@ -//! Copied and modified from https://github.com/foundry-rs/foundry/blob/master/evm/src/executor/mod.rs +//! Copied and modified from +//! +use crate::loader::evm::{Address, H256, U256}; use bytes::Bytes; -use ethereum_types::{Address, H256, U256, U64}; use revm::{ evm_inner, opcode, spec_opcode_gas, Account, BlockEnv, CallInputs, CallScheme, CreateInputs, CreateScheme, Database, DatabaseCommit, EVMData, Env, ExecutionResult, Gas, GasInspector, @@ -54,7 +55,7 @@ fn get_create2_address_from_hash( ] .concat(); - let hash = keccak256(&bytes); + let hash = keccak256(bytes); let mut bytes = [0u8; 20]; bytes.copy_from_slice(&hash[12..]); @@ -77,14 +78,6 @@ pub struct Log { pub address: Address, pub topics: Vec, pub data: Bytes, - pub block_hash: Option, - pub block_number: Option, - pub transaction_hash: Option, - pub transaction_index: Option, - pub log_index: Option, - pub transaction_log_index: Option, - pub log_type: Option, - pub removed: Option, } #[derive(Clone, Debug, Default)] @@ -98,7 +91,6 @@ impl Inspector for LogCollector { address: *address, topics: topics.to_vec(), data: data.clone(), - ..Default::default() }); } @@ -425,7 +417,6 @@ impl Inspector for Debugger { } } -#[macro_export] macro_rules! call_inspectors { ($id:ident, [ $($inspector:expr),+ ], $call:block) => { $({ @@ -678,16 +669,28 @@ impl Inspector for InspectorStack { } } +/// Call result. +#[derive(Debug)] pub struct RawCallResult { + /// Exit reason pub exit_reason: Return, + /// If the call is reverted or not. pub reverted: bool, + /// Returndata pub result: Bytes, + /// Gas used pub gas_used: u64, + /// Gas refunded pub gas_refunded: u64, + /// Logs emitted during the call pub logs: Vec, + /// Debug information if any pub debug: Option, + /// State changes if any pub state_changeset: Option>, + /// Environment pub env: Env, + /// Output pub out: TransactOut, } @@ -703,6 +706,7 @@ pub struct DeployResult { pub env: Env, } +/// Executor builder. #[derive(Debug, Default)] pub struct ExecutorBuilder { debugger: bool, @@ -710,16 +714,19 @@ pub struct ExecutorBuilder { } impl ExecutorBuilder { + /// Set `debugger`. pub fn set_debugger(mut self, enable: bool) -> Self { self.debugger = enable; self } + /// Set `gas_limit`. pub fn with_gas_limit(mut self, gas_limit: U256) -> Self { self.gas_limit = Some(gas_limit); self } + /// Initialize an `Executor`. pub fn build(self) -> Executor { Executor::new(self.debugger, self.gas_limit.unwrap_or(U256::MAX)) } diff --git a/snark-verifier/src/loader/halo2.rs b/snark-verifier/src/loader/halo2.rs index 0e84d506..4d37c5ce 100644 --- a/snark-verifier/src/loader/halo2.rs +++ b/snark-verifier/src/loader/halo2.rs @@ -1,15 +1,10 @@ -use crate::halo2_proofs::circuit; -use crate::{util::arithmetic::CurveAffine, Protocol}; -use std::rc::Rc; +//! `Loader` implementation for generating verifier in [`halo2_proofs`] circuit. pub(crate) mod loader; mod shim; -#[cfg(test)] -pub(crate) mod test; - pub use loader::{EcPoint, Halo2Loader, Scalar}; -pub use shim::{Context, EccInstructions, IntegerInstructions}; +pub use shim::{EccInstructions, IntegerInstructions}; pub use util::Valuetools; pub use halo2_ecc; @@ -17,7 +12,10 @@ pub use halo2_ecc; mod util { use crate::halo2_proofs::circuit::Value; + /// Helper methods when dealing with iterator of [`Value`]. pub trait Valuetools: Iterator> { + /// Fold zipped values into accumulator, returns `Value::unknown()` if + /// any is `Value::unknown()`. fn fold_zipped(self, init: B, mut f: F) -> Value where Self: Sized, @@ -31,37 +29,3 @@ mod util { impl>> Valuetools for I {} } - -impl Protocol -where - C: CurveAffine, -{ - pub fn loaded_preprocessed_as_witness<'a, EccChip: EccInstructions<'a, C>>( - &self, - loader: &Rc>, - ) -> Protocol>> { - let preprocessed = self - .preprocessed - .iter() - .map(|preprocessed| loader.assign_ec_point(circuit::Value::known(*preprocessed))) - .collect(); - let transcript_initial_state = - self.transcript_initial_state.as_ref().map(|transcript_initial_state| { - loader.assign_scalar(circuit::Value::known(*transcript_initial_state)) - }); - Protocol { - domain: self.domain.clone(), - preprocessed, - num_instance: self.num_instance.clone(), - num_witness: self.num_witness.clone(), - num_challenge: self.num_challenge.clone(), - evaluations: self.evaluations.clone(), - queries: self.queries.clone(), - quotient: self.quotient.clone(), - transcript_initial_state, - instance_committing_key: self.instance_committing_key.clone(), - linearization: self.linearization, - accumulator_indices: self.accumulator_indices.clone(), - } - } -} diff --git a/snark-verifier/src/loader/halo2/loader.rs b/snark-verifier/src/loader/halo2/loader.rs index 24b9df6e..31be9841 100644 --- a/snark-verifier/src/loader/halo2/loader.rs +++ b/snark-verifier/src/loader/halo2/loader.rs @@ -1,4 +1,3 @@ -use crate::halo2_proofs::circuit; use crate::{ loader::{ halo2::shim::{EccInstructions, IntegerInstructions}, @@ -17,18 +16,22 @@ use std::{ rc::Rc, }; +/// `Loader` implementation for generating verifier in [`halo2_proofs`] circuit. #[derive(Debug)] -pub struct Halo2Loader<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { +pub struct Halo2Loader> { ecc_chip: RefCell, ctx: RefCell, num_scalar: RefCell, num_ec_point: RefCell, _marker: PhantomData, #[cfg(test)] + #[allow(dead_code)] row_meterings: RefCell>, } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, EccChip> { +impl> Halo2Loader { + /// Initialize a [`Halo2Loader`] with given [`EccInstructions`] and + /// [`EccInstructions::Context`]. pub fn new(ecc_chip: EccChip, ctx: EccChip::Context) -> Rc { Rc::new(Self { ecc_chip: RefCell::new(ecc_chip), @@ -41,77 +44,82 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc }) } + /// Into [`EccInstructions::Context`]. pub fn into_ctx(self) -> EccChip::Context { self.ctx.into_inner() } + /// Takes [`EccInstructions::Context`] from the [`RefCell`], leaving with Default value + pub fn take_ctx(&self) -> EccChip::Context { + self.ctx.take() + } + + /// Returns reference of [`EccInstructions`]. pub fn ecc_chip(&self) -> Ref { self.ecc_chip.borrow() } + /// Returns reference of [`EccInstructions::ScalarChip`]. pub fn scalar_chip(&self) -> Ref { Ref::map(self.ecc_chip(), |ecc_chip| ecc_chip.scalar_chip()) } + /// Returns reference of [`EccInstructions::Context`]. pub fn ctx(&self) -> Ref { self.ctx.borrow() } + /// Returns mutable reference of [`EccInstructions::Context`]. pub fn ctx_mut(&self) -> RefMut<'_, EccChip::Context> { self.ctx.borrow_mut() } fn assign_const_scalar(self: &Rc, constant: C::Scalar) -> EccChip::AssignedScalar { - self.scalar_chip().assign_constant(&mut self.ctx_mut(), constant).unwrap() + self.scalar_chip().assign_constant(&mut self.ctx_mut(), constant) } - pub fn assign_scalar( - self: &Rc, - scalar: circuit::Value, - ) -> Scalar<'a, C, EccChip> { - let assigned = self.scalar_chip().assign_integer(&mut self.ctx_mut(), scalar).unwrap(); + /// Assign a field element witness. + pub fn assign_scalar(self: &Rc, scalar: C::Scalar) -> Scalar { + let assigned = self.scalar_chip().assign_integer(&mut self.ctx_mut(), scalar); self.scalar_from_assigned(assigned) } + /// Returns [`Scalar`] with assigned field element. pub fn scalar_from_assigned( self: &Rc, assigned: EccChip::AssignedScalar, - ) -> Scalar<'a, C, EccChip> { + ) -> Scalar { self.scalar(Value::Assigned(assigned)) } fn scalar( self: &Rc, value: Value, - ) -> Scalar<'a, C, EccChip> { + ) -> Scalar { let index = *self.num_scalar.borrow(); *self.num_scalar.borrow_mut() += 1; Scalar { loader: self.clone(), index, value: value.into() } } fn assign_const_ec_point(self: &Rc, constant: C) -> EccChip::AssignedEcPoint { - self.ecc_chip().assign_constant(&mut self.ctx_mut(), constant).unwrap() + self.ecc_chip().assign_constant(&mut self.ctx_mut(), constant) } - pub fn assign_ec_point( - self: &Rc, - ec_point: circuit::Value, - ) -> EcPoint<'a, C, EccChip> { - let assigned = self.ecc_chip().assign_point(&mut self.ctx_mut(), ec_point).unwrap(); + /// Assign an elliptic curve point witness. + pub fn assign_ec_point(self: &Rc, ec_point: C) -> EcPoint { + let assigned = self.ecc_chip().assign_point(&mut self.ctx_mut(), ec_point); self.ec_point_from_assigned(assigned) } + /// Returns [`EcPoint`] with assigned elliptic curve point. pub fn ec_point_from_assigned( self: &Rc, assigned: EccChip::AssignedEcPoint, - ) -> EcPoint<'a, C, EccChip> { + ) -> EcPoint { self.ec_point(Value::Assigned(assigned)) } - fn ec_point( - self: &Rc, - value: Value, - ) -> EcPoint<'a, C, EccChip> { + fn ec_point(self: &Rc, value: Value) -> EcPoint { let index = *self.num_ec_point.borrow(); *self.num_ec_point.borrow_mut() += 1; EcPoint { loader: self.clone(), index, value: value.into() } @@ -119,149 +127,109 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc fn add( self: &Rc, - lhs: &Scalar<'a, C, EccChip>, - rhs: &Scalar<'a, C, EccChip>, - ) -> Scalar<'a, C, EccChip> { + lhs: &Scalar, + rhs: &Scalar, + ) -> Scalar { let output = match (lhs.value().deref(), rhs.value().deref()) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs + rhs), (Value::Assigned(assigned), Value::Constant(constant)) - | (Value::Constant(constant), Value::Assigned(assigned)) => self - .scalar_chip() - .sum_with_coeff_and_const( + | (Value::Constant(constant), Value::Assigned(assigned)) => { + Value::Assigned(self.scalar_chip().sum_with_coeff_and_const( &mut self.ctx_mut(), &[(C::Scalar::one(), assigned)], *constant, - ) - .map(Value::Assigned) - .unwrap(), - (Value::Assigned(lhs), Value::Assigned(rhs)) => self - .scalar_chip() - .sum_with_coeff_and_const( + )) + } + (Value::Assigned(lhs), Value::Assigned(rhs)) => { + Value::Assigned(self.scalar_chip().sum_with_coeff_and_const( &mut self.ctx_mut(), &[(C::Scalar::one(), lhs), (C::Scalar::one(), rhs)], C::Scalar::zero(), - ) - .map(Value::Assigned) - .unwrap(), + )) + } }; self.scalar(output) } fn sub( self: &Rc, - lhs: &Scalar<'a, C, EccChip>, - rhs: &Scalar<'a, C, EccChip>, - ) -> Scalar<'a, C, EccChip> { + lhs: &Scalar, + rhs: &Scalar, + ) -> Scalar { let output = match (lhs.value().deref(), rhs.value().deref()) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs - rhs), - (Value::Constant(constant), Value::Assigned(assigned)) => self - .scalar_chip() - .sum_with_coeff_and_const( + (Value::Constant(constant), Value::Assigned(assigned)) => { + Value::Assigned(self.scalar_chip().sum_with_coeff_and_const( &mut self.ctx_mut(), &[(-C::Scalar::one(), assigned)], *constant, - ) - .map(Value::Assigned) - .unwrap(), - (Value::Assigned(assigned), Value::Constant(constant)) => self - .scalar_chip() - .sum_with_coeff_and_const( + )) + } + (Value::Assigned(assigned), Value::Constant(constant)) => { + Value::Assigned(self.scalar_chip().sum_with_coeff_and_const( &mut self.ctx_mut(), &[(C::Scalar::one(), assigned)], -*constant, - ) - .map(Value::Assigned) - .unwrap(), - (Value::Assigned(lhs), Value::Assigned(rhs)) => { - IntegerInstructions::sub(self.scalar_chip().deref(), &mut self.ctx_mut(), lhs, rhs) - .map(Value::Assigned) - .unwrap() + )) } + (Value::Assigned(lhs), Value::Assigned(rhs)) => Value::Assigned( + IntegerInstructions::sub(self.scalar_chip().deref(), &mut self.ctx_mut(), lhs, rhs), + ), }; self.scalar(output) } fn mul( self: &Rc, - lhs: &Scalar<'a, C, EccChip>, - rhs: &Scalar<'a, C, EccChip>, - ) -> Scalar<'a, C, EccChip> { + lhs: &Scalar, + rhs: &Scalar, + ) -> Scalar { let output = match (lhs.value().deref(), rhs.value().deref()) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs * rhs), (Value::Assigned(assigned), Value::Constant(constant)) - | (Value::Constant(constant), Value::Assigned(assigned)) => self - .scalar_chip() - .sum_with_coeff_and_const( + | (Value::Constant(constant), Value::Assigned(assigned)) => { + Value::Assigned(self.scalar_chip().sum_with_coeff_and_const( &mut self.ctx_mut(), &[(*constant, assigned)], C::Scalar::zero(), - ) - .map(Value::Assigned) - .unwrap(), - (Value::Assigned(lhs), Value::Assigned(rhs)) => self - .scalar_chip() - .sum_products_with_coeff_and_const( + )) + } + (Value::Assigned(lhs), Value::Assigned(rhs)) => { + Value::Assigned(self.scalar_chip().sum_products_with_coeff_and_const( &mut self.ctx_mut(), &[(C::Scalar::one(), lhs, rhs)], C::Scalar::zero(), - ) - .map(Value::Assigned) - .unwrap(), + )) + } }; self.scalar(output) } - fn neg(self: &Rc, scalar: &Scalar<'a, C, EccChip>) -> Scalar<'a, C, EccChip> { + fn neg(self: &Rc, scalar: &Scalar) -> Scalar { let output = match scalar.value().deref() { Value::Constant(constant) => Value::Constant(constant.neg()), - Value::Assigned(assigned) => { - IntegerInstructions::neg(self.scalar_chip().deref(), &mut self.ctx_mut(), assigned) - .map(Value::Assigned) - .unwrap() - } + Value::Assigned(assigned) => Value::Assigned(IntegerInstructions::neg( + self.scalar_chip().deref(), + &mut self.ctx_mut(), + assigned, + )), }; self.scalar(output) } - fn invert(self: &Rc, scalar: &Scalar<'a, C, EccChip>) -> Scalar<'a, C, EccChip> { + fn invert(self: &Rc, scalar: &Scalar) -> Scalar { let output = match scalar.value().deref() { Value::Constant(constant) => Value::Constant(Field::invert(constant).unwrap()), - Value::Assigned(assigned) => Value::Assigned( - IntegerInstructions::invert( - self.scalar_chip().deref(), - &mut self.ctx_mut(), - assigned, - ) - .unwrap(), - ), + Value::Assigned(assigned) => Value::Assigned(IntegerInstructions::invert( + self.scalar_chip().deref(), + &mut self.ctx_mut(), + assigned, + )), }; self.scalar(output) } } -#[cfg(test)] -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, EccChip> { - fn start_row_metering(self: &Rc, identifier: &str) { - use crate::loader::halo2::shim::Context; - - self.row_meterings.borrow_mut().push((identifier.to_string(), self.ctx().offset())) - } - - fn end_row_metering(self: &Rc) { - use crate::loader::halo2::shim::Context; - - let mut row_meterings = self.row_meterings.borrow_mut(); - let (_, row) = row_meterings.last_mut().unwrap(); - *row = self.ctx().offset() - *row; - } - - pub fn print_row_metering(self: &Rc) { - for (identifier, cost) in self.row_meterings.borrow().iter() { - println!("{identifier}: {cost}"); - } - } -} - #[derive(Clone, Debug)] pub enum Value { Constant(T), @@ -287,18 +255,21 @@ impl Value { } } +/// Field element #[derive(Clone)] -pub struct Scalar<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { - loader: Rc>, +pub struct Scalar> { + loader: Rc>, index: usize, value: RefCell>, } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Scalar<'a, C, EccChip> { - pub fn loader(&self) -> &Rc> { +impl> Scalar { + /// Returns reference of [`Rc`] + pub fn loader(&self) -> &Rc> { &self.loader } + /// Returns reference of [`EccInstructions::AssignedScalar`]. pub fn into_assigned(self) -> EccChip::AssignedScalar { match self.value.into_inner() { Value::Constant(constant) => self.loader.assign_const_scalar(constant), @@ -306,6 +277,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Scalar<'a, C, EccChip> } } + /// Returns reference of [`EccInstructions::AssignedScalar`]. pub fn assigned(&self) -> Ref { if let Some(constant) = self.maybe_const() { *self.value.borrow_mut() = Value::Assigned(self.loader.assign_const_scalar(constant)) @@ -322,35 +294,33 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Scalar<'a, C, EccChip> } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> PartialEq for Scalar<'a, C, EccChip> { +impl> PartialEq for Scalar { fn eq(&self, other: &Self) -> bool { self.index == other.index } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> LoadedScalar - for Scalar<'a, C, EccChip> -{ - type Loader = Rc>; +impl> LoadedScalar for Scalar { + type Loader = Rc>; fn loader(&self) -> &Self::Loader { &self.loader } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Debug for Scalar<'a, C, EccChip> { +impl> Debug for Scalar { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Scalar").field("value", &self.value).finish() } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> FieldOps for Scalar<'a, C, EccChip> { +impl> FieldOps for Scalar { fn invert(&self) -> Option { Some(self.loader.invert(self)) } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Add for Scalar<'a, C, EccChip> { +impl> Add for Scalar { type Output = Self; fn add(self, rhs: Self) -> Self::Output { @@ -358,7 +328,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Add for Scalar<'a, C, } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Sub for Scalar<'a, C, EccChip> { +impl> Sub for Scalar { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { @@ -366,7 +336,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Sub for Scalar<'a, C, } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Mul for Scalar<'a, C, EccChip> { +impl> Mul for Scalar { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { @@ -374,7 +344,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Mul for Scalar<'a, C, } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Neg for Scalar<'a, C, EccChip> { +impl> Neg for Scalar { type Output = Self; fn neg(self) -> Self::Output { @@ -382,9 +352,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Neg for Scalar<'a, C, } } -impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Add<&'b Self> - for Scalar<'a, C, EccChip> -{ +impl<'b, C: CurveAffine, EccChip: EccInstructions> Add<&'b Self> for Scalar { type Output = Self; fn add(self, rhs: &'b Self) -> Self::Output { @@ -392,9 +360,7 @@ impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Add<&'b Self> } } -impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Sub<&'b Self> - for Scalar<'a, C, EccChip> -{ +impl<'b, C: CurveAffine, EccChip: EccInstructions> Sub<&'b Self> for Scalar { type Output = Self; fn sub(self, rhs: &'b Self) -> Self::Output { @@ -402,9 +368,7 @@ impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Sub<&'b Self> } } -impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Mul<&'b Self> - for Scalar<'a, C, EccChip> -{ +impl<'b, C: CurveAffine, EccChip: EccInstructions> Mul<&'b Self> for Scalar { type Output = Self; fn mul(self, rhs: &'b Self) -> Self::Output { @@ -412,56 +376,52 @@ impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Mul<&'b Self> } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> AddAssign for Scalar<'a, C, EccChip> { +impl> AddAssign for Scalar { fn add_assign(&mut self, rhs: Self) { *self = Halo2Loader::add(&self.loader, self, &rhs) } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> SubAssign for Scalar<'a, C, EccChip> { +impl> SubAssign for Scalar { fn sub_assign(&mut self, rhs: Self) { *self = Halo2Loader::sub(&self.loader, self, &rhs) } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> MulAssign for Scalar<'a, C, EccChip> { +impl> MulAssign for Scalar { fn mul_assign(&mut self, rhs: Self) { *self = Halo2Loader::mul(&self.loader, self, &rhs) } } -impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> AddAssign<&'b Self> - for Scalar<'a, C, EccChip> -{ +impl<'b, C: CurveAffine, EccChip: EccInstructions> AddAssign<&'b Self> for Scalar { fn add_assign(&mut self, rhs: &'b Self) { *self = Halo2Loader::add(&self.loader, self, rhs) } } -impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> SubAssign<&'b Self> - for Scalar<'a, C, EccChip> -{ +impl<'b, C: CurveAffine, EccChip: EccInstructions> SubAssign<&'b Self> for Scalar { fn sub_assign(&mut self, rhs: &'b Self) { *self = Halo2Loader::sub(&self.loader, self, rhs) } } -impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> MulAssign<&'b Self> - for Scalar<'a, C, EccChip> -{ +impl<'b, C: CurveAffine, EccChip: EccInstructions> MulAssign<&'b Self> for Scalar { fn mul_assign(&mut self, rhs: &'b Self) { *self = Halo2Loader::mul(&self.loader, self, rhs) } } +/// Elliptic curve point #[derive(Clone)] -pub struct EcPoint<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { - loader: Rc>, +pub struct EcPoint> { + loader: Rc>, index: usize, value: RefCell>, } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPoint<'a, C, EccChip> { +impl> EcPoint { + /// Into [`EccInstructions::AssignedEcPoint`]. pub fn into_assigned(self) -> EccChip::AssignedEcPoint { match self.value.into_inner() { Value::Constant(constant) => self.loader.assign_const_ec_point(constant), @@ -469,6 +429,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPoint<'a, C, EccChip } } + /// Returns reference of [`EccInstructions::AssignedEcPoint`]. pub fn assigned(&self) -> Ref { if let Some(constant) = self.maybe_const() { *self.value.borrow_mut() = Value::Assigned(self.loader.assign_const_ec_point(constant)) @@ -485,110 +446,97 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPoint<'a, C, EccChip } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> PartialEq for EcPoint<'a, C, EccChip> { +impl> PartialEq for EcPoint { fn eq(&self, other: &Self) -> bool { self.index == other.index } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> LoadedEcPoint - for EcPoint<'a, C, EccChip> -{ - type Loader = Rc>; +impl> LoadedEcPoint for EcPoint { + type Loader = Rc>; fn loader(&self) -> &Self::Loader { &self.loader } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Debug for EcPoint<'a, C, EccChip> { +impl> Debug for EcPoint { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("EcPoint").field("index", &self.index).field("value", &self.value).finish() } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> ScalarLoader - for Rc> +impl> ScalarLoader + for Rc> { - type LoadedScalar = Scalar<'a, C, EccChip>; + type LoadedScalar = Scalar; - fn load_const(&self, value: &C::Scalar) -> Scalar<'a, C, EccChip> { + fn load_const(&self, value: &C::Scalar) -> Scalar { self.scalar(Value::Constant(*value)) } - fn assert_eq( - &self, - annotation: &str, - lhs: &Scalar<'a, C, EccChip>, - rhs: &Scalar<'a, C, EccChip>, - ) -> Result<(), crate::Error> { - self.scalar_chip() - .assert_equal(&mut self.ctx_mut(), &lhs.assigned(), &rhs.assigned()) - .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) + fn assert_eq(&self, _annotation: &str, lhs: &Scalar, rhs: &Scalar) { + self.scalar_chip().assert_equal(&mut self.ctx_mut(), &lhs.assigned(), &rhs.assigned()); } fn sum_with_coeff_and_const( &self, - values: &[(C::Scalar, &Scalar<'a, C, EccChip>)], + values: &[(C::Scalar, &Scalar)], constant: C::Scalar, - ) -> Scalar<'a, C, EccChip> { + ) -> Scalar { let values = values.iter().map(|(coeff, value)| (*coeff, value.assigned())).collect_vec(); - self.scalar(Value::Assigned( - self.scalar_chip() - .sum_with_coeff_and_const(&mut self.ctx_mut(), &values, constant) - .unwrap(), - )) + self.scalar(Value::Assigned(self.scalar_chip().sum_with_coeff_and_const( + &mut self.ctx_mut(), + &values, + constant, + ))) } fn sum_products_with_coeff_and_const( &self, - values: &[(C::Scalar, &Scalar<'a, C, EccChip>, &Scalar<'a, C, EccChip>)], + values: &[(C::Scalar, &Scalar, &Scalar)], constant: C::Scalar, - ) -> Scalar<'a, C, EccChip> { + ) -> Scalar { let values = values .iter() .map(|(coeff, lhs, rhs)| (*coeff, lhs.assigned(), rhs.assigned())) .collect_vec(); - self.scalar(Value::Assigned( - self.scalar_chip() - .sum_products_with_coeff_and_const(&mut self.ctx_mut(), &values, constant) - .unwrap(), - )) + self.scalar(Value::Assigned(self.scalar_chip().sum_products_with_coeff_and_const( + &mut self.ctx_mut(), + &values, + constant, + ))) } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader - for Rc> -{ - type LoadedEcPoint = EcPoint<'a, C, EccChip>; +impl> EcPointLoader for Rc> { + type LoadedEcPoint = EcPoint; - fn ec_point_load_const(&self, ec_point: &C) -> EcPoint<'a, C, EccChip> { + fn ec_point_load_const(&self, ec_point: &C) -> EcPoint { self.ec_point(Value::Constant(*ec_point)) } fn ec_point_assert_eq( &self, - annotation: &str, - lhs: &EcPoint<'a, C, EccChip>, - rhs: &EcPoint<'a, C, EccChip>, - ) -> Result<(), crate::Error> { + _annotation: &str, + lhs: &EcPoint, + rhs: &EcPoint, + ) { if let (Value::Constant(lhs), Value::Constant(rhs)) = (lhs.value().deref(), rhs.value().deref()) { assert_eq!(lhs, rhs); - Ok(()) } else { let lhs = lhs.assigned(); let rhs = rhs.assigned(); - self.ecc_chip() - .assert_equal(&mut self.ctx_mut(), lhs.deref(), rhs.deref()) - .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) + self.ecc_chip().assert_equal(&mut self.ctx_mut(), lhs.deref(), rhs.deref()); } } fn multi_scalar_multiplication( - pairs: &[(&>::LoadedScalar, &EcPoint<'a, C, EccChip>)], - ) -> EcPoint<'a, C, EccChip> { + pairs: &[(&>::LoadedScalar, &EcPoint)], + ) -> EcPoint { + assert!(!pairs.is_empty(), "multi_scalar_multiplication: pairs is empty"); let loader = &pairs[0].0.loader; let (constant, fixed_base, variable_base_non_scaled, variable_base_scaled) = @@ -625,11 +573,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader .into_iter() .map(|(scalar, base)| (scalar.assigned(), base)) .collect_vec(); - loader - .ecc_chip - .borrow_mut() - .fixed_base_msm(&mut loader.ctx_mut(), &fixed_base) - .unwrap() + loader.ecc_chip.borrow_mut().fixed_base_msm(&mut loader.ctx_mut(), &fixed_base) }) .map(RefCell::new); let variable_base_msm = (!variable_base_scaled.is_empty()) @@ -642,37 +586,21 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader .ecc_chip .borrow_mut() .variable_base_msm(&mut loader.ctx_mut(), &variable_base_scaled) - .unwrap() }) .map(RefCell::new); - let output = loader - .ecc_chip() - .sum_with_const( - &mut loader.ctx_mut(), - &variable_base_non_scaled - .into_iter() - .map(EcPoint::assigned) - .chain(fixed_base_msm.as_ref().map(RefCell::borrow)) - .chain(variable_base_msm.as_ref().map(RefCell::borrow)) - .collect_vec(), - constant, - ) - .unwrap(); + let output = loader.ecc_chip().sum_with_const( + &mut loader.ctx_mut(), + &variable_base_non_scaled + .into_iter() + .map(EcPoint::assigned) + .chain(fixed_base_msm.as_ref().map(RefCell::borrow)) + .chain(variable_base_msm.as_ref().map(RefCell::borrow)) + .collect_vec(), + constant, + ); loader.ec_point_from_assigned(output) } } -impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Loader - for Rc> -{ - #[cfg(test)] - fn start_cost_metering(&self, identifier: &str) { - self.start_row_metering(identifier) - } - - #[cfg(test)] - fn end_cost_metering(&self) { - self.end_row_metering() - } -} +impl> Loader for Rc> {} diff --git a/snark-verifier/src/loader/halo2/shim.rs b/snark-verifier/src/loader/halo2/shim.rs index 588e9482..790c9e22 100644 --- a/snark-verifier/src/loader/halo2/shim.rs +++ b/snark-verifier/src/loader/halo2/shim.rs @@ -1,40 +1,34 @@ -use crate::halo2_proofs::{ - circuit::{Cell, Value}, - plonk::Error, -}; use crate::util::arithmetic::{CurveAffine, FieldExt}; use std::{fmt::Debug, ops::Deref}; -pub trait Context: Debug { - fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error>; - - fn offset(&self) -> usize; -} - -pub trait IntegerInstructions<'a, F: FieldExt>: Clone + Debug { - type Context: Context; +/// Instructions to handle field element operations. +pub trait IntegerInstructions: Clone + Debug { + /// Context (either enhanced `region` or some kind of builder). + type Context: Debug; + /// Assigned cell. type AssignedCell: Clone + Debug; + /// Assigned integer. type AssignedInteger: Clone + Debug; + /// Assign an integer witness. fn assign_integer( &self, ctx: &mut Self::Context, - integer: Value, - ) -> Result; + integer: F, // witness + ) -> Self::AssignedInteger; - fn assign_constant( - &self, - ctx: &mut Self::Context, - integer: F, - ) -> Result; + /// Assign an integer constant. + fn assign_constant(&self, ctx: &mut Self::Context, integer: F) -> Self::AssignedInteger; + /// Sum integers with coefficients and constant. fn sum_with_coeff_and_const( &self, ctx: &mut Self::Context, values: &[(F::Scalar, impl Deref)], constant: F::Scalar, - ) -> Result; + ) -> Self::AssignedInteger; + /// Sum product of integers with coefficients and constant. fn sum_products_with_coeff_and_const( &self, ctx: &mut Self::Context, @@ -44,75 +38,78 @@ pub trait IntegerInstructions<'a, F: FieldExt>: Clone + Debug { impl Deref, )], constant: F::Scalar, - ) -> Result; + ) -> Self::AssignedInteger; + /// Returns `lhs - rhs`. fn sub( &self, ctx: &mut Self::Context, lhs: &Self::AssignedInteger, rhs: &Self::AssignedInteger, - ) -> Result; + ) -> Self::AssignedInteger; - fn neg( - &self, - ctx: &mut Self::Context, - value: &Self::AssignedInteger, - ) -> Result; + /// Returns `-value`. + fn neg(&self, ctx: &mut Self::Context, value: &Self::AssignedInteger) -> Self::AssignedInteger; + /// Returns `1/value`. fn invert( &self, ctx: &mut Self::Context, value: &Self::AssignedInteger, - ) -> Result; + ) -> Self::AssignedInteger; + /// Enforce `lhs` and `rhs` are equal. fn assert_equal( &self, ctx: &mut Self::Context, lhs: &Self::AssignedInteger, rhs: &Self::AssignedInteger, - ) -> Result<(), Error>; + ); } -pub trait EccInstructions<'a, C: CurveAffine>: Clone + Debug { - type Context: Context; +/// Instructions to handle elliptic curve point operations. +pub trait EccInstructions: Clone + Debug { + /// Context + type Context: Debug + Default; + /// [`IntegerInstructions`] to handle scalar field operation. type ScalarChip: IntegerInstructions< - 'a, C::Scalar, Context = Self::Context, AssignedCell = Self::AssignedCell, AssignedInteger = Self::AssignedScalar, >; + /// Assigned cell. type AssignedCell: Clone + Debug; + /// Assigned scalar field element. type AssignedScalar: Clone + Debug; + /// Assigned elliptic curve point. type AssignedEcPoint: Clone + Debug; + /// Returns reference of [`EccInstructions::ScalarChip`]. fn scalar_chip(&self) -> &Self::ScalarChip; - fn assign_constant( - &self, - ctx: &mut Self::Context, - ec_point: C, - ) -> Result; + /// Assign a elliptic curve point constant. + fn assign_constant(&self, ctx: &mut Self::Context, ec_point: C) -> Self::AssignedEcPoint; - fn assign_point( - &self, - ctx: &mut Self::Context, - ec_point: Value, - ) -> Result; + /// Assign a elliptic curve point witness. + fn assign_point(&self, ctx: &mut Self::Context, ec_point: C) -> Self::AssignedEcPoint; + /// Sum elliptic curve points and constant. fn sum_with_const( &self, ctx: &mut Self::Context, values: &[impl Deref], constant: C, - ) -> Result; + ) -> Self::AssignedEcPoint; + /// Perform fixed base multi-scalar multiplication. fn fixed_base_msm( &mut self, ctx: &mut Self::Context, pairs: &[(impl Deref, C)], - ) -> Result; + ) -> Self::AssignedEcPoint; + /// Perform variable base multi-scalar multiplication. fn variable_base_msm( &mut self, ctx: &mut Self::Context, @@ -120,76 +117,50 @@ pub trait EccInstructions<'a, C: CurveAffine>: Clone + Debug { impl Deref, impl Deref, )], - ) -> Result; + ) -> Self::AssignedEcPoint; + /// Enforce `lhs` and `rhs` are equal. fn assert_equal( &self, ctx: &mut Self::Context, lhs: &Self::AssignedEcPoint, rhs: &Self::AssignedEcPoint, - ) -> Result<(), Error>; + ); } mod halo2_lib { - use crate::halo2_proofs::{ - circuit::{Cell, Value}, - halo2curves::CurveAffineExt, - plonk::Error, - }; + use crate::halo2_proofs::halo2curves::CurveAffineExt; use crate::{ - loader::halo2::{Context, EccInstructions, IntegerInstructions}, - util::arithmetic::{CurveAffine, Field}, + loader::halo2::{EccInstructions, IntegerInstructions}, + util::arithmetic::CurveAffine, }; use halo2_base::{ self, - gates::{flex_gate::FlexGateConfig, GateInstructions, RangeInstructions}, - utils::PrimeField, + gates::{builder::GateThreadBuilder, GateChip, GateInstructions, RangeInstructions}, AssignedValue, - QuantumCell::{Constant, Existing, Witness}, + QuantumCell::{Constant, Existing}, }; + use halo2_ecc::bigint::ProperCrtUint; use halo2_ecc::{ - bigint::CRTInteger, - ecc::{fixed_base::FixedEcPoint, BaseFieldEccChip, EcPoint}, - fields::FieldChip, + ecc::{BaseFieldEccChip, EcPoint}, + fields::{FieldChip, PrimeField}, }; use std::ops::Deref; - type AssignedInteger<'v, C> = CRTInteger<'v, ::ScalarExt>; - type AssignedEcPoint<'v, C> = EcPoint<::ScalarExt, AssignedInteger<'v, C>>; - - impl<'a, F: PrimeField> Context for halo2_base::Context<'a, F> { - fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error> { - #[cfg(feature = "halo2-axiom")] - self.region.constrain_equal(&lhs, &rhs); - #[cfg(feature = "halo2-pse")] - self.region.constrain_equal(lhs, rhs); - Ok(()) - } - - fn offset(&self) -> usize { - unreachable!() - } - } + type AssignedInteger = ProperCrtUint<::ScalarExt>; + type AssignedEcPoint = EcPoint<::ScalarExt, AssignedInteger>; - impl<'a, F: PrimeField> IntegerInstructions<'a, F> for FlexGateConfig { - type Context = halo2_base::Context<'a, F>; - type AssignedCell = AssignedValue<'a, F>; - type AssignedInteger = AssignedValue<'a, F>; + impl IntegerInstructions for GateChip { + type Context = GateThreadBuilder; + type AssignedCell = AssignedValue; + type AssignedInteger = AssignedValue; - fn assign_integer( - &self, - ctx: &mut Self::Context, - integer: Value, - ) -> Result { - Ok(self.assign_region_last(ctx, vec![Witness(integer)], vec![])) + fn assign_integer(&self, ctx: &mut Self::Context, integer: F) -> Self::AssignedInteger { + ctx.main(0).load_witness(integer) } - fn assign_constant( - &self, - ctx: &mut Self::Context, - integer: F, - ) -> Result { - Ok(self.assign_region_last(ctx, vec![Constant(integer)], vec![])) + fn assign_constant(&self, ctx: &mut Self::Context, integer: F) -> Self::AssignedInteger { + ctx.main(0).load_constant(integer) } fn sum_with_coeff_and_const( @@ -197,16 +168,16 @@ mod halo2_lib { ctx: &mut Self::Context, values: &[(F::Scalar, impl Deref)], constant: F, - ) -> Result { + ) -> Self::AssignedInteger { let mut a = Vec::with_capacity(values.len() + 1); let mut b = Vec::with_capacity(values.len() + 1); if constant != F::zero() { a.push(Constant(constant)); b.push(Constant(F::one())); } - a.extend(values.iter().map(|(_, a)| Existing(a))); + a.extend(values.iter().map(|(_, a)| Existing(*a.deref()))); b.extend(values.iter().map(|(c, _)| Constant(*c))); - Ok(self.inner_product(ctx, a, b)) + self.inner_product(ctx.main(0), a, b) } fn sum_products_with_coeff_and_const( @@ -218,14 +189,14 @@ mod halo2_lib { impl Deref, )], constant: F, - ) -> Result { + ) -> Self::AssignedInteger { match values.len() { - 0 => self.assign_constant(ctx, constant), - _ => Ok(self.sum_products_with_coeff_and_var( - ctx, - values.iter().map(|(c, a, b)| (*c, Existing(a), Existing(b))), + 0 => ctx.main(0).load_constant(constant), + _ => self.sum_products_with_coeff_and_var( + ctx.main(0), + values.iter().map(|(c, a, b)| (*c, Existing(*a.deref()), Existing(*b.deref()))), Constant(constant), - )), + ), } } @@ -234,27 +205,23 @@ mod halo2_lib { ctx: &mut Self::Context, a: &Self::AssignedInteger, b: &Self::AssignedInteger, - ) -> Result { - Ok(GateInstructions::sub(self, ctx, Existing(a), Existing(b))) + ) -> Self::AssignedInteger { + GateInstructions::sub(self, ctx.main(0), Existing(*a), Existing(*b)) } - fn neg( - &self, - ctx: &mut Self::Context, - a: &Self::AssignedInteger, - ) -> Result { - Ok(GateInstructions::neg(self, ctx, Existing(a))) + fn neg(&self, ctx: &mut Self::Context, a: &Self::AssignedInteger) -> Self::AssignedInteger { + GateInstructions::neg(self, ctx.main(0), Existing(*a)) } fn invert( &self, ctx: &mut Self::Context, a: &Self::AssignedInteger, - ) -> Result { + ) -> Self::AssignedInteger { // make sure scalar != 0 - let is_zero = self.is_zero(ctx, a); - self.assert_is_const(ctx, &is_zero, F::zero()); - Ok(GateInstructions::div_unsafe(self, ctx, Constant(F::one()), Existing(a))) + let is_zero = self.is_zero(ctx.main(0), *a); + self.assert_is_const(ctx.main(0), &is_zero, &F::zero()); + GateInstructions::div_unsafe(self, ctx.main(0), Constant(F::one()), Existing(*a)) } fn assert_equal( @@ -262,54 +229,32 @@ mod halo2_lib { ctx: &mut Self::Context, a: &Self::AssignedInteger, b: &Self::AssignedInteger, - ) -> Result<(), Error> { - ctx.region.constrain_equal(a.cell(), b.cell()); - Ok(()) + ) { + ctx.main(0).constrain_equal(a, b); } } - impl<'a, C: CurveAffineExt> EccInstructions<'a, C> for BaseFieldEccChip + impl<'chip, C: CurveAffineExt> EccInstructions for BaseFieldEccChip<'chip, C> where C::ScalarExt: PrimeField, C::Base: PrimeField, { - type Context = halo2_base::Context<'a, C::Scalar>; - type ScalarChip = FlexGateConfig; - type AssignedCell = AssignedValue<'a, C::Scalar>; - type AssignedScalar = AssignedValue<'a, C::Scalar>; - type AssignedEcPoint = AssignedEcPoint<'a, C>; + type Context = GateThreadBuilder; + type ScalarChip = GateChip; + type AssignedCell = AssignedValue; + type AssignedScalar = AssignedValue; + type AssignedEcPoint = AssignedEcPoint; fn scalar_chip(&self) -> &Self::ScalarChip { self.field_chip.range().gate() } - fn assign_constant( - &self, - ctx: &mut Self::Context, - point: C, - ) -> Result { - let fixed = FixedEcPoint::::from_curve( - point, - self.field_chip.num_limbs, - self.field_chip.limb_bits, - ); - Ok(FixedEcPoint::assign( - fixed, - self.field_chip(), - ctx, - self.field_chip().native_modulus(), - )) + fn assign_constant(&self, ctx: &mut Self::Context, point: C) -> Self::AssignedEcPoint { + self.assign_constant_point(ctx.main(0), point) } - fn assign_point( - &self, - ctx: &mut Self::Context, - point: Value, - ) -> Result { - let assigned = self.assign_point(ctx, point); - let is_valid = self.is_on_curve_or_infinity::(ctx, &assigned); - self.field_chip.range.gate.assert_is_const(ctx, &is_valid, C::Scalar::one()); - Ok(assigned) + fn assign_point(&self, ctx: &mut Self::Context, point: C) -> Self::AssignedEcPoint { + self.assign_point(ctx.main(0), point) } fn sum_with_const( @@ -317,64 +262,62 @@ mod halo2_lib { ctx: &mut Self::Context, values: &[impl Deref], constant: C, - ) -> Result { + ) -> Self::AssignedEcPoint { let constant = if bool::from(constant.is_identity()) { None } else { - let constant = EccInstructions::::assign_constant(self, ctx, constant).unwrap(); + let constant = EccInstructions::assign_constant(self, ctx, constant); Some(constant) }; - Ok(self.sum::(ctx, constant.iter().chain(values.iter().map(Deref::deref)))) + self.sum::( + ctx.main(0), + constant.into_iter().chain(values.iter().map(|v| v.deref().clone())), + ) } fn variable_base_msm( &mut self, - ctx: &mut Self::Context, + builder: &mut Self::Context, pairs: &[( impl Deref, impl Deref, )], - ) -> Result { + ) -> Self::AssignedEcPoint { let (scalars, points): (Vec<_>, Vec<_>) = pairs .iter() - .map(|(scalar, point)| (vec![scalar.deref().clone()], point.deref().clone())) + .map(|(scalar, point)| (vec![*scalar.deref()], point.deref().clone())) .unzip(); - - Ok(BaseFieldEccChip::::variable_base_msm::( + BaseFieldEccChip::::variable_base_msm::( self, - ctx, + builder, &points, - &scalars, + scalars, C::Scalar::NUM_BITS as usize, - 4, // empirically clump factor of 4 seems to be best - )) + ) } fn fixed_base_msm( &mut self, - ctx: &mut Self::Context, + builder: &mut Self::Context, pairs: &[(impl Deref, C)], - ) -> Result { + ) -> Self::AssignedEcPoint { let (scalars, points): (Vec<_>, Vec<_>) = pairs .iter() .filter_map(|(scalar, point)| { if point.is_identity().into() { None } else { - Some((vec![scalar.deref().clone()], *point)) + Some((vec![*scalar.deref()], *point)) } }) .unzip(); - - Ok(BaseFieldEccChip::::fixed_base_msm::( + BaseFieldEccChip::::fixed_base_msm::( self, - ctx, + builder, &points, - &scalars, + scalars, C::Scalar::NUM_BITS as usize, - 0, - 4, - )) + ) } fn assert_equal( @@ -382,322 +325,8 @@ mod halo2_lib { ctx: &mut Self::Context, a: &Self::AssignedEcPoint, b: &Self::AssignedEcPoint, - ) -> Result<(), Error> { - self.assert_equal(ctx, a, b); - Ok(()) - } - } -} - -/* -mod halo2_wrong { - use crate::{ - loader::halo2::{Context, EccInstructions, IntegerInstructions}, - util::{ - arithmetic::{CurveAffine, FieldExt, Group}, - Itertools, - }, - }; - use halo2_proofs::{ - circuit::{AssignedCell, Cell, Value}, - plonk::Error, - }; - use halo2_wrong_ecc::{ - integer::rns::Common, - maingate::{ - CombinationOption, CombinationOptionCommon, MainGate, MainGateInstructions, RegionCtx, - Term, - }, - AssignedPoint, BaseFieldEccChip, - }; - use rand::rngs::OsRng; - use std::{iter, ops::Deref}; - - impl<'a, F: FieldExt> Context for RegionCtx<'a, F> { - fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error> { - self.constrain_equal(lhs, rhs) - } - - fn offset(&self) -> usize { - self.offset() - } - } - - impl<'a, F: FieldExt> IntegerInstructions<'a, F> for MainGate { - type Context = RegionCtx<'a, F>; - type AssignedCell = AssignedCell; - type AssignedInteger = AssignedCell; - - fn assign_integer( - &self, - ctx: &mut Self::Context, - integer: Value, - ) -> Result { - self.assign_value(ctx, integer) - } - - fn assign_constant( - &self, - ctx: &mut Self::Context, - integer: F, - ) -> Result { - MainGateInstructions::assign_constant(self, ctx, integer) - } - - fn sum_with_coeff_and_const( - &self, - ctx: &mut Self::Context, - values: &[(F, impl Deref)], - constant: F, - ) -> Result { - self.compose( - ctx, - &values - .iter() - .map(|(coeff, assigned)| Term::Assigned(assigned, *coeff)) - .collect_vec(), - constant, - ) - } - - fn sum_products_with_coeff_and_const( - &self, - ctx: &mut Self::Context, - values: &[( - F, - impl Deref, - impl Deref, - )], - constant: F, - ) -> Result { - match values.len() { - 0 => MainGateInstructions::assign_constant(self, ctx, constant), - 1 => { - let (scalar, lhs, rhs) = &values[0]; - let output = lhs - .value() - .zip(rhs.value()) - .map(|(lhs, rhs)| *scalar * lhs * rhs + constant); - - Ok(self - .apply( - ctx, - [ - Term::Zero, - Term::Zero, - Term::assigned_to_mul(lhs), - Term::assigned_to_mul(rhs), - Term::unassigned_to_sub(output), - ], - constant, - CombinationOption::OneLinerDoubleMul(*scalar), - )? - .swap_remove(4)) - } - _ => { - let (scalar, lhs, rhs) = &values[0]; - self.apply( - ctx, - [Term::assigned_to_mul(lhs), Term::assigned_to_mul(rhs)], - constant, - CombinationOptionCommon::CombineToNextScaleMul(-F::one(), *scalar).into(), - )?; - let acc = - Value::known(*scalar) * lhs.value() * rhs.value() + Value::known(constant); - let output = values.iter().skip(1).fold( - Ok::<_, Error>(acc), - |acc, (scalar, lhs, rhs)| { - acc.and_then(|acc| { - self.apply( - ctx, - [ - Term::assigned_to_mul(lhs), - Term::assigned_to_mul(rhs), - Term::Zero, - Term::Zero, - Term::Unassigned(acc, F::one()), - ], - F::zero(), - CombinationOptionCommon::CombineToNextScaleMul( - -F::one(), - *scalar, - ) - .into(), - )?; - Ok(acc + Value::known(*scalar) * lhs.value() * rhs.value()) - }) - }, - )?; - self.apply( - ctx, - [ - Term::Zero, - Term::Zero, - Term::Zero, - Term::Zero, - Term::Unassigned(output, F::zero()), - ], - F::zero(), - CombinationOptionCommon::OneLinerAdd.into(), - ) - .map(|mut outputs| outputs.swap_remove(4)) - } - } - } - - fn sub( - &self, - ctx: &mut Self::Context, - lhs: &Self::AssignedInteger, - rhs: &Self::AssignedInteger, - ) -> Result { - MainGateInstructions::sub(self, ctx, lhs, rhs) - } - - fn neg( - &self, - ctx: &mut Self::Context, - value: &Self::AssignedInteger, - ) -> Result { - MainGateInstructions::neg_with_constant(self, ctx, value, F::zero()) - } - - fn invert( - &self, - ctx: &mut Self::Context, - value: &Self::AssignedInteger, - ) -> Result { - MainGateInstructions::invert_unsafe(self, ctx, value) - } - - fn assert_equal( - &self, - ctx: &mut Self::Context, - lhs: &Self::AssignedInteger, - rhs: &Self::AssignedInteger, - ) -> Result<(), Error> { - let mut eq = true; - lhs.value().zip(rhs.value()).map(|(lhs, rhs)| { - eq &= lhs == rhs; - }); - MainGateInstructions::assert_equal(self, ctx, lhs, rhs) - .and(eq.then_some(()).ok_or(Error::Synthesis)) - } - } - - impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> EccInstructions<'a, C> - for BaseFieldEccChip - { - type Context = RegionCtx<'a, C::Scalar>; - type ScalarChip = MainGate; - type AssignedCell = AssignedCell; - type AssignedScalar = AssignedCell; - type AssignedEcPoint = AssignedPoint; - - fn scalar_chip(&self) -> &Self::ScalarChip { - self.main_gate() - } - - fn assign_constant( - &self, - ctx: &mut Self::Context, - ec_point: C, - ) -> Result { - self.assign_constant(ctx, ec_point) - } - - fn assign_point( - &self, - ctx: &mut Self::Context, - ec_point: Value, - ) -> Result { - self.assign_point(ctx, ec_point) - } - - fn sum_with_const( - &self, - ctx: &mut Self::Context, - values: &[impl Deref], - constant: C, - ) -> Result { - if values.is_empty() { - return self.assign_constant(ctx, constant); - } - - let constant = (!bool::from(constant.is_identity())) - .then(|| self.assign_constant(ctx, constant)) - .transpose()?; - let output = iter::empty() - .chain(constant) - .chain(values.iter().map(|value| value.deref().clone())) - .map(Ok) - .reduce(|acc, ec_point| self.add(ctx, &acc?, &ec_point?)) - .unwrap()?; - self.normalize(ctx, &output) - } - - fn fixed_base_msm( - &mut self, - ctx: &mut Self::Context, - pairs: &[(impl Deref, C)], - ) -> Result { - assert!(!pairs.is_empty()); - - // FIXME: Implement fixed base MSM in halo2_wrong - let pairs = pairs - .iter() - .filter(|(_, base)| !bool::from(base.is_identity())) - .map(|(scalar, base)| { - Ok::<_, Error>((scalar.deref().clone(), self.assign_constant(ctx, *base)?)) - }) - .collect::, _>>()?; - let pairs = pairs.iter().map(|(scalar, base)| (scalar, base)).collect_vec(); - self.variable_base_msm(ctx, &pairs) - } - - fn variable_base_msm( - &mut self, - ctx: &mut Self::Context, - pairs: &[( - impl Deref, - impl Deref, - )], - ) -> Result { - assert!(!pairs.is_empty()); - - const WINDOW_SIZE: usize = 3; - let pairs = pairs - .iter() - .map(|(scalar, base)| (base.deref().clone(), scalar.deref().clone())) - .collect_vec(); - let output = match self.mul_batch_1d_horizontal(ctx, pairs.clone(), WINDOW_SIZE) { - Err(_) => { - if self.assign_aux(ctx, WINDOW_SIZE, pairs.len()).is_err() { - let aux_generator = Value::known(C::Curve::random(OsRng).into()); - self.assign_aux_generator(ctx, aux_generator)?; - self.assign_aux(ctx, WINDOW_SIZE, pairs.len())?; - } - self.mul_batch_1d_horizontal(ctx, pairs, WINDOW_SIZE) - } - result => result, - }?; - self.normalize(ctx, &output) - } - - fn assert_equal( - &self, - ctx: &mut Self::Context, - lhs: &Self::AssignedEcPoint, - rhs: &Self::AssignedEcPoint, - ) -> Result<(), Error> { - let mut eq = true; - [(lhs.x(), rhs.x()), (lhs.y(), rhs.y())].map(|(lhs, rhs)| { - lhs.integer().zip(rhs.integer()).map(|(lhs, rhs)| { - eq &= lhs.value() == rhs.value(); - }); - }); - self.assert_equal(ctx, lhs, rhs).and(eq.then_some(()).ok_or(Error::Synthesis)) + ) { + self.assert_equal(ctx.main(0), a.clone(), b.clone()); } } } -*/ diff --git a/snark-verifier/src/loader/halo2/test.rs b/snark-verifier/src/loader/halo2/test.rs deleted file mode 100644 index 96de6747..00000000 --- a/snark-verifier/src/loader/halo2/test.rs +++ /dev/null @@ -1,62 +0,0 @@ -use crate::halo2_proofs::circuit::Value; -use crate::{ - util::{arithmetic::CurveAffine, Itertools}, - Protocol, -}; - -#[derive(Clone, Debug)] -pub struct Snark { - pub protocol: Protocol, - pub instances: Vec>, - pub proof: Vec, -} - -impl Snark { - pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { - assert_eq!( - protocol.num_instance, - instances.iter().map(|instances| instances.len()).collect_vec() - ); - Snark { protocol, instances, proof } - } -} - -#[derive(Clone, Debug)] -pub struct SnarkWitness { - pub protocol: Protocol, - pub instances: Vec>>, - pub proof: Value>, -} - -impl From> for SnarkWitness { - fn from(snark: Snark) -> Self { - Self { - protocol: snark.protocol, - instances: snark - .instances - .into_iter() - .map(|instances| instances.into_iter().map(Value::known).collect_vec()) - .collect(), - proof: Value::known(snark.proof), - } - } -} - -impl SnarkWitness { - pub fn new_without_witness(protocol: Protocol) -> Self { - let instances = protocol - .num_instance - .iter() - .map(|num_instance| vec![Value::unknown(); *num_instance]) - .collect(); - SnarkWitness { protocol, instances, proof: Value::unknown() } - } - - pub fn without_witnesses(&self) -> Self { - SnarkWitness::new_without_witness(self.protocol.clone()) - } - - pub fn proof(&self) -> Value<&[u8]> { - self.proof.as_ref().map(Vec::as_slice) - } -} diff --git a/snark-verifier/src/loader/native.rs b/snark-verifier/src/loader/native.rs index 6fce383a..783aaa89 100644 --- a/snark-verifier/src/loader/native.rs +++ b/snark-verifier/src/loader/native.rs @@ -1,3 +1,5 @@ +//! `Loader` implementation in native rust. + use crate::{ loader::{EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, util::arithmetic::{Curve, CurveAffine, FieldOps, PrimeField}, @@ -7,9 +9,12 @@ use lazy_static::lazy_static; use std::fmt::Debug; lazy_static! { + /// NativeLoader instance for [`LoadedEcPoint::loader`] and + /// [`LoadedScalar::loader`] referencing. pub static ref LOADER: NativeLoader = NativeLoader; } +/// `Loader` implementation in native rust. #[derive(Clone, Debug)] pub struct NativeLoader; @@ -47,10 +52,10 @@ impl EcPointLoader for NativeLoader { annotation: &str, lhs: &Self::LoadedEcPoint, rhs: &Self::LoadedEcPoint, - ) -> Result<(), Error> { + ) { lhs.eq(rhs) .then_some(()) - .ok_or_else(|| Error::AssertionFailure(annotation.to_string())) + .unwrap_or_else(|| panic!("{:?}", Error::AssertionFailure(annotation.to_string()))) } fn multi_scalar_multiplication( @@ -61,7 +66,7 @@ impl EcPointLoader for NativeLoader { .cloned() .map(|(scalar, base)| *base * scalar) .reduce(|acc, value| acc + value) - .unwrap() + .expect("pairs should not be empty") .to_affine() } } @@ -73,15 +78,10 @@ impl ScalarLoader for NativeLoader { *value } - fn assert_eq( - &self, - annotation: &str, - lhs: &Self::LoadedScalar, - rhs: &Self::LoadedScalar, - ) -> Result<(), Error> { + fn assert_eq(&self, annotation: &str, lhs: &Self::LoadedScalar, rhs: &Self::LoadedScalar) { lhs.eq(rhs) .then_some(()) - .ok_or_else(|| Error::AssertionFailure(annotation.to_string())) + .unwrap_or_else(|| panic!("{:?}", Error::AssertionFailure(annotation.to_string()))) } } diff --git a/snark-verifier/src/pcs.rs b/snark-verifier/src/pcs.rs index 73aa62ce..65b1325b 100644 --- a/snark-verifier/src/pcs.rs +++ b/snark-verifier/src/pcs.rs @@ -1,3 +1,5 @@ +//! Verifiers for polynomial commitment schemes. + use crate::{ loader::{native::NativeLoader, Loader}, util::{ @@ -8,128 +10,167 @@ use crate::{ Error, }; use rand::Rng; -use std::fmt::Debug; +use std::{fmt::Debug, marker::PhantomData}; -// pub mod ipa; +pub mod ipa; pub mod kzg; -pub trait PolynomialCommitmentScheme: Clone + Debug -where - C: CurveAffine, - L: Loader, -{ - type Accumulator: Clone + Debug; -} - +/// Query to an oracle. +/// It assumes all queries are based on the same point, but with some `shift`. #[derive(Clone, Debug)] pub struct Query { + /// Index of polynomial to query pub poly: usize, + /// Shift of the query point. pub shift: F, + /// Evaluation read from transcript. pub eval: T, } impl Query { + /// Initialize [`Query`] without evaluation. + pub fn new(poly: usize, shift: F) -> Self { + Self { poly, shift, eval: () } + } + + /// Returns [`Query`] with evaluation. pub fn with_evaluation(self, eval: T) -> Query { Query { poly: self.poly, shift: self.shift, eval } } } -pub trait MultiOpenScheme: PolynomialCommitmentScheme +/// Polynomial commitment scheme verifier. +pub trait PolynomialCommitmentScheme: Clone + Debug where C: CurveAffine, L: Loader, { - type SuccinctVerifyingKey: Clone + Debug; + /// Verifying key. + type VerifyingKey: Clone + Debug; + /// Structured proof read from transcript. type Proof: Clone + Debug; + /// Output of verification. + type Output: Clone + Debug; + /// Read [`PolynomialCommitmentScheme::Proof`] from transcript. fn read_proof( - svk: &Self::SuccinctVerifyingKey, + vk: &Self::VerifyingKey, queries: &[Query], transcript: &mut T, - ) -> Self::Proof + ) -> Result where T: TranscriptRead; - fn succinct_verify( - svk: &Self::SuccinctVerifyingKey, + /// Verify [`PolynomialCommitmentScheme::Proof`] and output [`PolynomialCommitmentScheme::Output`]. + fn verify( + vk: &Self::VerifyingKey, commitments: &[Msm], point: &L::LoadedScalar, queries: &[Query], proof: &Self::Proof, - ) -> Self::Accumulator; + ) -> Result; } -pub trait Decider: PolynomialCommitmentScheme +/// Accumulation scheme verifier. +pub trait AccumulationScheme where C: CurveAffine, L: Loader, { - type DecidingKey: Clone + Debug; - type Output: Clone + Debug; - - fn decide(dk: &Self::DecidingKey, accumulator: Self::Accumulator) -> Self::Output; - - fn decide_all(dk: &Self::DecidingKey, accumulators: Vec) -> Self::Output; -} - -pub trait AccumulationScheme: Clone + Debug -where - C: CurveAffine, - L: Loader, - PCS: PolynomialCommitmentScheme, -{ + /// Accumulator to be accumulated. + type Accumulator: Clone + Debug; + /// Verifying key. type VerifyingKey: Clone + Debug; + /// Structured proof read from transcript. type Proof: Clone + Debug; + /// Read a [`AccumulationScheme::Proof`] from transcript. fn read_proof( vk: &Self::VerifyingKey, - instances: &[PCS::Accumulator], + instances: &[Self::Accumulator], transcript: &mut T, ) -> Result where T: TranscriptRead; + /// Verify old [`AccumulationScheme::Accumulator`]s are accumulated properly + /// into a new one with the [`AccumulationScheme::Proof`], and returns the + /// new one as output. fn verify( vk: &Self::VerifyingKey, - instances: &[PCS::Accumulator], + instances: &[Self::Accumulator], proof: &Self::Proof, - ) -> Result; + ) -> Result; } -pub trait AccumulationSchemeProver: AccumulationScheme +/// Accumulation scheme decider. +/// When accumulation is going to end, the decider will perform the check if the +/// final accumulator is valid or not, where the check is usually much more +/// expensive than accumulation verification. +pub trait AccumulationDecider: AccumulationScheme where C: CurveAffine, - PCS: PolynomialCommitmentScheme, + L: Loader, { + /// Deciding key. The key for decider for perform the final accumulator + /// check. + type DecidingKey: Clone + Debug; + + /// Decide if a [`AccumulationScheme::Accumulator`] is valid. + fn decide(dk: &Self::DecidingKey, accumulator: Self::Accumulator) -> Result<(), Error>; + + /// Decide if all [`AccumulationScheme::Accumulator`]s are valid. + fn decide_all( + dk: &Self::DecidingKey, + accumulators: Vec, + ) -> Result<(), Error>; +} + +/// Accumulation scheme prover. +pub trait AccumulationSchemeProver: AccumulationScheme +where + C: CurveAffine, +{ + /// Proving key. type ProvingKey: Clone + Debug; + /// Create a proof that argues if old [`AccumulationScheme::Accumulator`]s + /// are properly accumulated into the new one, and returns the new one as + /// output. fn create_proof( pk: &Self::ProvingKey, - instances: &[PCS::Accumulator], + instances: &[Self::Accumulator], transcript: &mut T, rng: R, - ) -> Result + ) -> Result where T: TranscriptWrite, R: Rng; } -pub trait AccumulatorEncoding: Clone + Debug +/// Accumulator encoding. +pub trait AccumulatorEncoding: Clone + Debug where C: CurveAffine, L: Loader, - PCS: PolynomialCommitmentScheme, { - fn from_repr(repr: &[&L::LoadedScalar]) -> Result; + /// Accumulator to be encoded. + type Accumulator: Clone + Debug; + + /// Decode an [`AccumulatorEncoding::Accumulator`] from serveral + /// [`crate::loader::ScalarLoader::LoadedScalar`]s. + fn from_repr(repr: &[&L::LoadedScalar]) -> Result; } -impl AccumulatorEncoding for () +impl AccumulatorEncoding for PhantomData where C: CurveAffine, L: Loader, PCS: PolynomialCommitmentScheme, { - fn from_repr(_: &[&L::LoadedScalar]) -> Result { + type Accumulator = PCS::Output; + + fn from_repr(_: &[&L::LoadedScalar]) -> Result { unimplemented!() } } diff --git a/snark-verifier/src/pcs/ipa.rs b/snark-verifier/src/pcs/ipa.rs index a2b34824..6358e15d 100644 --- a/snark-verifier/src/pcs/ipa.rs +++ b/snark-verifier/src/pcs/ipa.rs @@ -1,6 +1,8 @@ +//! Inner product argument polynomial commitment scheme and accumulation scheme. +//! The notations are following . + use crate::{ loader::{native::NativeLoader, LoadedScalar, Loader, ScalarLoader}, - pcs::PolynomialCommitmentScheme, util::{ arithmetic::{ inner_product, powers, Curve, CurveAffine, Domain, Field, Fraction, PrimeField, @@ -24,24 +26,17 @@ mod multiopen; pub use accumulation::{IpaAs, IpaAsProof}; pub use accumulator::IpaAccumulator; pub use decider::IpaDecidingKey; -pub use multiopen::{Bgh19, Bgh19Proof, Bgh19SuccinctVerifyingKey}; +pub use multiopen::{Bgh19, Bgh19Proof}; +/// Inner product argument polynomial commitment scheme. #[derive(Clone, Debug)] -pub struct Ipa(PhantomData<(C, MOS)>); - -impl PolynomialCommitmentScheme for Ipa -where - C: CurveAffine, - L: Loader, - MOS: Clone + Debug, -{ - type Accumulator = IpaAccumulator; -} +pub struct Ipa(PhantomData); -impl Ipa +impl Ipa where C: CurveAffine, { + /// Create an inner product argument. pub fn create_proof( pk: &IpaProvingKey, p: &[C::Scalar], @@ -127,6 +122,7 @@ where Ok(IpaAccumulator::new(xi, bases[0])) } + /// Read [`IpaProof`] from transcript. pub fn read_proof>( svk: &IpaSuccinctVerifyingKey, transcript: &mut T, @@ -137,6 +133,7 @@ where IpaProof::read(svk, transcript) } + /// Perform the succinct check of the proof and returns [`IpaAccumulator`]. pub fn succinct_verify>( svk: &IpaSuccinctVerifyingKey, commitment: &Msm, @@ -151,11 +148,8 @@ where let h_prime = h * &proof.xi_0; let lhs = { - let c_prime = match ( - s.as_ref(), - proof.c_bar_alpha.as_ref(), - proof.omega_prime.as_ref(), - ) { + let c_prime = match (s.as_ref(), proof.c_bar_alpha.as_ref(), proof.omega_prime.as_ref()) + { (Some(s), Some((c_bar, alpha)), Some(omega_prime)) => { let s = Msm::::base(s); commitment.clone() + Msm::base(c_bar) * alpha - s * omega_prime @@ -180,37 +174,47 @@ where (u * &proof.c + h_prime * &v_prime).evaluate(None) }; - loader.ec_point_assert_eq("C_k == c[U] + v'[H']", &lhs, &rhs)?; + loader.ec_point_assert_eq("C_k == c[U] + v'[H']", &lhs, &rhs); Ok(IpaAccumulator::new(proof.xi(), proof.u.clone())) } } +/// Inner product argument proving key. #[derive(Clone, Debug)] pub struct IpaProvingKey { + /// Working domain. pub domain: Domain, + /// $\mathbb{G}$ pub g: Vec, + /// $H$ pub h: C, + /// $S$ pub s: Option, } impl IpaProvingKey { + /// Initialize an [`IpaProvingKey`]. pub fn new(domain: Domain, g: Vec, h: C, s: Option) -> Self { Self { domain, g, h, s } } + /// Returns if it supports zero-knowledge. pub fn zk(&self) -> bool { self.s.is_some() } + /// Returns [`IpaSuccinctVerifyingKey`]. pub fn svk(&self) -> IpaSuccinctVerifyingKey { - IpaSuccinctVerifyingKey::new(self.domain.clone(), self.h, self.s) + IpaSuccinctVerifyingKey::new(self.domain.clone(), self.g[0], self.h, self.s) } + /// Returns [`IpaDecidingKey`]. pub fn dk(&self) -> IpaDecidingKey { - IpaDecidingKey::new(self.g.clone()) + IpaDecidingKey::new(self.svk(), self.g.clone()) } + /// Commit a polynomial into with a randomizer if any. pub fn commit(&self, poly: &Polynomial, omega: Option) -> C { let mut c = multi_scalar_multiplication(&poly[..], &self.g); match (self.s, omega) { @@ -224,15 +228,13 @@ impl IpaProvingKey { impl IpaProvingKey { #[cfg(test)] - pub fn rand(k: usize, zk: bool, mut rng: R) -> Self { + pub(crate) fn rand(k: usize, zk: bool, mut rng: R) -> Self { use crate::util::arithmetic::{root_of_unity, Group}; let domain = Domain::new(k, root_of_unity(k)); let mut g = vec![C::default(); 1 << k]; C::Curve::batch_normalize( - &iter::repeat_with(|| C::Curve::random(&mut rng)) - .take(1 << k) - .collect_vec(), + &iter::repeat_with(|| C::Curve::random(&mut rng)).take(1 << k).collect_vec(), &mut g, ); let h = C::Curve::random(&mut rng).to_affine(); @@ -241,23 +243,32 @@ impl IpaProvingKey { } } +/// Inner product argument succinct verifying key. #[derive(Clone, Debug)] pub struct IpaSuccinctVerifyingKey { + /// Working domain. pub domain: Domain, + /// $G_0$ + pub g: C, + /// $H$ pub h: C, + /// $S$ pub s: Option, } impl IpaSuccinctVerifyingKey { - pub fn new(domain: Domain, h: C, s: Option) -> Self { - Self { domain, h, s } + /// Initialize an [`IpaSuccinctVerifyingKey`]. + pub fn new(domain: Domain, g: C, h: C, s: Option) -> Self { + Self { domain, g, h, s } } + /// Returns if it supports zero-knowledge. pub fn zk(&self) -> bool { self.s.is_some() } } +/// Inner product argument #[derive(Clone, Debug)] pub struct IpaProof where @@ -277,7 +288,7 @@ where C: CurveAffine, L: Loader, { - pub fn new( + fn new( c_bar_alpha: Option<(L::LoadedEcPoint, L::LoadedScalar)>, omega_prime: Option, xi_0: L::LoadedScalar, @@ -285,16 +296,10 @@ where u: L::LoadedEcPoint, c: L::LoadedScalar, ) -> Self { - Self { - c_bar_alpha, - omega_prime, - xi_0, - rounds, - u, - c, - } + Self { c_bar_alpha, omega_prime, xi_0, rounds, u, c } } + /// Read [`crate::pcs::AccumulationScheme::Proof`] from transcript. pub fn read(svk: &IpaSuccinctVerifyingKey, transcript: &mut T) -> Result where T: TranscriptRead, @@ -320,33 +325,25 @@ where .collect::, _>>()?; let u = transcript.read_ec_point()?; let c = transcript.read_scalar()?; - Ok(Self { - c_bar_alpha, - omega_prime, - xi_0, - rounds, - u, - c, - }) + Ok(Self { c_bar_alpha, omega_prime, xi_0, rounds, u, c }) } + /// Returns $\{\xi_0, \xi_1, ...\}$. pub fn xi(&self) -> Vec { self.rounds.iter().map(|round| round.xi.clone()).collect() } + /// Returns $\{\xi_0^{-1}, \xi_1^{-1}, ...\}$. pub fn xi_inv(&self) -> Vec { let mut xi_inv = self.xi().into_iter().map(Fraction::one_over).collect_vec(); L::batch_invert(xi_inv.iter_mut().filter_map(Fraction::denom_mut)); xi_inv.iter_mut().for_each(Fraction::evaluate); - xi_inv - .into_iter() - .map(|xi_inv| xi_inv.evaluated().clone()) - .collect() + xi_inv.into_iter().map(|xi_inv| xi_inv.evaluated().clone()).collect() } } #[derive(Clone, Debug)] -pub struct Round +struct Round where C: CurveAffine, L: Loader, @@ -361,12 +358,12 @@ where C: CurveAffine, L: Loader, { - pub fn new(l: L::LoadedEcPoint, r: L::LoadedEcPoint, xi: L::LoadedScalar) -> Self { + fn new(l: L::LoadedEcPoint, r: L::LoadedEcPoint, xi: L::LoadedScalar) -> Self { Self { l, r, xi } } } -pub fn h_eval>(xi: &[T], z: &T) -> T { +fn h_eval>(xi: &[T], z: &T) -> T { let loader = z.loader(); let one = loader.load_one(); loader.product( @@ -379,7 +376,7 @@ pub fn h_eval>(xi: &[T], z: &T) -> T { ) } -pub fn h_coeffs(xi: &[F], scalar: F) -> Vec { +fn h_coeffs(xi: &[F], scalar: F) -> Vec { assert!(!xi.is_empty()); let mut coeffs = vec![F::zero(); 1 << xi.len()]; @@ -402,7 +399,7 @@ mod test { use crate::{ pcs::{ ipa::{self, IpaProvingKey}, - Decider, + AccumulationDecider, }, util::{arithmetic::Field, msm::Msm, poly::Polynomial}, }; @@ -414,7 +411,8 @@ mod test { #[test] fn test_ipa() { - type Ipa = ipa::Ipa; + type Ipa = ipa::Ipa; + type IpaAs = ipa::IpaAs; let k = 10; let mut rng = OsRng; @@ -441,7 +439,7 @@ mod test { }; let dk = pk.dk(); - assert!(Ipa::decide(&dk, accumulator)); + assert!(IpaAs::decide(&dk, accumulator).is_ok()); } } } diff --git a/snark-verifier/src/pcs/ipa/accumulation.rs b/snark-verifier/src/pcs/ipa/accumulation.rs index 07f294de..56d61aa7 100644 --- a/snark-verifier/src/pcs/ipa/accumulation.rs +++ b/snark-verifier/src/pcs/ipa/accumulation.rs @@ -4,7 +4,7 @@ use crate::{ ipa::{ h_coeffs, h_eval, Ipa, IpaAccumulator, IpaProof, IpaProvingKey, IpaSuccinctVerifyingKey, }, - AccumulationScheme, AccumulationSchemeProver, PolynomialCommitmentScheme, + AccumulationScheme, AccumulationSchemeProver, }, util::{ arithmetic::{Curve, CurveAffine, Field}, @@ -16,23 +16,26 @@ use crate::{ Error, }; use rand::Rng; -use std::{array, iter, marker::PhantomData}; +use std::{array, fmt::Debug, iter, marker::PhantomData}; +/// Inner product argument accumulation scheme. The second generic `MOS` stands +/// for different kind of multi-open scheme. #[derive(Clone, Debug)] -pub struct IpaAs(PhantomData); +pub struct IpaAs(PhantomData<(C, MOS)>); -impl AccumulationScheme for IpaAs +impl AccumulationScheme for IpaAs where C: CurveAffine, L: Loader, - PCS: PolynomialCommitmentScheme>, + MOS: Clone + Debug, { + type Accumulator = IpaAccumulator; type VerifyingKey = IpaSuccinctVerifyingKey; - type Proof = IpaAsProof; + type Proof = IpaAsProof; fn read_proof( vk: &Self::VerifyingKey, - instances: &[PCS::Accumulator], + instances: &[Self::Accumulator], transcript: &mut T, ) -> Result where @@ -43,9 +46,9 @@ where fn verify( vk: &Self::VerifyingKey, - instances: &[PCS::Accumulator], + instances: &[Self::Accumulator], proof: &Self::Proof, - ) -> Result { + ) -> Result { let loader = proof.z.loader(); let s = vk.s.as_ref().map(|s| loader.ec_point_load_const(s)); @@ -66,34 +69,32 @@ where } let v = loader.sum_products(&powers_of_alpha.iter().zip(h.iter()).collect_vec()); - Ipa::::succinct_verify(vk, &c, &proof.z, &v, &proof.ipa) + Ipa::succinct_verify(vk, &c, &proof.z, &v, &proof.ipa) } } +/// Inner product argument accumulation scheme proof. #[derive(Clone, Debug)] -pub struct IpaAsProof +pub struct IpaAsProof where C: CurveAffine, L: Loader, - PCS: PolynomialCommitmentScheme>, { a_b_u: Option<(L::LoadedScalar, L::LoadedScalar, L::LoadedEcPoint)>, omega: Option, alpha: L::LoadedScalar, z: L::LoadedScalar, ipa: IpaProof, - _marker: PhantomData, } -impl IpaAsProof +impl IpaAsProof where C: CurveAffine, L: Loader, - PCS: PolynomialCommitmentScheme>, { fn read( vk: &IpaSuccinctVerifyingKey, - instances: &[PCS::Accumulator], + instances: &[IpaAccumulator], transcript: &mut T, ) -> Result where @@ -130,23 +131,23 @@ where let ipa = IpaProof::read(vk, transcript)?; - Ok(Self { a_b_u, omega, alpha, z, ipa, _marker: PhantomData }) + Ok(Self { a_b_u, omega, alpha, z, ipa }) } } -impl AccumulationSchemeProver for IpaAs +impl AccumulationSchemeProver for IpaAs where C: CurveAffine, - PCS: PolynomialCommitmentScheme>, + MOS: Clone + Debug, { type ProvingKey = IpaProvingKey; fn create_proof( pk: &Self::ProvingKey, - instances: &[PCS::Accumulator], + instances: &[IpaAccumulator], transcript: &mut T, mut rng: R, - ) -> Result + ) -> Result, Error> where T: TranscriptWrite, R: Rng, @@ -204,7 +205,7 @@ where .map(|(power_of_alpha, h)| h * power_of_alpha) .sum::>(); - Ipa::::create_proof(pk, &h.to_vec(), &z, omega.as_ref(), transcript, &mut rng) + Ipa::create_proof(pk, &h.to_vec(), &z, omega.as_ref(), transcript, &mut rng) } } @@ -217,7 +218,7 @@ mod test { use crate::{ pcs::{ ipa::{self, IpaProvingKey}, - AccumulationScheme, AccumulationSchemeProver, Decider, + AccumulationDecider, AccumulationScheme, AccumulationSchemeProver, }, util::{arithmetic::Field, msm::Msm, poly::Polynomial, Itertools}, }; @@ -226,8 +227,8 @@ mod test { #[test] fn test_ipa_as() { - type Ipa = ipa::Ipa; - type IpaAs = ipa::IpaAs; + type Ipa = ipa::Ipa; + type IpaAs = ipa::IpaAs; let k = 10; let zk = true; @@ -274,6 +275,6 @@ mod test { }; let dk = pk.dk(); - assert!(Ipa::decide(&dk, accumulator)); + assert!(IpaAs::decide(&dk, accumulator).is_ok()); } } diff --git a/snark-verifier/src/pcs/ipa/accumulator.rs b/snark-verifier/src/pcs/ipa/accumulator.rs index 27d9d5c7..fc8d9fd7 100644 --- a/snark-verifier/src/pcs/ipa/accumulator.rs +++ b/snark-verifier/src/pcs/ipa/accumulator.rs @@ -1,12 +1,15 @@ use crate::{loader::Loader, util::arithmetic::CurveAffine}; +/// Inner product argument accumulator. #[derive(Clone, Debug)] pub struct IpaAccumulator where C: CurveAffine, L: Loader, { + /// $\xi$. pub xi: Vec, + /// $U$. pub u: L::LoadedEcPoint, } @@ -15,6 +18,7 @@ where C: CurveAffine, L: Loader, { + /// Initialize a [`IpaAccumulator`]. pub fn new(xi: Vec, u: L::LoadedEcPoint) -> Self { Self { xi, u } } diff --git a/snark-verifier/src/pcs/ipa/decider.rs b/snark-verifier/src/pcs/ipa/decider.rs index 2cf8c6cc..5235a857 100644 --- a/snark-verifier/src/pcs/ipa/decider.rs +++ b/snark-verifier/src/pcs/ipa/decider.rs @@ -1,17 +1,23 @@ +use crate::{pcs::ipa::IpaSuccinctVerifyingKey, util::arithmetic::CurveAffine}; + +/// Inner product argument deciding key. #[derive(Clone, Debug)] -pub struct IpaDecidingKey { - pub g: Vec, +pub struct IpaDecidingKey { + svk: IpaSuccinctVerifyingKey, + /// Committing key. + g: Vec, } -impl IpaDecidingKey { - pub fn new(g: Vec) -> Self { - Self { g } +impl IpaDecidingKey { + /// Initialize an [`IpaDecidingKey`]. + pub fn new(svk: IpaSuccinctVerifyingKey, g: Vec) -> Self { + Self { svk, g } } } -impl From> for IpaDecidingKey { - fn from(g: Vec) -> IpaDecidingKey { - IpaDecidingKey::new(g) +impl AsRef> for IpaDecidingKey { + fn as_ref(&self) -> &IpaSuccinctVerifyingKey { + &self.svk } } @@ -19,39 +25,45 @@ mod native { use crate::{ loader::native::NativeLoader, pcs::{ - ipa::{h_coeffs, Ipa, IpaAccumulator, IpaDecidingKey}, - Decider, + ipa::{h_coeffs, IpaAccumulator, IpaAs, IpaDecidingKey}, + AccumulationDecider, }, util::{ arithmetic::{Curve, CurveAffine, Field}, msm::multi_scalar_multiplication, + Itertools, }, + Error, }; use std::fmt::Debug; - impl Decider for Ipa + impl AccumulationDecider for IpaAs where C: CurveAffine, MOS: Clone + Debug, { type DecidingKey = IpaDecidingKey; - type Output = bool; fn decide( dk: &Self::DecidingKey, IpaAccumulator { u, xi }: IpaAccumulator, - ) -> bool { + ) -> Result<(), Error> { let h = h_coeffs(&xi, C::Scalar::one()); - u == multi_scalar_multiplication(&h, &dk.g).to_affine() + (u == multi_scalar_multiplication(&h, &dk.g).to_affine()) + .then_some(()) + .ok_or_else(|| Error::AssertionFailure("U == commit(G, h)".to_string())) } fn decide_all( dk: &Self::DecidingKey, accumulators: Vec>, - ) -> bool { - !accumulators + ) -> Result<(), Error> { + assert!(!accumulators.is_empty()); + accumulators .into_iter() - .any(|accumulator| !Self::decide(dk, accumulator)) + .map(|accumulator| Self::decide(dk, accumulator)) + .try_collect::<_, Vec<_>, _>()?; + Ok(()) } } } diff --git a/snark-verifier/src/pcs/ipa/multiopen.rs b/snark-verifier/src/pcs/ipa/multiopen.rs index 9f685e76..99b0a565 100644 --- a/snark-verifier/src/pcs/ipa/multiopen.rs +++ b/snark-verifier/src/pcs/ipa/multiopen.rs @@ -1,3 +1,3 @@ mod bgh19; -pub use bgh19::{Bgh19, Bgh19Proof, Bgh19SuccinctVerifyingKey}; +pub use bgh19::{Bgh19, Bgh19Proof}; diff --git a/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs b/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs index 29d291ad..cae77a5f 100644 --- a/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs +++ b/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs @@ -1,11 +1,11 @@ use crate::{ loader::{LoadedScalar, Loader, ScalarLoader}, pcs::{ - ipa::{Ipa, IpaProof, IpaSuccinctVerifyingKey, Round}, - MultiOpenScheme, Query, + ipa::{Ipa, IpaAccumulator, IpaAs, IpaProof, IpaSuccinctVerifyingKey, Round}, + PolynomialCommitmentScheme, Query, }, util::{ - arithmetic::{ilog2, CurveAffine, Domain, FieldExt, Fraction}, + arithmetic::{CurveAffine, FieldExt, Fraction}, msm::Msm, transcript::TranscriptRead, Itertools, @@ -18,19 +18,23 @@ use std::{ marker::PhantomData, }; +/// Verifier of multi-open inner product argument. It is for the implementation +/// in [`halo2_proofs`], which is previously +/// . #[derive(Clone, Debug)] pub struct Bgh19; -impl MultiOpenScheme for Ipa +impl PolynomialCommitmentScheme for IpaAs where C: CurveAffine, L: Loader, { - type SuccinctVerifyingKey = Bgh19SuccinctVerifyingKey; + type VerifyingKey = IpaSuccinctVerifyingKey; type Proof = Bgh19Proof; + type Output = IpaAccumulator; fn read_proof( - svk: &Self::SuccinctVerifyingKey, + svk: &Self::VerifyingKey, queries: &[Query], transcript: &mut T, ) -> Result @@ -40,13 +44,13 @@ where Bgh19Proof::read(svk, queries, transcript) } - fn succinct_verify( - svk: &Self::SuccinctVerifyingKey, + fn verify( + svk: &Self::VerifyingKey, commitments: &[Msm], x: &L::LoadedScalar, queries: &[Query], proof: &Self::Proof, - ) -> Result { + ) -> Result { let loader = x.loader(); let g = loader.ec_point_load_const(&svk.g); @@ -55,9 +59,8 @@ where let p = { let coeffs = query_set_coeffs(&sets, x, &proof.x_3); - let powers_of_x_1 = proof - .x_1 - .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); + let powers_of_x_1 = + proof.x_1.powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); let f_eval = { let powers_of_x_2 = proof.x_2.powers(sets.len()); let f_evals = sets @@ -87,25 +90,11 @@ where }; // IPA - Ipa::::succinct_verify(&svk.ipa, &p, &proof.x_3, &loader.load_zero(), &proof.ipa) - } -} - -#[derive(Clone, Debug)] -pub struct Bgh19SuccinctVerifyingKey { - g: C, - ipa: IpaSuccinctVerifyingKey, -} - -impl Bgh19SuccinctVerifyingKey { - pub fn new(domain: Domain, g: C, w: C, u: C) -> Self { - Self { - g, - ipa: IpaSuccinctVerifyingKey::new(domain, u, Some(w)), - } + Ipa::succinct_verify(svk, &p, &proof.x_3, &loader.load_zero(), &proof.ipa) } } +/// Structured proof of [`Bgh19`]. #[derive(Clone, Debug)] pub struct Bgh19Proof where @@ -129,7 +118,7 @@ where L: Loader, { fn read>( - svk: &Bgh19SuccinctVerifyingKey, + svk: &IpaSuccinctVerifyingKey, queries: &[Query], transcript: &mut T, ) -> Result { @@ -151,7 +140,7 @@ where transcript.squeeze_challenge(), )) }) - .take(svk.ipa.domain.k) + .take(svk.domain.k) .collect::, _>>()?; let c = transcript.read_scalar()?; let blind = transcript.read_scalar()?; @@ -173,13 +162,9 @@ where F: FieldExt, T: Clone, { - let poly_shifts = queries.iter().fold( - Vec::<(usize, Vec, Vec<&T>)>::new(), - |mut poly_shifts, query| { - if let Some(pos) = poly_shifts - .iter() - .position(|(poly, _, _)| *poly == query.poly) - { + let poly_shifts = + queries.iter().fold(Vec::<(usize, Vec, Vec<&T>)>::new(), |mut poly_shifts, query| { + if let Some(pos) = poly_shifts.iter().position(|(poly, _, _)| *poly == query.poly) { let (_, shifts, evals) = &mut poly_shifts[pos]; if !shifts.contains(&query.shift) { shifts.push(query.shift); @@ -189,39 +174,31 @@ where poly_shifts.push((query.poly, vec![query.shift], vec![&query.eval])); } poly_shifts - }, - ); - - poly_shifts.into_iter().fold( - Vec::>::new(), - |mut sets, (poly, shifts, evals)| { - if let Some(pos) = sets.iter().position(|set| { - BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) - }) { - let set = &mut sets[pos]; - if !set.polys.contains(&poly) { - set.polys.push(poly); - set.evals.push( - set.shifts - .iter() - .map(|lhs| { - let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); - evals[idx] - }) - .collect(), - ); - } - } else { - let set = QuerySet { - shifts, - polys: vec![poly], - evals: vec![evals], - }; - sets.push(set); + }); + + poly_shifts.into_iter().fold(Vec::>::new(), |mut sets, (poly, shifts, evals)| { + if let Some(pos) = sets.iter().position(|set| { + BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) + }) { + let set = &mut sets[pos]; + if !set.polys.contains(&poly) { + set.polys.push(poly); + set.evals.push( + set.shifts + .iter() + .map(|lhs| { + let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); + evals[idx] + }) + .collect(), + ); } - sets - }, - ) + } else { + let set = QuerySet { shifts, polys: vec![poly], evals: vec![evals] }; + sets.push(set); + } + sets + }) } fn query_set_coeffs(sets: &[QuerySet], x: &T, x_3: &T) -> Vec> @@ -230,15 +207,9 @@ where T: LoadedScalar, { let loader = x.loader(); - let superset = sets - .iter() - .flat_map(|set| set.shifts.clone()) - .sorted() - .dedup(); + let superset = sets.iter().flat_map(|set| set.shifts.clone()).sorted().dedup(); - let size = 2.max( - ilog2((sets.iter().map(|set| set.shifts.len()).max().unwrap() - 1).next_power_of_two()) + 1, - ); + let size = sets.iter().map(|set| set.shifts.len()).chain(Some(2)).max().unwrap(); let powers_of_x = x.powers(size); let x_3_minus_x_shift_i = BTreeMap::from_iter( superset.map(|shift| (shift, x_3.clone() - x.clone() * loader.load_const(&shift))), @@ -337,39 +308,23 @@ where .collect_vec(); let x = &powers_of_x[1].clone(); - let x_pow_k_minus_one = { - let k_minus_one = shifts.len() - 1; - powers_of_x - .iter() - .enumerate() - .skip(1) - .filter_map(|(i, power_of_x)| { - (k_minus_one & (1 << i) == 1).then(|| power_of_x.clone()) - }) - .reduce(|acc, value| acc * value) - .unwrap_or_else(|| loader.load_one()) - }; + let x_pow_k_minus_one = &powers_of_x[shifts.len() - 1]; let barycentric_weights = shifts .iter() .zip(normalized_ell_primes.iter()) .map(|(shift, normalized_ell_prime)| { loader.sum_products_with_coeff(&[ - (*normalized_ell_prime, &x_pow_k_minus_one, x_3), - (-(*normalized_ell_prime * shift), &x_pow_k_minus_one, x), + (*normalized_ell_prime, x_pow_k_minus_one, x_3), + (-(*normalized_ell_prime * shift), x_pow_k_minus_one, x), ]) }) .map(Fraction::one_over) .collect_vec(); - let f_eval_coeff = Fraction::one_over( - loader.product( - &shifts - .iter() - .map(|shift| x_3_minus_x_shift_i.get(shift).unwrap()) - .collect_vec(), - ), - ); + let f_eval_coeff = Fraction::one_over(loader.product( + &shifts.iter().map(|shift| x_3_minus_x_shift_i.get(shift).unwrap()).collect_vec(), + )); Self { eval_coeffs: barycentric_weights, @@ -396,13 +351,8 @@ where .for_each(Fraction::evaluate); let loader = self.f_eval_coeff.evaluated().loader(); - let barycentric_weights_sum = loader.sum( - &self - .eval_coeffs - .iter() - .map(Fraction::evaluated) - .collect_vec(), - ); + let barycentric_weights_sum = + loader.sum(&self.eval_coeffs.iter().map(Fraction::evaluated).collect_vec()); self.r_eval_coeff = Some(Fraction::one_over(barycentric_weights_sum)); return vec![self.r_eval_coeff.as_mut().unwrap().denom_mut().unwrap()]; diff --git a/snark-verifier/src/pcs/kzg.rs b/snark-verifier/src/pcs/kzg.rs index 056589a8..8f416ee3 100644 --- a/snark-verifier/src/pcs/kzg.rs +++ b/snark-verifier/src/pcs/kzg.rs @@ -1,9 +1,7 @@ -use crate::{ - loader::Loader, - pcs::PolynomialCommitmentScheme, - util::arithmetic::{CurveAffine, MultiMillerLoop}, -}; -use std::{fmt::Debug, marker::PhantomData}; +//! [KZG]() +//! polynomial commitment scheme and accumulation scheme. + +use crate::util::arithmetic::CurveAffine; mod accumulation; mod accumulator; @@ -18,24 +16,15 @@ pub use multiopen::{Bdfg21, Bdfg21Proof, Gwc19, Gwc19Proof}; #[cfg(feature = "loader_halo2")] pub use accumulator::LimbsEncodingInstructions; -#[derive(Clone, Debug)] -pub struct Kzg(PhantomData<(M, MOS)>); - -impl PolynomialCommitmentScheme for Kzg -where - M: MultiMillerLoop, - L: Loader, - MOS: Clone + Debug, -{ - type Accumulator = KzgAccumulator; -} - +/// KZG succinct verifying key. #[derive(Clone, Copy, Debug)] pub struct KzgSuccinctVerifyingKey { + /// Generator. pub g: C, } impl KzgSuccinctVerifyingKey { + /// Initialize a [`KzgSuccinctVerifyingKey`]. pub fn new(g: C) -> Self { Self { g } } diff --git a/snark-verifier/src/pcs/kzg/accumulation.rs b/snark-verifier/src/pcs/kzg/accumulation.rs index f4bd9783..1f901568 100644 --- a/snark-verifier/src/pcs/kzg/accumulation.rs +++ b/snark-verifier/src/pcs/kzg/accumulation.rs @@ -1,47 +1,47 @@ use crate::{ loader::{native::NativeLoader, LoadedScalar, Loader}, - pcs::{ - kzg::KzgAccumulator, AccumulationScheme, AccumulationSchemeProver, - PolynomialCommitmentScheme, - }, + pcs::{kzg::KzgAccumulator, AccumulationScheme, AccumulationSchemeProver}, util::{ - arithmetic::{Curve, CurveAffine, Field}, + arithmetic::{Curve, CurveAffine, Field, MultiMillerLoop}, msm::Msm, transcript::{TranscriptRead, TranscriptWrite}, }, Error, }; use rand::Rng; -use std::marker::PhantomData; +use std::{fmt::Debug, marker::PhantomData}; +/// KZG accumulation scheme. The second generic `MOS` stands for different kind +/// of multi-open scheme. #[derive(Clone, Debug)] -pub struct KzgAs(PhantomData); +pub struct KzgAs(PhantomData<(M, MOS)>); -impl AccumulationScheme for KzgAs +impl AccumulationScheme for KzgAs where - C: CurveAffine, - L: Loader, - PCS: PolynomialCommitmentScheme>, + M: MultiMillerLoop, + L: Loader, + MOS: Clone + Debug, { + type Accumulator = KzgAccumulator; type VerifyingKey = KzgAsVerifyingKey; - type Proof = KzgAsProof; + type Proof = KzgAsProof; fn read_proof( vk: &Self::VerifyingKey, - instances: &[PCS::Accumulator], + instances: &[Self::Accumulator], transcript: &mut T, ) -> Result where - T: TranscriptRead, + T: TranscriptRead, { KzgAsProof::read(vk, instances, transcript) } fn verify( _: &Self::VerifyingKey, - instances: &[PCS::Accumulator], + instances: &[Self::Accumulator], proof: &Self::Proof, - ) -> Result { + ) -> Result { let (lhs, rhs) = instances .iter() .map(|accumulator| (&accumulator.lhs, &accumulator.rhs)) @@ -53,7 +53,7 @@ where bases .into_iter() .zip(powers_of_r.iter()) - .map(|(base, r)| Msm::::base(base) * r) + .map(|(base, r)| Msm::::base(base) * r) .sum::>() .evaluate(None) }); @@ -62,53 +62,57 @@ where } } +/// KZG accumulation scheme proving key. #[derive(Clone, Copy, Debug, Default)] pub struct KzgAsProvingKey(pub Option<(C, C)>); impl KzgAsProvingKey { + /// Initialize a [`KzgAsProvingKey`]. pub fn new(g: Option<(C, C)>) -> Self { Self(g) } + /// Returns if it supports zero-knowledge or not. pub fn zk(&self) -> bool { self.0.is_some() } + /// Returns [`KzgAsVerifyingKey`]. pub fn vk(&self) -> KzgAsVerifyingKey { KzgAsVerifyingKey(self.zk()) } } +/// KZG accumulation scheme verifying key. #[derive(Clone, Copy, Debug, Default)] pub struct KzgAsVerifyingKey(bool); impl KzgAsVerifyingKey { + /// Returns if it supports zero-knowledge or not. pub fn zk(&self) -> bool { self.0 } } +/// KZG accumulation scheme proof. #[derive(Clone, Debug)] -pub struct KzgAsProof +pub struct KzgAsProof where C: CurveAffine, L: Loader, - PCS: PolynomialCommitmentScheme>, { blind: Option<(L::LoadedEcPoint, L::LoadedEcPoint)>, r: L::LoadedScalar, - _marker: PhantomData, } -impl KzgAsProof +impl KzgAsProof where C: CurveAffine, L: Loader, - PCS: PolynomialCommitmentScheme>, { fn read( vk: &KzgAsVerifyingKey, - instances: &[PCS::Accumulator], + instances: &[KzgAccumulator], transcript: &mut T, ) -> Result where @@ -117,35 +121,36 @@ where assert!(!instances.is_empty()); for accumulator in instances { - transcript.common_ec_point(&accumulator.lhs).unwrap(); - transcript.common_ec_point(&accumulator.rhs).unwrap(); + transcript.common_ec_point(&accumulator.lhs)?; + transcript.common_ec_point(&accumulator.rhs)?; } let blind = vk .zk() - .then(|| (transcript.read_ec_point().unwrap(), transcript.read_ec_point().unwrap())); + .then(|| Ok((transcript.read_ec_point()?, transcript.read_ec_point()?))) + .transpose()?; let r = transcript.squeeze_challenge(); - Ok(Self { blind, r, _marker: PhantomData }) + Ok(Self { blind, r }) } } -impl AccumulationSchemeProver for KzgAs +impl AccumulationSchemeProver for KzgAs where - C: CurveAffine, - PCS: PolynomialCommitmentScheme>, + M: MultiMillerLoop, + MOS: Clone + Debug, { - type ProvingKey = KzgAsProvingKey; + type ProvingKey = KzgAsProvingKey; fn create_proof( pk: &Self::ProvingKey, - instances: &[PCS::Accumulator], + instances: &[KzgAccumulator], transcript: &mut T, rng: R, - ) -> Result + ) -> Result, Error> where - T: TranscriptWrite, + T: TranscriptWrite, R: Rng, { assert!(!instances.is_empty()); @@ -158,7 +163,7 @@ where let blind = pk .zk() .then(|| { - let s = C::Scalar::random(rng); + let s = M::Scalar::random(rng); let (g, s_g) = pk.0.unwrap(); let lhs = (s_g * s).to_affine(); let rhs = (g * s).to_affine(); @@ -181,7 +186,7 @@ where let [lhs, rhs] = [lhs, rhs].map(|msms| { msms.iter() .zip(powers_of_r.iter()) - .map(|(msm, power_of_r)| Msm::::base(msm) * power_of_r) + .map(|(msm, power_of_r)| Msm::::base(msm) * power_of_r) .sum::>() .evaluate(None) }); diff --git a/snark-verifier/src/pcs/kzg/accumulator.rs b/snark-verifier/src/pcs/kzg/accumulator.rs index efc28cd8..82d1454b 100644 --- a/snark-verifier/src/pcs/kzg/accumulator.rs +++ b/snark-verifier/src/pcs/kzg/accumulator.rs @@ -1,13 +1,16 @@ use crate::{loader::Loader, util::arithmetic::CurveAffine}; use std::fmt::Debug; +/// KZG accumulator, containing lhs G1 and rhs G1 of pairing. #[derive(Clone, Debug)] pub struct KzgAccumulator where C: CurveAffine, L: Loader, { + /// Left-hand side G1 of pairing. pub lhs: L::LoadedEcPoint, + /// Right-hand side G1 of pairing. pub rhs: L::LoadedEcPoint, } @@ -16,6 +19,7 @@ where C: CurveAffine, L: Loader, { + /// Initialize a [`KzgAccumulator`]. pub fn new(lhs: L::LoadedEcPoint, rhs: L::LoadedEcPoint) -> Self { Self { lhs, rhs } } @@ -34,7 +38,7 @@ mod native { loader::native::NativeLoader, pcs::{ kzg::{KzgAccumulator, LimbsEncoding}, - AccumulatorEncoding, PolynomialCommitmentScheme, + AccumulatorEncoding, }, util::{ arithmetic::{fe_from_limbs, CurveAffine}, @@ -43,17 +47,14 @@ mod native { Error, }; - impl AccumulatorEncoding + impl AccumulatorEncoding for LimbsEncoding where C: CurveAffine, - PCS: PolynomialCommitmentScheme< - C, - NativeLoader, - Accumulator = KzgAccumulator, - >, { - fn from_repr(limbs: &[&C::Scalar]) -> Result { + type Accumulator = KzgAccumulator; + + fn from_repr(limbs: &[&C::Scalar]) -> Result { assert_eq!(limbs.len(), 4 * LIMBS); let [lhs_x, lhs_y, rhs_x, rhs_y]: [_; 4] = limbs @@ -83,7 +84,7 @@ mod evm { loader::evm::{EvmLoader, Scalar}, pcs::{ kzg::{KzgAccumulator, LimbsEncoding}, - AccumulatorEncoding, PolynomialCommitmentScheme, + AccumulatorEncoding, }, util::{ arithmetic::{CurveAffine, PrimeField}, @@ -93,18 +94,15 @@ mod evm { }; use std::rc::Rc; - impl AccumulatorEncoding, PCS> + impl AccumulatorEncoding> for LimbsEncoding where C: CurveAffine, C::Scalar: PrimeField, - PCS: PolynomialCommitmentScheme< - C, - Rc, - Accumulator = KzgAccumulator>, - >, { - fn from_repr(limbs: &[&Scalar]) -> Result { + type Accumulator = KzgAccumulator>; + + fn from_repr(limbs: &[&Scalar]) -> Result { assert_eq!(limbs.len(), 4 * LIMBS); let loader = limbs[0].loader(); @@ -131,12 +129,11 @@ pub use halo2::LimbsEncodingInstructions; #[cfg(feature = "loader_halo2")] mod halo2 { - use crate::halo2_proofs::{circuit::Value, plonk}; use crate::{ - loader::halo2::{EccInstructions, Halo2Loader, Scalar, Valuetools}, + loader::halo2::{EccInstructions, Halo2Loader, Scalar}, pcs::{ kzg::{KzgAccumulator, LimbsEncoding}, - AccumulatorEncoding, PolynomialCommitmentScheme, + AccumulatorEncoding, }, util::{ arithmetic::{fe_from_limbs, CurveAffine}, @@ -147,64 +144,56 @@ mod halo2 { use std::{iter, ops::Deref, rc::Rc}; fn ec_point_from_limbs( - limbs: &[Value<&C::Scalar>], - ) -> Value { + limbs: &[&C::Scalar], + ) -> C { assert_eq!(limbs.len(), 2 * LIMBS); let [x, y] = [&limbs[..LIMBS], &limbs[LIMBS..]].map(|limbs| { - limbs - .iter() - .cloned() - .fold_zipped(Vec::new(), |mut acc, limb| { - acc.push(*limb); - acc - }) - .map(|limbs| fe_from_limbs::<_, _, LIMBS, BITS>(limbs.try_into().unwrap())) + fe_from_limbs::<_, _, LIMBS, BITS>( + limbs.iter().map(|limb| **limb).collect_vec().try_into().unwrap(), + ) }); - x.zip(y).map(|(x, y)| C::from_xy(x, y).unwrap()) + C::from_xy(x, y).unwrap() } - pub trait LimbsEncodingInstructions<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize>: - EccInstructions<'a, C> + /// Instructions to encode/decode a elliptic curve point into/from limbs. + pub trait LimbsEncodingInstructions: + EccInstructions { + /// Decode and assign an elliptic curve point from limbs. fn assign_ec_point_from_limbs( &self, ctx: &mut Self::Context, limbs: &[impl Deref], - ) -> Result; + ) -> Self::AssignedEcPoint; + /// Encode an elliptic curve point into limbs. fn assign_ec_point_to_limbs( &self, ctx: &mut Self::Context, ec_point: impl Deref, - ) -> Result, plonk::Error>; + ) -> Vec; } - impl<'a, C, PCS, EccChip, const LIMBS: usize, const BITS: usize> - AccumulatorEncoding>, PCS> for LimbsEncoding + impl + AccumulatorEncoding>> for LimbsEncoding where C: CurveAffine, - PCS: PolynomialCommitmentScheme< - C, - Rc>, - Accumulator = KzgAccumulator>>, - >, - EccChip: LimbsEncodingInstructions<'a, C, LIMBS, BITS>, + EccChip: LimbsEncodingInstructions, { - fn from_repr(limbs: &[&Scalar<'a, C, EccChip>]) -> Result { + type Accumulator = KzgAccumulator>>; + + fn from_repr(limbs: &[&Scalar]) -> Result { assert_eq!(limbs.len(), 4 * LIMBS); let loader = limbs[0].loader(); let [lhs, rhs] = [&limbs[..2 * LIMBS], &limbs[2 * LIMBS..]].map(|limbs| { - let assigned = loader - .ecc_chip() - .assign_ec_point_from_limbs( - &mut loader.ctx_mut(), - &limbs.iter().map(|limb| limb.assigned()).collect_vec(), - ) - .unwrap(); + let assigned = loader.ecc_chip().assign_ec_point_from_limbs( + &mut loader.ctx_mut(), + &limbs.iter().map(|limb| limb.assigned()).collect_vec(), + ); loader.ec_point_from_assigned(assigned) }); @@ -214,11 +203,11 @@ mod halo2 { mod halo2_lib { use super::*; - use halo2_base::{halo2_proofs::halo2curves::CurveAffineExt, utils::PrimeField}; - use halo2_ecc::ecc::BaseFieldEccChip; + use halo2_base::halo2_proofs::halo2curves::CurveAffineExt; + use halo2_ecc::{ecc::BaseFieldEccChip, fields::PrimeField}; - impl<'a, C, const LIMBS: usize, const BITS: usize> - LimbsEncodingInstructions<'a, C, LIMBS, BITS> for BaseFieldEccChip + impl<'chip, C, const LIMBS: usize, const BITS: usize> + LimbsEncodingInstructions for BaseFieldEccChip<'chip, C> where C: CurveAffineExt, C::ScalarExt: PrimeField, @@ -228,11 +217,11 @@ mod halo2 { &self, ctx: &mut Self::Context, limbs: &[impl Deref], - ) -> Result { + ) -> Self::AssignedEcPoint { assert_eq!(limbs.len(), 2 * LIMBS); let ec_point = self.assign_point::( - ctx, + ctx.main(0), ec_point_from_limbs::<_, LIMBS, BITS>( &limbs.iter().map(|limb| limb.value()).collect_vec(), ), @@ -242,71 +231,23 @@ mod halo2 { .iter() .zip_eq(iter::empty().chain(ec_point.x().limbs()).chain(ec_point.y().limbs())) { - ctx.region.constrain_equal(src.cell(), dst.cell()); - } - - Ok(ec_point) - } - - fn assign_ec_point_to_limbs( - &self, - _: &mut Self::Context, - ec_point: impl Deref, - ) -> Result, plonk::Error> { - Ok(iter::empty() - .chain(ec_point.x().limbs()) - .chain(ec_point.y().limbs()) - .cloned() - .collect()) - } - } - } - - /* - mod halo2_wrong { - use super::*; - use halo2_wrong_ecc::BaseFieldEccChip; - - impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> - LimbsEncodingInstructions<'a, C, LIMBS, BITS> for BaseFieldEccChip - { - fn assign_ec_point_from_limbs( - &self, - ctx: &mut Self::Context, - limbs: &[impl Deref], - ) -> Result { - assert_eq!(limbs.len(), 2 * LIMBS); - - let ec_point = self.assign_point( - ctx, - ec_point_from_limbs::<_, LIMBS, BITS>( - &limbs.iter().map(|limb| limb.value()).collect_vec(), - ), - )?; - - for (src, dst) in limbs - .iter() - .zip_eq(iter::empty().chain(ec_point.x().limbs()).chain(ec_point.y().limbs())) - { - ctx.constrain_equal(src.cell(), dst.as_ref().cell())?; + ctx.main(0).constrain_equal(src, dst); } - Ok(ec_point) + ec_point } fn assign_ec_point_to_limbs( &self, _: &mut Self::Context, ec_point: impl Deref, - ) -> Result, plonk::Error> { - Ok(iter::empty() + ) -> Vec { + iter::empty() .chain(ec_point.x().limbs()) .chain(ec_point.y().limbs()) - .map(|limb| limb.as_ref()) - .cloned() - .collect()) + .copied() + .collect() } } } - */ } diff --git a/snark-verifier/src/pcs/kzg/decider.rs b/snark-verifier/src/pcs/kzg/decider.rs index baabda6c..59f1afbf 100644 --- a/snark-verifier/src/pcs/kzg/decider.rs +++ b/snark-verifier/src/pcs/kzg/decider.rs @@ -1,26 +1,37 @@ -use crate::util::arithmetic::MultiMillerLoop; +use crate::{pcs::kzg::KzgSuccinctVerifyingKey, util::arithmetic::MultiMillerLoop}; use std::marker::PhantomData; +/// KZG deciding key. #[derive(Debug, Clone, Copy)] pub struct KzgDecidingKey { - pub g2: M::G2Affine, - pub s_g2: M::G2Affine, + svk: KzgSuccinctVerifyingKey, + /// Generator on G2. + g2: M::G2Affine, + /// Generator to the trusted-setup secret on G2. + s_g2: M::G2Affine, _marker: PhantomData, } impl KzgDecidingKey { - pub fn new(g2: M::G2Affine, s_g2: M::G2Affine) -> Self { - Self { - g2, - s_g2, - _marker: PhantomData, - } + /// Initialize a [`KzgDecidingKey`] + pub fn new( + svk: impl Into>, + g2: M::G2Affine, + s_g2: M::G2Affine, + ) -> Self { + Self { svk: svk.into(), g2, s_g2, _marker: PhantomData } + } +} + +impl From<(M::G1Affine, M::G2Affine, M::G2Affine)> for KzgDecidingKey { + fn from((g1, g2, s_g2): (M::G1Affine, M::G2Affine, M::G2Affine)) -> KzgDecidingKey { + KzgDecidingKey::new(g1, g2, s_g2) } } -impl From<(M::G2Affine, M::G2Affine)> for KzgDecidingKey { - fn from((g2, s_g2): (M::G2Affine, M::G2Affine)) -> KzgDecidingKey { - KzgDecidingKey::new(g2, s_g2) +impl AsRef> for KzgDecidingKey { + fn as_ref(&self) -> &KzgSuccinctVerifyingKey { + &self.svk } } @@ -28,47 +39,44 @@ mod native { use crate::{ loader::native::NativeLoader, pcs::{ - kzg::{Kzg, KzgAccumulator, KzgDecidingKey}, - Decider, + kzg::{KzgAccumulator, KzgAs, KzgDecidingKey}, + AccumulationDecider, }, - util::arithmetic::{Group, MillerLoopResult, MultiMillerLoop}, + util::{ + arithmetic::{Group, MillerLoopResult, MultiMillerLoop}, + Itertools, + }, + Error, }; use std::fmt::Debug; - impl Decider for Kzg + impl AccumulationDecider for KzgAs where M: MultiMillerLoop, MOS: Clone + Debug, { type DecidingKey = KzgDecidingKey; - type Output = bool; fn decide( dk: &Self::DecidingKey, KzgAccumulator { lhs, rhs }: KzgAccumulator, - ) -> bool { + ) -> Result<(), Error> { let terms = [(&lhs, &dk.g2.into()), (&rhs, &(-dk.s_g2).into())]; - M::multi_miller_loop(&terms) - .final_exponentiation() - .is_identity() - .into() + bool::from(M::multi_miller_loop(&terms).final_exponentiation().is_identity()) + .then_some(()) + .ok_or_else(|| Error::AssertionFailure("e(lhs, g2)·e(rhs, -s_g2) == O".to_string())) } fn decide_all( dk: &Self::DecidingKey, accumulators: Vec>, - ) -> bool { - !accumulators + ) -> Result<(), Error> { + assert!(!accumulators.is_empty()); + accumulators .into_iter() - //.enumerate() - .any(|accumulator| { - /*let decide = Self::decide(dk, accumulator); - if !decide { - panic!("{i}"); - } - !decide*/ - !Self::decide(dk, accumulator) - }) + .map(|accumulator| Self::decide(dk, accumulator)) + .try_collect::<_, Vec<_>, _>()?; + Ok(()) } } } @@ -77,34 +85,33 @@ mod native { mod evm { use crate::{ loader::{ - evm::{loader::Value, EvmLoader}, + evm::{loader::Value, EvmLoader, U256}, LoadedScalar, }, pcs::{ - kzg::{Kzg, KzgAccumulator, KzgDecidingKey}, - Decider, + kzg::{KzgAccumulator, KzgAs, KzgDecidingKey}, + AccumulationDecider, }, util::{ arithmetic::{CurveAffine, MultiMillerLoop, PrimeField}, msm::Msm, }, + Error, }; - use ethereum_types::U256; use std::{fmt::Debug, rc::Rc}; - impl Decider> for Kzg + impl AccumulationDecider> for KzgAs where M: MultiMillerLoop, M::Scalar: PrimeField, MOS: Clone + Debug, { type DecidingKey = KzgDecidingKey; - type Output = (); fn decide( dk: &Self::DecidingKey, KzgAccumulator { lhs, rhs }: KzgAccumulator>, - ) { + ) -> Result<(), Error> { let loader = lhs.loader(); let [g2, minus_s_g2] = [dk.g2, -dk.s_g2].map(|ec_point| { let coordinates = ec_point.coordinates().unwrap(); @@ -118,12 +125,13 @@ mod evm { ) }); loader.pairing(&lhs, g2, &rhs, minus_s_g2); + Ok(()) } fn decide_all( dk: &Self::DecidingKey, mut accumulators: Vec>>, - ) { + ) -> Result<(), Error> { assert!(!accumulators.is_empty()); let accumulator = if accumulators.len() == 1 { @@ -158,7 +166,7 @@ mod evm { KzgAccumulator::new(lhs, rhs) }; - Self::decide(dk, accumulator) + >>::decide(dk, accumulator) } } } diff --git a/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs b/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs index bc7b8131..3a448056 100644 --- a/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs +++ b/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs @@ -2,56 +2,62 @@ use crate::{ cost::{Cost, CostEstimation}, loader::{LoadedScalar, Loader, ScalarLoader}, pcs::{ - kzg::{Kzg, KzgAccumulator, KzgSuccinctVerifyingKey}, - MultiOpenScheme, Query, + kzg::{KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey}, + PolynomialCommitmentScheme, Query, }, util::{ - arithmetic::{ilog2, CurveAffine, FieldExt, Fraction, MultiMillerLoop}, + arithmetic::{CurveAffine, FieldExt, Fraction, MultiMillerLoop}, msm::Msm, transcript::TranscriptRead, Itertools, }, + Error, }; use std::{ collections::{BTreeMap, BTreeSet}, marker::PhantomData, }; +/// Verifier of multi-open KZG. It is for the SHPLONK implementation +/// in [`halo2_proofs`]. +/// Notations are following . #[derive(Clone, Debug)] pub struct Bdfg21; -impl MultiOpenScheme for Kzg +impl PolynomialCommitmentScheme for KzgAs where M: MultiMillerLoop, L: Loader, { - type SuccinctVerifyingKey = KzgSuccinctVerifyingKey; + type VerifyingKey = KzgSuccinctVerifyingKey; type Proof = Bdfg21Proof; + type Output = KzgAccumulator; fn read_proof( _: &KzgSuccinctVerifyingKey, _: &[Query], transcript: &mut T, - ) -> Bdfg21Proof + ) -> Result, Error> where T: TranscriptRead, { Bdfg21Proof::read(transcript) } - fn succinct_verify( + fn verify( svk: &KzgSuccinctVerifyingKey, commitments: &[Msm], z: &L::LoadedScalar, queries: &[Query], proof: &Bdfg21Proof, - ) -> Self::Accumulator { + ) -> Result { let sets = query_sets(queries); let f = { let coeffs = query_set_coeffs(&sets, z, &proof.z_prime); - let powers_of_mu = - proof.mu.powers(Iterator::max(sets.iter().map(|set| set.polys.len())).unwrap()); + let powers_of_mu = proof + .mu + .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); let msms = sets .iter() .zip(coeffs.iter()) @@ -66,10 +72,14 @@ where let rhs = Msm::base(&proof.w_prime); let lhs = f + rhs.clone() * &proof.z_prime; - KzgAccumulator::new(lhs.evaluate(Some(svk.g)), rhs.evaluate(Some(svk.g))) + Ok(KzgAccumulator::new( + lhs.evaluate(Some(svk.g)), + rhs.evaluate(Some(svk.g)), + )) } } +/// Structured proof of [`Bdfg21`]. #[derive(Clone, Debug)] pub struct Bdfg21Proof where @@ -88,20 +98,30 @@ where C: CurveAffine, L: Loader, { - fn read>(transcript: &mut T) -> Self { + fn read>(transcript: &mut T) -> Result { let mu = transcript.squeeze_challenge(); let gamma = transcript.squeeze_challenge(); - let w = transcript.read_ec_point().unwrap(); + let w = transcript.read_ec_point()?; let z_prime = transcript.squeeze_challenge(); - let w_prime = transcript.read_ec_point().unwrap(); - Bdfg21Proof { mu, gamma, w, z_prime, w_prime } + let w_prime = transcript.read_ec_point()?; + Ok(Bdfg21Proof { + mu, + gamma, + w, + z_prime, + w_prime, + }) } } fn query_sets(queries: &[Query]) -> Vec> { - let poly_shifts = - queries.iter().fold(Vec::<(usize, Vec, Vec<&T>)>::new(), |mut poly_shifts, query| { - if let Some(pos) = poly_shifts.iter().position(|(poly, _, _)| *poly == query.poly) { + let poly_shifts = queries.iter().fold( + Vec::<(usize, Vec, Vec<&T>)>::new(), + |mut poly_shifts, query| { + if let Some(pos) = poly_shifts + .iter() + .position(|(poly, _, _)| *poly == query.poly) + { let (_, shifts, evals) = &mut poly_shifts[pos]; if !shifts.contains(&query.shift) { shifts.push(query.shift); @@ -111,31 +131,39 @@ fn query_sets(queries: &[Query]) -> Vec>::new(), |mut sets, (poly, shifts, evals)| { - if let Some(pos) = sets.iter().position(|set| { - BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) - }) { - let set = &mut sets[pos]; - if !set.polys.contains(&poly) { - set.polys.push(poly); - set.evals.push( - set.shifts - .iter() - .map(|lhs| { - let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); - evals[idx] - }) - .collect(), - ); + }, + ); + + poly_shifts.into_iter().fold( + Vec::>::new(), + |mut sets, (poly, shifts, evals)| { + if let Some(pos) = sets.iter().position(|set| { + BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) + }) { + let set = &mut sets[pos]; + if !set.polys.contains(&poly) { + set.polys.push(poly); + set.evals.push( + set.shifts + .iter() + .map(|lhs| { + let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); + evals[idx] + }) + .collect(), + ); + } + } else { + let set = QuerySet { + shifts, + polys: vec![poly], + evals: vec![evals], + }; + sets.push(set); } - } else { - let set = QuerySet { shifts, polys: vec![poly], evals: vec![evals] }; - sets.push(set); - } - sets - }) + sets + }, + ) } fn query_set_coeffs<'a, F: FieldExt, T: LoadedScalar>( @@ -145,15 +173,25 @@ fn query_set_coeffs<'a, F: FieldExt, T: LoadedScalar>( ) -> Vec> { let loader = z.loader(); - let superset = sets.iter().flat_map(|set| set.shifts.clone()).sorted().dedup(); + let superset = sets + .iter() + .flat_map(|set| set.shifts.clone()) + .sorted() + .dedup(); - let size = 2.max( - ilog2((sets.iter().map(|set| set.shifts.len()).max().unwrap() - 1).next_power_of_two()) + 1, - ); + let size = sets + .iter() + .map(|set| set.shifts.len()) + .chain(Some(2)) + .max() + .unwrap(); let powers_of_z = z.powers(size); - let z_prime_minus_z_shift_i = BTreeMap::from_iter( - superset.map(|shift| (shift, z_prime.clone() - z.clone() * loader.load_const(&shift))), - ); + let z_prime_minus_z_shift_i = BTreeMap::from_iter(superset.map(|shift| { + ( + shift, + z_prime.clone() - z.clone() * loader.load_const(&shift), + ) + })); let mut z_s_1 = None; let mut coeffs = sets @@ -259,30 +297,25 @@ where .collect_vec(); let z = &powers_of_z[1]; - let z_pow_k_minus_one = { - let k_minus_one = shifts.len() - 1; - powers_of_z - .iter() - .enumerate() - .skip(1) - .filter_map(|(i, power_of_z)| (k_minus_one & (1 << i) == 1).then_some(power_of_z)) - .fold(loader.load_one(), |acc, value| acc * value) - }; + let z_pow_k_minus_one = &powers_of_z[shifts.len() - 1]; let barycentric_weights = shifts .iter() .zip(normalized_ell_primes.iter()) .map(|(shift, normalized_ell_prime)| { loader.sum_products_with_coeff(&[ - (*normalized_ell_prime, &z_pow_k_minus_one, z_prime), - (-(*normalized_ell_prime * shift), &z_pow_k_minus_one, z), + (*normalized_ell_prime, z_pow_k_minus_one, z_prime), + (-(*normalized_ell_prime * shift), z_pow_k_minus_one, z), ]) }) .map(Fraction::one_over) .collect_vec(); let z_s = loader.product( - &shifts.iter().map(|shift| z_prime_minus_z_shift_i.get(shift).unwrap()).collect_vec(), + &shifts + .iter() + .map(|shift| z_prime_minus_z_shift_i.get(shift).unwrap()) + .collect_vec(), ); let z_s_1_over_z_s = z_s_1.clone().map(|z_s_1| Fraction::new(z_s_1, z_s.clone())); @@ -311,8 +344,13 @@ where .iter_mut() .chain(self.commitment_coeff.as_mut()) .for_each(Fraction::evaluate); - let barycentric_weights_sum = - loader.sum(&self.eval_coeffs.iter().map(Fraction::evaluated).collect_vec()); + let barycentric_weights_sum = loader.sum( + &self + .eval_coeffs + .iter() + .map(Fraction::evaluated) + .collect_vec(), + ); self.r_eval_coeff = Some(match self.commitment_coeff.as_ref() { Some(coeff) => Fraction::new(coeff.evaluated().clone(), barycentric_weights_sum), None => Fraction::one_over(barycentric_weights_sum), @@ -328,13 +366,17 @@ where } } -impl CostEstimation for Kzg +impl CostEstimation for KzgAs where M: MultiMillerLoop, { type Input = Vec>; fn estimate_cost(_: &Vec>) -> Cost { - Cost::new(0, 2, 0, 2) + Cost { + num_commitment: 2, + num_msm: 2, + ..Default::default() + } } } diff --git a/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs b/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs index 12496ad7..e5741163 100644 --- a/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs +++ b/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs @@ -2,8 +2,8 @@ use crate::{ cost::{Cost, CostEstimation}, loader::{LoadedScalar, Loader}, pcs::{ - kzg::{Kzg, KzgAccumulator, KzgSuccinctVerifyingKey}, - MultiOpenScheme, Query, + kzg::{KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey}, + PolynomialCommitmentScheme, Query, }, util::{ arithmetic::{CurveAffine, MultiMillerLoop, PrimeField}, @@ -11,42 +11,48 @@ use crate::{ transcript::TranscriptRead, Itertools, }, + Error, }; +/// Verifier of multi-open KZG. It is for the GWC implementation +/// in [`halo2_proofs`]. +/// Notations are following . #[derive(Clone, Debug)] pub struct Gwc19; -impl MultiOpenScheme for Kzg +impl PolynomialCommitmentScheme for KzgAs where M: MultiMillerLoop, L: Loader, { - type SuccinctVerifyingKey = KzgSuccinctVerifyingKey; + type VerifyingKey = KzgSuccinctVerifyingKey; type Proof = Gwc19Proof; + type Output = KzgAccumulator; fn read_proof( - _: &Self::SuccinctVerifyingKey, + _: &Self::VerifyingKey, queries: &[Query], transcript: &mut T, - ) -> Self::Proof + ) -> Result where T: TranscriptRead, { Gwc19Proof::read(queries, transcript) } - fn succinct_verify( - svk: &Self::SuccinctVerifyingKey, + fn verify( + svk: &Self::VerifyingKey, commitments: &[Msm], z: &L::LoadedScalar, queries: &[Query], proof: &Self::Proof, - ) -> Self::Accumulator { + ) -> Result { let sets = query_sets(queries); let powers_of_u = &proof.u.powers(sets.len()); let f = { - let powers_of_v = - proof.v.powers(Iterator::max(sets.iter().map(|set| set.polys.len())).unwrap()); + let powers_of_v = proof + .v + .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); sets.iter() .map(|set| set.msm(commitments, &powers_of_v)) .zip(powers_of_u.iter()) @@ -61,15 +67,20 @@ where .zip(powers_of_u.iter()) .map(|(w, power_of_u)| Msm::base(w) * power_of_u) .collect_vec(); - let lhs = f + rhs.iter().zip(z_omegas).map(|(uw, z_omega)| uw.clone() * &z_omega).sum(); + let lhs = f + rhs + .iter() + .zip(z_omegas) + .map(|(uw, z_omega)| uw.clone() * &z_omega) + .sum(); - KzgAccumulator::new( + Ok(KzgAccumulator::new( lhs.evaluate(Some(svk.g)), rhs.into_iter().sum::>().evaluate(Some(svk.g)), - ) + )) } } +/// Structured proof of [`Gwc19`]. #[derive(Clone, Debug)] pub struct Gwc19Proof where @@ -86,14 +97,14 @@ where C: CurveAffine, L: Loader, { - fn read(queries: &[Query], transcript: &mut T) -> Self + fn read(queries: &[Query], transcript: &mut T) -> Result where T: TranscriptRead, { let v = transcript.squeeze_challenge(); - let ws = transcript.read_n_ec_points(query_sets(queries).len()).unwrap(); + let ws = transcript.read_n_ec_points(query_sets(queries).len())?; let u = transcript.squeeze_challenge(); - Gwc19Proof { v, ws, u } + Ok(Gwc19Proof { v, ws, u }) } } @@ -146,7 +157,7 @@ where }) } -impl CostEstimation for Kzg +impl CostEstimation for KzgAs where M: MultiMillerLoop, { @@ -154,6 +165,10 @@ where fn estimate_cost(queries: &Vec>) -> Cost { let num_w = query_sets(queries).len(); - Cost::new(0, num_w, 0, num_w) + Cost { + num_commitment: num_w, + num_msm: num_w, + ..Default::default() + } } } diff --git a/snark-verifier/src/system.rs b/snark-verifier/src/system.rs index edf79228..d6e74e73 100644 --- a/snark-verifier/src/system.rs +++ b/snark-verifier/src/system.rs @@ -1 +1,3 @@ +//! Proof systems `snark-verifier` supports + pub mod halo2; diff --git a/snark-verifier/src/system/halo2.rs b/snark-verifier/src/system/halo2.rs index 1ba7c1cc..98f4488c 100644 --- a/snark-verifier/src/system/halo2.rs +++ b/snark-verifier/src/system/halo2.rs @@ -1,3 +1,5 @@ +//! [`halo2_proofs`] proof system + use crate::halo2_proofs::{ plonk::{self, Any, ConstraintSystem, FirstPhase, SecondPhase, ThirdPhase, VerifyingKey}, poly::{self, commitment::Params}, @@ -6,62 +8,67 @@ use crate::halo2_proofs::{ use crate::{ util::{ arithmetic::{root_of_unity, CurveAffine, Domain, FieldExt, Rotation}, - protocol::{ - CommonPolynomial, Expression, InstanceCommittingKey, Query, QuotientPolynomial, - }, Itertools, }, - Protocol, + verifier::plonk::protocol::{ + CommonPolynomial, Expression, InstanceCommittingKey, PlonkProtocol, Query, + QuotientPolynomial, + }, }; use num_integer::Integer; use std::{io, iter, mem::size_of}; -// pub mod strategy; +pub mod strategy; pub mod transcript; -#[cfg(test)] -#[cfg(feature = "loader_halo2")] -pub(crate) mod test; - +/// Configuration for converting a [`VerifyingKey`] of [`halo2_proofs`] into +/// [`PlonkProtocol`]. #[derive(Clone, Debug, Default)] pub struct Config { - pub zk: bool, - pub query_instance: bool, - pub num_proof: usize, - pub num_instance: Vec, - pub accumulator_indices: Option>, + zk: bool, + query_instance: bool, + num_proof: usize, + num_instance: Vec, + accumulator_indices: Option>, } impl Config { + /// Returns [`Config`] with `query_instance` set to `false`. pub fn kzg() -> Self { Self { zk: true, query_instance: false, num_proof: 1, ..Default::default() } } + /// Returns [`Config`] with `query_instance` set to `true`. pub fn ipa() -> Self { Self { zk: true, query_instance: true, num_proof: 1, ..Default::default() } } + /// Set `zk` pub fn set_zk(mut self, zk: bool) -> Self { self.zk = zk; self } + /// Set `query_instance` pub fn set_query_instance(mut self, query_instance: bool) -> Self { self.query_instance = query_instance; self } + /// Set `num_proof` pub fn with_num_proof(mut self, num_proof: usize) -> Self { assert!(num_proof > 0); self.num_proof = num_proof; self } + /// Set `num_instance` pub fn with_num_instance(mut self, num_instance: Vec) -> Self { self.num_instance = num_instance; self } + /// Set `accumulator_indices` pub fn with_accumulator_indices( mut self, accumulator_indices: Option>, @@ -71,11 +78,12 @@ impl Config { } } +/// Convert a [`VerifyingKey`] of [`halo2_proofs`] into [`PlonkProtocol`]. pub fn compile<'a, C: CurveAffine, P: Params<'a, C>>( params: &P, vk: &VerifyingKey, config: Config, -) -> Protocol { +) -> PlonkProtocol { assert_eq!(vk.get_domain().k(), params.k()); let cs = vk.cs(); @@ -103,7 +111,7 @@ pub fn compile<'a, C: CurveAffine, P: Params<'a, C>>( .chain((0..num_proof).flat_map(move |t| polynomials.permutation_z_queries::(t))) .chain((0..num_proof).flat_map(move |t| polynomials.lookup_queries::(t))) .collect(); - + // `quotient_query()` is not needed in evaluations because the verifier can compute it itself from the other evaluations. let queries = (0..num_proof) .flat_map(|t| { iter::empty() @@ -123,7 +131,7 @@ pub fn compile<'a, C: CurveAffine, P: Params<'a, C>>( let instance_committing_key = query_instance.then(|| { instance_committing_key( params, - Iterator::max(polynomials.num_instance().into_iter()).unwrap_or_default(), + polynomials.num_instance().into_iter().max().unwrap_or_default(), ) }); @@ -131,7 +139,7 @@ pub fn compile<'a, C: CurveAffine, P: Params<'a, C>>( .map(|accumulator_indices| polynomials.accumulator_indices(accumulator_indices)) .unwrap_or_default(); - Protocol { + PlonkProtocol { domain, preprocessed, num_instance: polynomials.num_instance(), diff --git a/snark-verifier/src/system/halo2/aggregation.rs b/snark-verifier/src/system/halo2/aggregation.rs deleted file mode 100644 index a3b09c15..00000000 --- a/snark-verifier/src/system/halo2/aggregation.rs +++ /dev/null @@ -1,718 +0,0 @@ -use super::{BITS, LIMBS}; -use crate::{ - loader::{self, native::NativeLoader, Loader}, - pcs::{ - kzg::{ - Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgAsProvingKey, KzgAsVerifyingKey, - KzgSuccinctVerifyingKey, LimbsEncoding, - }, - AccumulationScheme, AccumulationSchemeProver, - }, - system::{ - self, - halo2::{ - compile, read_or_create_srs, transcript::halo2::ChallengeScalar, Config, - Halo2VerifierCircuitConfig, Halo2VerifierCircuitConfigParams, - }, - }, - util::arithmetic::fe_to_limbs, - verifier::{self, PlonkVerifier}, - Protocol, -}; -use ark_std::{end_timer, start_timer}; -use halo2_base::AssignedValue; -pub use halo2_base::{ - utils::{biguint_to_fe, fe_to_biguint}, - Context, ContextParams, -}; -use halo2_curves::bn256::{Bn256, Fr, G1Affine}; -use halo2_proofs::{ - circuit::{Layouter, SimpleFloorPlanner, Value}, - plonk::{ - self, create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey, - }, - poly::{ - commitment::{Params, ParamsProver}, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::AccumulatorStrategy, - }, - VerificationStrategy, - }, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, -}; -use itertools::Itertools; -use num_bigint::BigUint; -use num_traits::Num; -use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; -use std::{ - fs::File, - io::{BufReader, BufWriter, Cursor, Read, Write}, - path::Path, - rc::Rc, -}; - -pub const T: usize = 3; -pub const RATE: usize = 2; -pub const R_F: usize = 8; -pub const R_P: usize = 57; - -pub type Halo2Loader<'a, 'b> = loader::halo2::Halo2Loader<'a, 'b, G1Affine>; -pub type PoseidonTranscript = - system::halo2::transcript::halo2::PoseidonTranscript; - -pub type Pcs = Kzg; -pub type Svk = KzgSuccinctVerifyingKey; -pub type As = KzgAs; -pub type AsPk = KzgAsProvingKey; -pub type AsVk = KzgAsVerifyingKey; -pub type Plonk = verifier::Plonk>; - -pub struct Snark { - protocol: Protocol, - instances: Vec>, - proof: Vec, -} - -impl Snark { - pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { - Self { protocol, instances, proof } - } - pub fn protocol(&self) -> &Protocol { - &self.protocol - } - pub fn instances(&self) -> &[Vec] { - &self.instances - } - pub fn proof(&self) -> &[u8] { - &self.proof - } -} - -impl From for SnarkWitness { - fn from(snark: Snark) -> Self { - Self { - protocol: snark.protocol, - instances: snark - .instances - .into_iter() - .map(|instances| instances.into_iter().map(Value::known).collect_vec()) - .collect(), - proof: Value::known(snark.proof), - } - } -} - -#[derive(Clone)] -pub struct SnarkWitness { - protocol: Protocol, - instances: Vec>>, - proof: Value>, -} - -impl SnarkWitness { - pub fn without_witnesses(&self) -> Self { - SnarkWitness { - protocol: self.protocol.clone(), - instances: self - .instances - .iter() - .map(|instances| vec![Value::unknown(); instances.len()]) - .collect(), - proof: Value::unknown(), - } - } - - pub fn protocol(&self) -> &Protocol { - &self.protocol - } - - pub fn instances(&self) -> &[Vec>] { - &self.instances - } - - pub fn proof(&self) -> Value<&[u8]> { - self.proof.as_ref().map(Vec::as_slice) - } -} - -pub fn aggregate<'a, 'b>( - svk: &Svk, - loader: &Rc>, - snarks: &[SnarkWitness], - as_vk: &AsVk, - as_proof: Value<&'_ [u8]>, - expose_instances: bool, -) -> Vec> { - let assign_instances = |instances: &[Vec>]| { - instances - .iter() - .map(|instances| { - instances.iter().map(|instance| loader.assign_scalar(*instance)).collect_vec() - }) - .collect_vec() - }; - - let mut instances_to_expose = vec![]; - let mut accumulators = snarks - .iter() - .flat_map(|snark| { - let instances = assign_instances(&snark.instances); - if expose_instances { - instances_to_expose.extend( - instances - .iter() - .flat_map(|instance| instance.iter().map(|scalar| scalar.assigned())), - ); - } - let mut transcript = - PoseidonTranscript::, _, _>::new(loader, snark.proof()); - let proof = - Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap(); - Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap() - }) - .collect_vec(); - - let KzgAccumulator { lhs, rhs } = if accumulators.len() > 1 { - let mut transcript = PoseidonTranscript::, _, _>::new(loader, as_proof); - let proof = As::read_proof(as_vk, &accumulators, &mut transcript).unwrap(); - As::verify(as_vk, &accumulators, &proof).unwrap() - } else { - accumulators.pop().unwrap() - }; - - let lhs = lhs.assigned(); - let rhs = rhs.assigned(); - - lhs.x - .truncation - .limbs - .iter() - .chain(lhs.y.truncation.limbs.iter()) - .chain(rhs.x.truncation.limbs.iter()) - .chain(rhs.y.truncation.limbs.iter()) - .chain(instances_to_expose.iter()) - .cloned() - .collect_vec() -} - -pub fn recursive_aggregate<'a, 'b>( - svk: &Svk, - loader: &Rc>, - snarks: &[SnarkWitness], - recursive_snark: &SnarkWitness, - as_vk: &AsVk, - as_proof: Value<&'_ [u8]>, - use_dummy: AssignedValue, -) -> (Vec>, Vec>>) { - let assign_instances = |instances: &[Vec>]| { - instances - .iter() - .map(|instances| { - instances.iter().map(|instance| loader.assign_scalar(*instance)).collect_vec() - }) - .collect_vec() - }; - - let mut assigned_instances = vec![]; - let mut accumulators = snarks - .iter() - .flat_map(|snark| { - let instances = assign_instances(&snark.instances); - assigned_instances.push( - instances - .iter() - .flat_map(|instance| instance.iter().map(|scalar| scalar.assigned())) - .collect_vec(), - ); - let mut transcript = - PoseidonTranscript::, _, _>::new(loader, snark.proof()); - let proof = - Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap(); - Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap() - }) - .collect_vec(); - - let use_dummy = loader.scalar_from_assigned(use_dummy); - - let prev_instances = assign_instances(&recursive_snark.instances); - let mut accs = { - let mut transcript = - PoseidonTranscript::, _, _>::new(loader, recursive_snark.proof()); - let proof = - Plonk::read_proof(svk, &recursive_snark.protocol, &prev_instances, &mut transcript) - .unwrap(); - let mut accs = Plonk::succinct_verify_or_dummy( - svk, - &recursive_snark.protocol, - &prev_instances, - &proof, - &use_dummy, - ) - .unwrap(); - for acc in accs.iter_mut() { - (*acc).lhs = - loader.ec_point_select(&accumulators[0].lhs, &acc.lhs, &use_dummy).unwrap(); - (*acc).rhs = - loader.ec_point_select(&accumulators[0].rhs, &acc.rhs, &use_dummy).unwrap(); - } - accs - }; - accumulators.append(&mut accs); - - let KzgAccumulator { lhs, rhs } = { - let mut transcript = PoseidonTranscript::, _, _>::new(loader, as_proof); - let proof = As::read_proof(as_vk, &accumulators, &mut transcript).unwrap(); - As::verify(as_vk, &accumulators, &proof).unwrap() - }; - - let lhs = lhs.assigned(); - let rhs = rhs.assigned(); - - let mut new_instances = prev_instances - .iter() - .flat_map(|instance| instance.iter().map(|scalar| scalar.assigned())) - .collect_vec(); - for (i, acc_limb) in lhs - .x - .truncation - .limbs - .iter() - .chain(lhs.y.truncation.limbs.iter()) - .chain(rhs.x.truncation.limbs.iter()) - .chain(rhs.y.truncation.limbs.iter()) - .enumerate() - { - new_instances[i] = acc_limb.clone(); - } - (new_instances, assigned_instances) -} - -#[derive(Clone)] -pub struct AggregationCircuit { - svk: Svk, - snarks: Vec, - pub instances: Vec, - as_vk: AsVk, - as_proof: Value>, - expose_target_instances: bool, -} - -impl AggregationCircuit { - pub fn new( - params: &ParamsKZG, - snarks: impl IntoIterator, - expose_target_instances: bool, - ) -> Self { - let svk = params.get_g()[0].into(); - let snarks = snarks.into_iter().collect_vec(); - - let mut accumulators = snarks - .iter() - .flat_map(|snark| { - let mut transcript = - PoseidonTranscript::::new(snark.proof.as_slice()); - let proof = - Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript) - .unwrap(); - Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap() - }) - .collect_vec(); - - let as_pk = AsPk::new(Some((params.get_g()[0], params.get_g()[1]))); - let (accumulator, as_proof) = if accumulators.len() > 1 { - let mut transcript = PoseidonTranscript::::new(Vec::new()); - let accumulator = As::create_proof( - &as_pk, - &accumulators, - &mut transcript, - ChaCha20Rng::from_seed(Default::default()), - ) - .unwrap(); - (accumulator, Value::known(transcript.finalize())) - } else { - (accumulators.pop().unwrap(), Value::unknown()) - }; - - let KzgAccumulator { lhs, rhs } = accumulator; - let mut instances = - [lhs.x, lhs.y, rhs.x, rhs.y].map(fe_to_limbs::<_, _, LIMBS, BITS>).concat(); - if expose_target_instances { - instances.extend(snarks.iter().flat_map(|snark| snark.instances.iter().flatten())); - } - - Self { - svk, - snarks: snarks.into_iter().map_into().collect(), - instances, - as_vk: as_pk.vk(), - as_proof, - expose_target_instances, - } - } - - pub fn accumulator_indices() -> Vec<(usize, usize)> { - (0..4 * LIMBS).map(|idx| (0, idx)).collect() - } - - pub fn num_instance(&self) -> Vec { - dbg!(self.instances.len()); - vec![self.instances.len()] - } - - pub fn instances(&self) -> Vec> { - vec![self.instances.clone()] - } - - pub fn as_proof(&self) -> Value<&[u8]> { - self.as_proof.as_ref().map(Vec::as_slice) - } - - pub fn synthesize_proof( - &self, - config: Halo2VerifierCircuitConfig, - layouter: &mut impl Layouter, - instance_equalities: Vec<(usize, usize)>, - ) -> Result>, plonk::Error> { - config.base_field_config.load_lookup_table(layouter)?; - - // Need to trick layouter to skip first pass in get shape mode - let using_simple_floor_planner = true; - let mut first_pass = true; - let mut assigned_instances = None; - layouter.assign_region( - || "", - |region| { - if using_simple_floor_planner && first_pass { - first_pass = false; - return Ok(()); - } - let ctx = config.base_field_config.new_context(region); - - let loader = Halo2Loader::new(&config.base_field_config, ctx); - let instances = aggregate( - &self.svk, - &loader, - &self.snarks, - &self.as_vk, - self.as_proof(), - self.expose_target_instances, - ); - - for &(i, j) in &instance_equalities { - loader - .ctx_mut() - .region - .constrain_equal(instances[i].cell(), instances[j].cell())?; - } - // REQUIRED STEP - loader.finalize(); - assigned_instances = Some(instances); - Ok(()) - }, - )?; - Ok(assigned_instances.unwrap()) - } -} - -impl Circuit for AggregationCircuit { - type Config = Halo2VerifierCircuitConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - svk: self.svk, - snarks: self.snarks.iter().map(SnarkWitness::without_witnesses).collect(), - instances: Vec::new(), - as_vk: self.as_vk, - as_proof: Value::unknown(), - expose_target_instances: self.expose_target_instances, - } - } - - fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - let path = std::env::var("VERIFY_CONFIG").expect("export VERIFY_CONFIG with config path"); - let params: Halo2VerifierCircuitConfigParams = serde_json::from_reader( - File::open(path.as_str()).expect(format!("{} file should exist", path).as_str()), - ) - .unwrap(); - - Halo2VerifierCircuitConfig::configure(meta, params) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), plonk::Error> { - let config_instance = config.instance.clone(); - let assigned_instances = self.synthesize_proof(config, &mut layouter, vec![])?; - Ok({ - // TODO: use less instances by following Scroll's strategy of keeping only last bit of y coordinate - let mut layouter = layouter.namespace(|| "expose"); - for (i, assigned_instance) in assigned_instances.iter().enumerate() { - layouter.constrain_instance( - assigned_instance.cell().clone(), - config_instance, - i, - )?; - } - }) - } -} - -pub fn gen_srs(k: u32) -> ParamsKZG { - read_or_create_srs::(k, |k| { - ParamsKZG::::setup(k, ChaCha20Rng::from_seed(Default::default())) - }) -} - -pub fn gen_vk>( - params: &ParamsKZG, - circuit: &ConcreteCircuit, - name: &str, -) -> VerifyingKey { - let path = format!("./data/{}_{}.vkey", name, params.k()); - #[cfg(feature = "serialize")] - match File::open(path.as_str()) { - Ok(f) => { - let read_time = start_timer!(|| format!("Reading vkey from {}", path)); - let mut bufreader = BufReader::new(f); - let vk = VerifyingKey::read::<_, ConcreteCircuit>(&mut bufreader, params) - .expect("Reading vkey should not fail"); - end_timer!(read_time); - vk - } - Err(_) => { - let vk_time = start_timer!(|| "vkey"); - let vk = keygen_vk(params, circuit).unwrap(); - end_timer!(vk_time); - let mut f = BufWriter::new(File::create(path.as_str()).unwrap()); - println!("Writing vkey to {}", path); - vk.write(&mut f).unwrap(); - vk - } - } - #[cfg(not(feature = "serialize"))] - { - let vk_time = start_timer!(|| "vkey"); - let vk = keygen_vk(params, circuit).unwrap(); - end_timer!(vk_time); - vk - } -} - -pub fn gen_pk>( - params: &ParamsKZG, - circuit: &ConcreteCircuit, - name: &str, -) -> ProvingKey { - let path = format!("./data/{}_{}.pkey", name, params.k()); - #[cfg(feature = "serialize")] - match File::open(path.as_str()) { - Ok(f) => { - let read_time = start_timer!(|| format!("Reading pkey from {}", path)); - let mut bufreader = BufReader::new(f); - let pk = ProvingKey::read::<_, ConcreteCircuit>(&mut bufreader, params) - .expect("Reading pkey should not fail"); - end_timer!(read_time); - pk - } - Err(_) => { - let vk = gen_vk::(params, circuit, name); - let pk_time = start_timer!(|| "pkey"); - let pk = keygen_pk(params, vk, circuit).unwrap(); - end_timer!(pk_time); - let mut f = BufWriter::new(File::create(path.as_str()).unwrap()); - println!("Writing pkey to {}", path); - pk.write(&mut f).unwrap(); - pk - } - } - #[cfg(not(feature = "serialize"))] - { - let vk = gen_vk::(params, circuit, name); - let pk_time = start_timer!(|| "pkey"); - let pk = keygen_pk(params, vk, circuit).unwrap(); - end_timer!(pk_time); - pk - } -} - -pub fn read_bytes(path: &str) -> Vec { - let mut buf = vec![]; - let mut f = File::open(path).unwrap(); - f.read_to_end(&mut buf).unwrap(); - buf -} - -pub fn write_bytes(path: &str, buf: &Vec) { - let mut f = File::create(path).unwrap(); - f.write(buf).unwrap(); -} - -/// reads the instances for T::N_PROOFS circuits from file -pub fn read_instances(path: &str) -> Option>>> { - let f = File::open(path); - if let Err(_) = f { - return None; - } - let f = f.unwrap(); - let reader = BufReader::new(f); - let instances_str: Vec>> = serde_json::from_reader(reader).unwrap(); - let ret = instances_str - .into_iter() - .map(|circuit_instances| { - circuit_instances - .into_iter() - .map(|instance_column| { - instance_column - .iter() - .map(|str| { - biguint_to_fe::(&BigUint::from_str_radix(str.as_str(), 16).unwrap()) - }) - .collect_vec() - }) - .collect_vec() - }) - .collect_vec(); - Some(ret) -} - -pub fn write_instances(instances: &Vec>>, path: &str) { - let mut hex_strings = vec![]; - for circuit_instances in instances.iter() { - hex_strings.push( - circuit_instances - .iter() - .map(|instance_column| { - instance_column.iter().map(|x| fe_to_biguint(x).to_str_radix(16)).collect_vec() - }) - .collect_vec(), - ); - } - let f = BufWriter::new(File::create(path).unwrap()); - serde_json::to_writer(f, &hex_strings).unwrap(); -} - -pub trait TargetCircuit { - const N_PROOFS: usize; - - type Circuit: Circuit; - - fn name() -> String; -} - -// this is a toggle that should match the fork of halo2_proofs you are using -// the current default in PSE/main is `false`, before 2022_10_22 it was `true`: -// see https://github.com/privacy-scaling-explorations/halo2/pull/96/files -pub const KZG_QUERY_INSTANCE: bool = false; - -pub fn create_snark_shplonk( - params: &ParamsKZG, - circuits: Vec, - instances: Vec>>, // instances[i][j][..] is the i-th circuit's j-th instance column - accumulator_indices: Option>, -) -> Snark { - println!("CREATING SNARK FOR: {}", T::name()); - let config = if let Some(accumulator_indices) = accumulator_indices { - Config::kzg(KZG_QUERY_INSTANCE) - .set_zk(true) - .with_num_proof(T::N_PROOFS) - .with_accumulator_indices(accumulator_indices) - } else { - Config::kzg(KZG_QUERY_INSTANCE).set_zk(true).with_num_proof(T::N_PROOFS) - }; - - let pk = gen_pk(params, &circuits[0], T::name().as_str()); - // num_instance[i] is length of the i-th instance columns in circuit 0 (all circuits should have same shape of instances) - let num_instance = instances[0].iter().map(|instance_column| instance_column.len()).collect(); - let protocol = compile(params, pk.get_vk(), config.with_num_instance(num_instance)); - - // usual shenanigans to turn nested Vec into nested slice - let instances1: Vec> = instances - .iter() - .map(|instances| instances.iter().map(Vec::as_slice).collect_vec()) - .collect_vec(); - let instances2: Vec<&[&[Fr]]> = instances1.iter().map(Vec::as_slice).collect_vec(); - // TODO: need to cache the instances as well! - - let proof = { - let path = format!("./data/proof_{}_{}.dat", T::name(), params.k()); - let instance_path = format!("./data/instances_{}_{}.dat", T::name(), params.k()); - let cached_instances = read_instances::(instance_path.as_str()); - #[cfg(feature = "serialize")] - if cached_instances.is_some() - && Path::new(path.as_str()).exists() - && cached_instances.unwrap() == instances - { - let proof_time = start_timer!(|| "read proof"); - let mut file = File::open(path.as_str()).unwrap(); - let mut buf = vec![]; - file.read_to_end(&mut buf).unwrap(); - end_timer!(proof_time); - buf - } else { - let proof_time = start_timer!(|| "create proof"); - let mut transcript = PoseidonTranscript::, _>::init(Vec::new()); - create_proof::, ProverSHPLONK<_>, ChallengeScalar<_>, _, _, _>( - params, - &pk, - &circuits, - instances2.as_slice(), - &mut ChaCha20Rng::from_seed(Default::default()), - &mut transcript, - ) - .unwrap(); - let proof = transcript.finalize(); - let mut file = File::create(path.as_str()).unwrap(); - file.write_all(&proof).unwrap(); - write_instances(&instances, instance_path.as_str()); - end_timer!(proof_time); - proof - } - #[cfg(not(feature = "serialize"))] - { - let proof_time = start_timer!(|| "create proof"); - let mut transcript = PoseidonTranscript::, _>::init(Vec::new()); - create_proof::, ProverSHPLONK<_>, ChallengeScalar<_>, _, _, _>( - params, - &pk, - &circuits, - instances2.as_slice(), - &mut ChaCha20Rng::from_seed(Default::default()), - &mut transcript, - ) - .unwrap(); - let proof = transcript.finalize(); - end_timer!(proof_time); - proof - } - }; - - let verify_time = start_timer!(|| "verify proof"); - { - let verifier_params = params.verifier_params(); - let strategy = AccumulatorStrategy::new(verifier_params); - let mut transcript = - >, _> as TranscriptReadBuffer< - _, - _, - _, - >>::init(Cursor::new(proof.clone())); - assert!(VerificationStrategy::<_, VerifierSHPLONK<_>>::finalize( - verify_proof::<_, VerifierSHPLONK<_>, _, _, _>( - verifier_params, - pk.get_vk(), - strategy, - instances2.as_slice(), - &mut transcript, - ) - .unwrap() - )) - } - end_timer!(verify_time); - - Snark::new(protocol.clone(), instances.into_iter().flatten().collect_vec(), proof) -} diff --git a/snark-verifier/src/system/halo2/strategy.rs b/snark-verifier/src/system/halo2/strategy.rs index de66f8e3..a1523680 100644 --- a/snark-verifier/src/system/halo2/strategy.rs +++ b/snark-verifier/src/system/halo2/strategy.rs @@ -1,6 +1,9 @@ +//! Verifier strategy + pub mod ipa { - use crate::util::arithmetic::CurveAffine; - use halo2_proofs::{ + //! IPA verifier strategy + + use crate::halo2_proofs::{ plonk::Error, poly::{ commitment::MSM, @@ -13,7 +16,10 @@ pub mod ipa { VerificationStrategy, }, }; + use crate::util::arithmetic::CurveAffine; + /// Strategy that handles single proof and decide immediately, but also + /// returns `g` if the proof is valid. #[derive(Clone, Debug)] pub struct SingleStrategy<'a, C: CurveAffine> { msm: MSMIPA<'a, C>, @@ -25,9 +31,7 @@ pub mod ipa { type Output = C; fn new(params: &'a ParamsIPA) -> Self { - SingleStrategy { - msm: MSMIPA::new(params), - } + SingleStrategy { msm: MSMIPA::new(params) } } fn process( diff --git a/snark-verifier/src/system/halo2/test.rs b/snark-verifier/src/system/halo2/test.rs deleted file mode 100644 index 88a2ff26..00000000 --- a/snark-verifier/src/system/halo2/test.rs +++ /dev/null @@ -1,215 +0,0 @@ -#![allow(dead_code)] -#![allow(clippy::all)] -use crate::halo2_proofs::{ - dev::MockProver, - plonk::{create_proof, verify_proof, Circuit, ProvingKey}, - poly::{ - commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier}, - VerificationStrategy, - }, - transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, -}; -use crate::util::arithmetic::CurveAffine; -use rand_chacha::rand_core::RngCore; -use std::{fs, io::Cursor}; - -mod circuit; -// mod ipa; -mod kzg; - -pub use circuit::standard::StandardPlonk; - -pub fn read_or_create_srs<'a, C: CurveAffine, P: ParamsProver<'a, C>>( - dir: &str, - k: u32, - setup: impl Fn(u32) -> P, -) -> P { - let path = format!("{dir}/k-{k}.srs"); - match fs::File::open(path.as_str()) { - Ok(mut file) => P::read(&mut file).unwrap(), - Err(_) => { - fs::create_dir_all(dir).unwrap(); - let params = setup(k); - params.write(&mut fs::File::create(path).unwrap()).unwrap(); - params - } - } -} - -pub fn create_proof_checked<'a, S, C, P, V, VS, TW, TR, EC, R>( - params: &'a S::ParamsProver, - pk: &ProvingKey, - circuits: &[C], - instances: &[&[&[S::Scalar]]], - mut rng: R, - finalize: impl Fn(Vec, VS::Output) -> Vec, -) -> Vec -where - S: CommitmentScheme, - S::ParamsVerifier: 'a, - C: Circuit, - P: Prover<'a, S>, - V: Verifier<'a, S>, - VS: VerificationStrategy<'a, S, V>, - TW: TranscriptWriterBuffer, S::Curve, EC>, - TR: TranscriptReadBuffer>, S::Curve, EC>, - EC: EncodedChallenge, - R: RngCore + Send, -{ - for (circuit, instances) in circuits.iter().zip(instances.iter()) { - MockProver::run( - params.k(), - circuit, - instances.iter().map(|instance| instance.to_vec()).collect(), - ) - .unwrap() - .assert_satisfied(); - } - - let proof = { - let mut transcript = TW::init(Vec::new()); - create_proof::( - params, - pk, - circuits, - instances, - &mut rng, - &mut transcript, - ) - .unwrap(); - transcript.finalize() - }; - let output = { - let params = params.verifier_params(); - let strategy = VS::new(params); - let mut transcript = TR::init(Cursor::new(proof.clone())); - verify_proof(params, pk.get_vk(), strategy, instances, &mut transcript).unwrap() - }; - - finalize(proof, output) -} - -macro_rules! halo2_prepare { - ($dir:expr, $k:expr, $setup:expr, $config:expr, $create_circuit:expr) => {{ - use $crate::halo2_proofs::plonk::{keygen_pk, keygen_vk}; - use std::iter; - use $crate::{ - system::halo2::{compile, test::read_or_create_srs}, - util::{Itertools}, - }; - - let params = read_or_create_srs($dir, $k, $setup); - - let circuits = iter::repeat_with(|| $create_circuit) - .take($config.num_proof) - .collect_vec(); - - let pk = if $config.zk { - let vk = keygen_vk(¶ms, &circuits[0]).unwrap(); - let pk = keygen_pk(¶ms, vk, &circuits[0]).unwrap(); - pk - } else { - // TODO: Re-enable optional-zk when it's merged in pse/halo2. - unimplemented!() - }; - - let num_instance = circuits[0] - .instances() - .iter() - .map(|instances| instances.len()) - .collect(); - let protocol = compile( - ¶ms, - pk.get_vk(), - $config.with_num_instance(num_instance), - ); - - /* assert fails when fixed column is all 0s - assert_eq!( - protocol.preprocessed.len(), - protocol - .preprocessed - .iter() - .map( - |ec_point| <[u8; 32]>::try_from(ec_point.to_bytes().as_ref().to_vec()).unwrap() - ) - .unique() - .count() - ); - */ - - (params, pk, protocol, circuits) - }}; -} - -macro_rules! halo2_create_snark { - ( - $commitment_scheme:ty, - $prover:ty, - $verifier:ty, - $verification_strategy:ty, - $transcript_read:ty, - $transcript_write:ty, - $encoded_challenge:ty, - $finalize:expr, - $params:expr, - $pk:expr, - $protocol:expr, - $circuits:expr - ) => {{ - use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; - use $crate::{ - loader::halo2::test::Snark, system::halo2::test::create_proof_checked, util::Itertools, - }; - - let instances = $circuits.iter().map(|circuit| circuit.instances()).collect_vec(); - let proof = { - #[allow(clippy::needless_borrow)] - let instances = instances - .iter() - .map(|instances| instances.iter().map(Vec::as_slice).collect_vec()) - .collect_vec(); - let instances = instances.iter().map(Vec::as_slice).collect_vec(); - create_proof_checked::< - $commitment_scheme, - _, - $prover, - $verifier, - $verification_strategy, - $transcript_read, - $transcript_write, - $encoded_challenge, - _, - >( - $params, - $pk, - $circuits, - &instances, - &mut ChaCha20Rng::from_seed(Default::default()), - $finalize, - ) - }; - - Snark::new($protocol.clone(), instances.into_iter().flatten().collect_vec(), proof) - }}; -} - -macro_rules! halo2_native_verify { - ( - $plonk_verifier:ty, - $params:expr, - $protocol:expr, - $instances:expr, - $transcript:expr, - $svk:expr, - $dk:expr - ) => {{ - use $crate::halo2_proofs::poly::commitment::ParamsProver; - use $crate::verifier::PlonkVerifier; - - let proof = <$plonk_verifier>::read_proof($svk, $protocol, $instances, $transcript); - assert!(<$plonk_verifier>::verify($svk, $dk, $protocol, $instances, &proof)) - }}; -} - -pub(crate) use {halo2_create_snark, halo2_native_verify, halo2_prepare}; diff --git a/snark-verifier/src/system/halo2/test/circuit.rs b/snark-verifier/src/system/halo2/test/circuit.rs deleted file mode 100644 index ab713995..00000000 --- a/snark-verifier/src/system/halo2/test/circuit.rs +++ /dev/null @@ -1,2 +0,0 @@ -// pub mod maingate; -pub mod standard; diff --git a/snark-verifier/src/system/halo2/test/circuit/maingate.rs b/snark-verifier/src/system/halo2/test/circuit/maingate.rs deleted file mode 100644 index 82d63b5e..00000000 --- a/snark-verifier/src/system/halo2/test/circuit/maingate.rs +++ /dev/null @@ -1,111 +0,0 @@ -use crate::util::arithmetic::{CurveAffine, FieldExt}; -use halo2_proofs::{ - circuit::{floor_planner::V1, Layouter, Value}, - plonk::{Circuit, ConstraintSystem, Error}, -}; -use halo2_wrong_ecc::{ - maingate::{ - MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, RangeInstructions, - RegionCtx, - }, - BaseFieldEccChip, EccConfig, -}; -use rand::RngCore; - -#[derive(Clone)] -pub struct MainGateWithRangeConfig { - main_gate_config: MainGateConfig, - range_config: RangeConfig, -} - -impl MainGateWithRangeConfig { - pub fn configure( - meta: &mut ConstraintSystem, - composition_bits: Vec, - overflow_bits: Vec, - ) -> Self { - let main_gate_config = MainGate::::configure(meta); - let range_config = - RangeChip::::configure(meta, &main_gate_config, composition_bits, overflow_bits); - MainGateWithRangeConfig { - main_gate_config, - range_config, - } - } - - pub fn main_gate(&self) -> MainGate { - MainGate::new(self.main_gate_config.clone()) - } - - pub fn range_chip(&self) -> RangeChip { - RangeChip::new(self.range_config.clone()) - } - - pub fn ecc_chip( - &self, - ) -> BaseFieldEccChip { - BaseFieldEccChip::new(EccConfig::new( - self.range_config.clone(), - self.main_gate_config.clone(), - )) - } -} - -#[derive(Clone, Default)] -pub struct MainGateWithRange(Vec); - -impl MainGateWithRange { - pub fn new(inner: Vec) -> Self { - Self(inner) - } - - pub fn rand(mut rng: R) -> Self { - Self::new(vec![F::from(rng.next_u32() as u64)]) - } - - pub fn instances(&self) -> Vec> { - vec![self.0.clone()] - } -} - -impl Circuit for MainGateWithRange { - type Config = MainGateWithRangeConfig; - type FloorPlanner = V1; - - fn without_witnesses(&self) -> Self { - Self(vec![F::zero()]) - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - MainGateWithRangeConfig::configure(meta, vec![8], vec![4, 7]) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let main_gate = config.main_gate(); - let range_chip = config.range_chip(); - range_chip.load_table(&mut layouter)?; - - let a = layouter.assign_region( - || "", - |region| { - let mut ctx = RegionCtx::new(region, 0); - range_chip.decompose(&mut ctx, Value::known(F::from(u64::MAX)), 8, 64)?; - range_chip.decompose(&mut ctx, Value::known(F::from(u32::MAX as u64)), 8, 39)?; - let a = range_chip.assign(&mut ctx, Value::known(self.0[0]), 8, 68)?; - let b = main_gate.sub_sub_with_constant(&mut ctx, &a, &a, &a, F::from(2))?; - let cond = main_gate.assign_bit(&mut ctx, Value::known(F::one()))?; - main_gate.select(&mut ctx, &a, &b, &cond)?; - - Ok(a) - }, - )?; - - main_gate.expose_public(layouter, a, 0)?; - - Ok(()) - } -} diff --git a/snark-verifier/src/system/halo2/test/circuit/standard.rs b/snark-verifier/src/system/halo2/test/circuit/standard.rs deleted file mode 100644 index bfa94df4..00000000 --- a/snark-verifier/src/system/halo2/test/circuit/standard.rs +++ /dev/null @@ -1,146 +0,0 @@ -use crate::halo2_proofs::{ - circuit::{Layouter, SimpleFloorPlanner, Value}, - plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, - poly::Rotation, -}; -use crate::util::arithmetic::FieldExt; -use halo2_base::halo2_proofs::plonk::Assigned; -use rand::RngCore; - -#[allow(dead_code)] -#[derive(Clone)] -pub struct StandardPlonkConfig { - a: Column, - b: Column, - c: Column, - q_a: Column, - q_b: Column, - q_c: Column, - q_ab: Column, - constant: Column, - instance: Column, -} - -impl StandardPlonkConfig { - pub fn configure(meta: &mut ConstraintSystem) -> Self { - let [a, b, c] = [(); 3].map(|_| meta.advice_column()); - let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); - let instance = meta.instance_column(); - - [a, b, c].map(|column| meta.enable_equality(column)); - - meta.create_gate( - "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", - |meta| { - let [a, b, c] = [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); - let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] - .map(|column| meta.query_fixed(column, Rotation::cur())); - let instance = meta.query_instance(instance, Rotation::cur()); - Some( - q_a * a.clone() - + q_b * b.clone() - + q_c * c - + q_ab * a * b - + constant - + instance, - ) - }, - ); - - StandardPlonkConfig { a, b, c, q_a, q_b, q_c, q_ab, constant, instance } - } -} - -#[derive(Clone, Default)] -pub struct StandardPlonk(F); - -impl StandardPlonk { - pub fn rand(mut rng: R) -> Self { - Self(F::from(rng.next_u32() as u64)) - } - - pub fn instances(&self) -> Vec> { - vec![vec![self.0]] - } -} - -impl Circuit for StandardPlonk { - type Config = StandardPlonkConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - meta.set_minimum_degree(4); - StandardPlonkConfig::configure(meta) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - layouter.assign_region( - || "", - |mut region| { - #[cfg(feature = "halo2-pse")] - { - region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; - region.assign_fixed(|| "", config.q_a, 0, || Value::known(-F::one()))?; - - region.assign_advice(|| "", config.a, 1, || Value::known(-F::from(5u64)))?; - for (idx, column) in (1..).zip([ - config.q_a, - config.q_b, - config.q_c, - config.q_ab, - config.constant, - ]) { - region.assign_fixed( - || "", - column, - 1, - || Value::known(F::from(idx as u64)), - )?; - } - - let a = region.assign_advice(|| "", config.a, 2, || Value::known(F::one()))?; - a.copy_advice(|| "", &mut region, config.b, 3)?; - a.copy_advice(|| "", &mut region, config.c, 4)?; - } - #[cfg(feature = "halo2-axiom")] - { - region.assign_advice(config.a, 0, Value::known(Assigned::Trivial(self.0)))?; - region.assign_fixed(config.q_a, 0, Assigned::Trivial(-F::one())); - - region.assign_advice( - config.a, - 1, - Value::known(Assigned::Trivial(-F::from(5u64))), - )?; - for (idx, column) in (1..).zip([ - config.q_a, - config.q_b, - config.q_c, - config.q_ab, - config.constant, - ]) { - region.assign_fixed(column, 1, Assigned::Trivial(F::from(idx as u64))); - } - - let a = region.assign_advice( - config.a, - 2, - Value::known(Assigned::Trivial(F::one())), - )?; - a.copy_advice(&mut region, config.b, 3); - a.copy_advice(&mut region, config.c, 4); - } - - Ok(()) - }, - ) - } -} diff --git a/snark-verifier/src/system/halo2/test/ipa.rs b/snark-verifier/src/system/halo2/test/ipa.rs deleted file mode 100644 index 07fd6efd..00000000 --- a/snark-verifier/src/system/halo2/test/ipa.rs +++ /dev/null @@ -1,143 +0,0 @@ -use crate::util::arithmetic::CurveAffine; -use halo2_proofs::poly::{ - commitment::{Params, ParamsProver}, - ipa::commitment::ParamsIPA, -}; -use std::mem::size_of; - -mod native; - -pub const TESTDATA_DIR: &str = "./src/system/halo2/test/ipa/testdata"; - -pub fn setup(k: u32) -> ParamsIPA { - ParamsIPA::new(k) -} - -pub fn w_u() -> (C, C) { - let mut buf = Vec::new(); - setup::(1).write(&mut buf).unwrap(); - - let repr = C::Repr::default(); - let repr_len = repr.as_ref().len(); - let offset = size_of::() + 4 * repr_len; - - let [w, u] = [offset, offset + repr_len].map(|offset| { - let mut repr = C::Repr::default(); - repr.as_mut() - .copy_from_slice(&buf[offset..offset + repr_len]); - C::from_bytes(&repr).unwrap() - }); - - (w, u) -} - -macro_rules! halo2_ipa_config { - ($zk:expr, $num_proof:expr) => { - $crate::system::halo2::Config::ipa() - .set_zk($zk) - .with_num_proof($num_proof) - }; - ($zk:expr, $num_proof:expr, $accumulator_indices:expr) => { - $crate::system::halo2::Config::ipa() - .set_zk($zk) - .with_num_proof($num_proof) - .with_accumulator_indices($accumulator_indices) - }; -} - -macro_rules! halo2_ipa_prepare { - ($dir:expr, $curve:path, $k:expr, $config:expr, $create_circuit:expr) => {{ - use $crate::system::halo2::test::{halo2_prepare, ipa::setup}; - - halo2_prepare!($dir, $k, setup::<$curve>, $config, $create_circuit) - }}; - (pallas::Affine, $k:expr, $config:expr, $create_circuit:expr) => {{ - use halo2_curves::pasta::pallas; - use $crate::system::halo2::test::ipa::TESTDATA_DIR; - - halo2_ipa_prepare!( - &format!("{TESTDATA_DIR}/pallas"), - pallas::Affine, - $k, - $config, - $create_circuit - ) - }}; - (vesta::Affine, $k:expr, $config:expr, $create_circuit:expr) => {{ - use halo2_curves::pasta::vesta; - use $crate::system::halo2::test::ipa::TESTDATA_DIR; - - halo2_ipa_prepare!( - &format!("{TESTDATA_DIR}/vesta"), - vesta::Affine, - $k, - $config, - $create_circuit - ) - }}; -} - -macro_rules! halo2_ipa_create_snark { - ( - $prover:ty, - $verifier:ty, - $transcript_read:ty, - $transcript_write:ty, - $encoded_challenge:ty, - $params:expr, - $pk:expr, - $protocol:expr, - $circuits:expr - ) => {{ - use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme; - use $crate::{ - system::halo2::{strategy::ipa::SingleStrategy, test::halo2_create_snark}, - util::arithmetic::GroupEncoding, - }; - - halo2_create_snark!( - IPACommitmentScheme<_>, - $prover, - $verifier, - SingleStrategy<_>, - $transcript_read, - $transcript_write, - $encoded_challenge, - |proof, g| { [proof, g.to_bytes().as_ref().to_vec()].concat() }, - $params, - $pk, - $protocol, - $circuits - ) - }}; -} - -macro_rules! halo2_ipa_native_verify { - ( - $plonk_verifier:ty, - $params:expr, - $protocol:expr, - $instances:expr, - $transcript:expr - ) => {{ - use $crate::{ - pcs::ipa::{Bgh19SuccinctVerifyingKey, IpaDecidingKey}, - system::halo2::test::{halo2_native_verify, ipa::w_u}, - }; - - let (w, u) = w_u(); - halo2_native_verify!( - $plonk_verifier, - $params, - $protocol, - $instances, - $transcript, - &Bgh19SuccinctVerifyingKey::new($protocol.domain.clone(), $params.get_g()[0], w, u), - &IpaDecidingKey::new($params.get_g().to_vec()) - ) - }}; -} - -pub(crate) use { - halo2_ipa_config, halo2_ipa_create_snark, halo2_ipa_native_verify, halo2_ipa_prepare, -}; diff --git a/snark-verifier/src/system/halo2/test/ipa/native.rs b/snark-verifier/src/system/halo2/test/ipa/native.rs deleted file mode 100644 index 7d9e09bb..00000000 --- a/snark-verifier/src/system/halo2/test/ipa/native.rs +++ /dev/null @@ -1,59 +0,0 @@ -use crate::{ - pcs::ipa::{Bgh19, Ipa}, - system::halo2::test::ipa::{ - halo2_ipa_config, halo2_ipa_create_snark, halo2_ipa_native_verify, halo2_ipa_prepare, - }, - system::halo2::test::StandardPlonk, - verifier::Plonk, -}; -use halo2_curves::pasta::pallas; -use halo2_proofs::{ - poly::ipa::multiopen::{ProverIPA, VerifierIPA}, - transcript::{Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer}, -}; -use paste::paste; -use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; - -macro_rules! test { - (@ $name:ident, $k:expr, $config:expr, $create_cirucit:expr, $prover:ty, $verifier:ty, $plonk_verifier:ty) => { - paste! { - #[test] - fn []() { - let (params, pk, protocol, circuits) = halo2_ipa_prepare!( - pallas::Affine, - $k, - $config, - $create_cirucit - ); - let snark = halo2_ipa_create_snark!( - $prover, - $verifier, - Blake2bWrite<_, _, _>, - Blake2bRead<_, _, _>, - Challenge255<_>, - ¶ms, - &pk, - &protocol, - &circuits - ); - halo2_ipa_native_verify!( - $plonk_verifier, - params, - &snark.protocol, - &snark.instances, - &mut Blake2bRead::<_, pallas::Affine, _>::init(snark.proof.as_slice()) - ); - } - } - }; - ($name:ident, $k:expr, $config:expr, $create_cirucit:expr) => { - test!(@ $name, $k, $config, $create_cirucit, ProverIPA, VerifierIPA, Plonk::>); - } -} - -test!( - zk_standard_plonk_rand, - 9, - halo2_ipa_config!(true, 1), - StandardPlonk::rand(ChaCha20Rng::from_seed(Default::default())) -); diff --git a/snark-verifier/src/system/halo2/test/kzg.rs b/snark-verifier/src/system/halo2/test/kzg.rs deleted file mode 100644 index 6cf145db..00000000 --- a/snark-verifier/src/system/halo2/test/kzg.rs +++ /dev/null @@ -1,106 +0,0 @@ -use crate::halo2_proofs::poly::kzg::commitment::ParamsKZG; -use crate::util::arithmetic::MultiMillerLoop; -use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; - -mod native; - -#[cfg(feature = "loader_evm")] -mod evm; - -#[cfg(feature = "loader_halo2")] -pub(crate) mod halo2; - -#[allow(dead_code)] -pub const TESTDATA_DIR: &str = "./src/system/halo2/test/data"; - -pub const LIMBS: usize = 3; -pub const BITS: usize = 88; - -pub fn setup(k: u32) -> ParamsKZG { - ParamsKZG::::setup(k, ChaCha20Rng::from_seed(Default::default())) -} - -macro_rules! halo2_kzg_config { - ($zk:expr, $num_proof:expr) => { - $crate::system::halo2::Config::kzg().set_zk($zk).with_num_proof($num_proof) - }; - ($zk:expr, $num_proof:expr, $accumulator_indices:expr) => { - $crate::system::halo2::Config::kzg() - .set_zk($zk) - .with_num_proof($num_proof) - .with_accumulator_indices($accumulator_indices) - }; -} - -macro_rules! halo2_kzg_prepare { - ($k:expr, $config:expr, $create_circuit:expr) => {{ - use $crate::halo2_curves::bn256::Bn256; - #[allow(unused_imports)] - use $crate::system::halo2::test::{ - halo2_prepare, - kzg::{setup, TESTDATA_DIR}, - }; - - halo2_prepare!(TESTDATA_DIR, $k, setup::, $config, $create_circuit) - }}; -} - -macro_rules! halo2_kzg_create_snark { - ( - $prover:ty, - $verifier:ty, - $transcript_read:ty, - $transcript_write:ty, - $encoded_challenge:ty, - $params:expr, - $pk:expr, - $protocol:expr, - $circuits:expr - ) => {{ - use $crate::halo2_proofs::poly::kzg::{ - commitment::KZGCommitmentScheme, strategy::SingleStrategy, - }; - use $crate::system::halo2::test::halo2_create_snark; - - halo2_create_snark!( - KZGCommitmentScheme<_>, - $prover, - $verifier, - SingleStrategy<_>, - $transcript_read, - $transcript_write, - $encoded_challenge, - |proof, _| proof, - $params, - $pk, - $protocol, - $circuits - ) - }}; -} - -macro_rules! halo2_kzg_native_verify { - ( - $plonk_verifier:ty, - $params:expr, - $protocol:expr, - $instances:expr, - $transcript:expr - ) => {{ - use $crate::system::halo2::test::halo2_native_verify; - - halo2_native_verify!( - $plonk_verifier, - $params, - $protocol, - $instances, - $transcript, - &$params.get_g()[0].into(), - &($params.g2(), $params.s_g2()).into() - ) - }}; -} - -pub(crate) use { - halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, halo2_kzg_prepare, -}; diff --git a/snark-verifier/src/system/halo2/test/kzg/evm.rs b/snark-verifier/src/system/halo2/test/kzg/evm.rs deleted file mode 100644 index 80439205..00000000 --- a/snark-verifier/src/system/halo2/test/kzg/evm.rs +++ /dev/null @@ -1,137 +0,0 @@ -use crate::{halo2_curves, halo2_proofs}; -use crate::{ - loader::native::NativeLoader, - pcs::kzg::{Bdfg21, Gwc19, Kzg, LimbsEncoding}, - system::halo2::{ - test::{ - kzg::{ - self, halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, - halo2_kzg_prepare, BITS, LIMBS, - }, - StandardPlonk, - }, - transcript::evm::{ChallengeEvm, EvmTranscript}, - }, - verifier::Plonk, -}; -use halo2_curves::bn256::{Bn256, G1Affine}; -use halo2_proofs::poly::kzg::multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}; -use paste::paste; -use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; - -macro_rules! halo2_kzg_evm_verify { - ($plonk_verifier:ty, $params:expr, $protocol:expr, $instances:expr, $proof:expr) => {{ - use halo2_curves::bn256::{Bn256, Fq, Fr}; - use halo2_proofs::poly::commitment::ParamsProver; - use std::rc::Rc; - use $crate::{ - loader::evm::{compile_yul, encode_calldata, execute, EvmLoader}, - system::halo2::{ - test::kzg::{BITS, LIMBS}, - transcript::evm::EvmTranscript, - }, - util::Itertools, - verifier::PlonkVerifier, - }; - - let loader = EvmLoader::new::(); - let deployment_code = { - let svk = $params.get_g()[0].into(); - let dk = ($params.g2(), $params.s_g2()).into(); - let protocol = $protocol.loaded(&loader); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); - let instances = transcript - .load_instances($instances.iter().map(|instances| instances.len()).collect_vec()); - let proof = <$plonk_verifier>::read_proof(&svk, &protocol, &instances, &mut transcript); - <$plonk_verifier>::verify(&svk, &dk, &protocol, &instances, &proof); - - compile_yul(&loader.yul_code()) - }; - - let (accept, total_cost, costs) = - execute(deployment_code, encode_calldata($instances, &$proof)); - - loader.print_gas_metering(costs); - println!("Total gas cost: {}", total_cost); - - assert!(accept); - }}; -} - -macro_rules! test { - (@ $(#[$attr:meta],)* $prefix:ident, $name:ident, $k:expr, $config:expr, $create_circuit:expr, $prover:ty, $verifier:ty, $plonk_verifier:ty) => { - paste! { - $(#[$attr])* - fn []() { - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - $k, - $config, - $create_circuit - ); - let snark = halo2_kzg_create_snark!( - $prover, - $verifier, - EvmTranscript, - EvmTranscript, - ChallengeEvm<_>, - ¶ms, - &pk, - &protocol, - &circuits - ); - halo2_kzg_native_verify!( - $plonk_verifier, - params, - &snark.protocol, - &snark.instances, - &mut EvmTranscript::<_, NativeLoader, _, _>::new(snark.proof.as_slice()) - ); - halo2_kzg_evm_verify!( - $plonk_verifier, - params, - &snark.protocol, - &snark.instances, - snark.proof - ); - } - } - }; - ($name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - test!(@ #[test], shplonk, $name, $k, $config, $create_circuit, ProverSHPLONK<_>, VerifierSHPLONK<_>, Plonk, LimbsEncoding>); - test!(@ #[test], plonk, $name, $k, $config, $create_circuit, ProverGWC<_>, VerifierGWC<_>, Plonk, LimbsEncoding>); - }; - ($(#[$attr:meta],)* $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - test!(@ #[test] $(,#[$attr])*, plonk, $name, $k, $config, $create_circuit, ProverGWC<_>, VerifierGWC<_>, Plonk, LimbsEncoding>); - }; -} - -test!( - zk_standard_plonk_rand, - 9, - halo2_kzg_config!(true, 1), - StandardPlonk::rand(ChaCha20Rng::from_seed(Default::default())) -); -/* -test!( - zk_main_gate_with_range_with_mock_kzg_accumulator, - 9, - halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - main_gate_with_range_with_mock_kzg_accumulator::() -); -*/ -test!( - #[cfg(feature = "loader_halo2")], - #[ignore = "cause it requires 32GB memory to run"], - zk_accumulation_two_snark, - 22, - halo2_kzg_config!(true, 1, Some((0..4 * LIMBS).map(|idx| (0, idx)).collect())), - kzg::halo2::Accumulation::two_snark() -); -test!( - #[cfg(feature = "loader_halo2")], - #[ignore = "cause it requires 32GB memory to run"], - zk_accumulation_two_snark_with_accumulator, - 22, - halo2_kzg_config!(true, 1, Some((0..4 * LIMBS).map(|idx| (0, idx)).collect())), - kzg::halo2::Accumulation::two_snark_with_accumulator() -); diff --git a/snark-verifier/src/system/halo2/test/kzg/halo2.rs b/snark-verifier/src/system/halo2/test/kzg/halo2.rs deleted file mode 100644 index 9090a0a1..00000000 --- a/snark-verifier/src/system/halo2/test/kzg/halo2.rs +++ /dev/null @@ -1,618 +0,0 @@ -use crate::halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; -use crate::halo2_proofs::{ - circuit::{Layouter, SimpleFloorPlanner, Value}, - plonk::{self, create_proof, verify_proof, Circuit, Column, ConstraintSystem, Instance}, - poly::{ - commitment::ParamsProver, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, -}; -use crate::{ - loader::{ - self, - halo2::test::{Snark, SnarkWitness}, - native::NativeLoader, - }, - pcs::{ - kzg::{ - Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgAsProvingKey, KzgAsVerifyingKey, - KzgSuccinctVerifyingKey, LimbsEncoding, - }, - AccumulationScheme, AccumulationSchemeProver, - }, - system::halo2::{ - test::{ - kzg::{ - halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, - halo2_kzg_prepare, BITS, LIMBS, - }, - StandardPlonk, - }, - transcript::halo2::{ChallengeScalar, PoseidonTranscript as GenericPoseidonTranscript}, - }, - util::{arithmetic::fe_to_limbs, Itertools}, - verifier::{self, PlonkVerifier}, -}; -use ark_std::{end_timer, start_timer}; -use halo2_base::{Context, ContextParams}; -use halo2_ecc::ecc::EccChip; -use halo2_ecc::fields::fp::FpConfig; -use paste::paste; -use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; -use serde::{Deserialize, Serialize}; -use std::fs::File; -use std::{ - io::{Cursor, Read, Write}, - rc::Rc, -}; - -const T: usize = 5; -const RATE: usize = 4; -const R_F: usize = 8; -const R_P: usize = 60; - -type BaseFieldEccChip = halo2_ecc::ecc::BaseFieldEccChip; -type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; -type PoseidonTranscript = GenericPoseidonTranscript; - -type Pcs = Kzg; -type Svk = KzgSuccinctVerifyingKey; -type As = KzgAs; -type AsPk = KzgAsProvingKey; -type AsVk = KzgAsVerifyingKey; -type Plonk = verifier::Plonk>; - -// for tuning the circuit -#[derive(Serialize, Deserialize)] -pub struct Halo2VerifierCircuitConfigParams { - pub strategy: halo2_ecc::fields::fp::FpStrategy, - pub degree: u32, - pub num_advice: usize, - pub num_lookup_advice: usize, - pub num_fixed: usize, - pub lookup_bits: usize, - pub limb_bits: usize, - pub num_limbs: usize, -} - -pub fn load_verify_circuit_degree() -> u32 { - let path = "./configs/verify_circuit.config"; - let params: Halo2VerifierCircuitConfigParams = - serde_json::from_reader(File::open(path).unwrap_or_else(|err| panic!("{err:?}"))).unwrap(); - params.degree -} - -#[derive(Clone)] -pub struct Halo2VerifierCircuitConfig { - pub base_field_config: halo2_ecc::fields::fp::FpConfig, - pub instance: Column, -} - -impl Halo2VerifierCircuitConfig { - pub fn configure( - meta: &mut ConstraintSystem, - params: Halo2VerifierCircuitConfigParams, - ) -> Self { - assert!( - params.limb_bits == BITS && params.num_limbs == LIMBS, - "For now we fix limb_bits = {}, otherwise change code", - BITS - ); - let base_field_config = halo2_ecc::fields::fp::FpConfig::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - halo2_base::utils::modulus::(), - 0, - params.degree as usize, - ); - - let instance = meta.instance_column(); - meta.enable_equality(instance); - - Self { base_field_config, instance } - } -} - -pub fn accumulate<'a>( - svk: &Svk, - loader: &Rc>, - snarks: &[SnarkWitness], - as_vk: &AsVk, - as_proof: Value<&'_ [u8]>, -) -> KzgAccumulator>> { - let assign_instances = |instances: &[Vec>]| { - instances - .iter() - .map(|instances| { - instances.iter().map(|instance| loader.assign_scalar(*instance)).collect_vec() - }) - .collect_vec() - }; - - let mut accumulators = snarks - .iter() - .flat_map(|snark| { - let protocol = snark.protocol.loaded(loader); - let instances = assign_instances(&snark.instances); - let mut transcript = - PoseidonTranscript::, _>::new(loader, snark.proof()); - let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript); - Plonk::succinct_verify(svk, &protocol, &instances, &proof) - }) - .collect_vec(); - - let acccumulator = if accumulators.len() > 1 { - let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); - let proof = As::read_proof(as_vk, &accumulators, &mut transcript).unwrap(); - As::verify(as_vk, &accumulators, &proof).unwrap() - } else { - accumulators.pop().unwrap() - }; - - acccumulator -} - -pub struct Accumulation { - svk: Svk, - snarks: Vec>, - instances: Vec, - as_vk: AsVk, - as_proof: Value>, -} - -impl Accumulation { - pub fn accumulator_indices() -> Vec<(usize, usize)> { - (0..4 * LIMBS).map(|idx| (0, idx)).collect() - } - - pub fn new( - params: &ParamsKZG, - snarks: impl IntoIterator>, - ) -> Self { - let svk = params.get_g()[0].into(); - let snarks = snarks.into_iter().collect_vec(); - - let mut accumulators = snarks - .iter() - .flat_map(|snark| { - let mut transcript = - PoseidonTranscript::::new(snark.proof.as_slice()); - let proof = - Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript); - Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof) - }) - .collect_vec(); - - let as_pk = AsPk::new(Some((params.get_g()[0], params.get_g()[1]))); - let (accumulator, as_proof) = if accumulators.len() > 1 { - let mut transcript = PoseidonTranscript::::new(Vec::new()); - let accumulator = As::create_proof( - &as_pk, - &accumulators, - &mut transcript, - ChaCha20Rng::from_seed(Default::default()), - ) - .unwrap(); - (accumulator, Value::known(transcript.finalize())) - } else { - (accumulators.pop().unwrap(), Value::unknown()) - }; - - let KzgAccumulator { lhs, rhs } = accumulator; - let instances = [lhs.x, lhs.y, rhs.x, rhs.y].map(fe_to_limbs::<_, _, LIMBS, BITS>).concat(); - - Self { - svk, - snarks: snarks.into_iter().map_into().collect(), - instances, - as_vk: as_pk.vk(), - as_proof, - } - } - - pub fn two_snark() -> Self { - let (params, snark1) = { - const K: u32 = 9; - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - K, - halo2_kzg_config!(true, 1), - StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) - ); - let snark = halo2_kzg_create_snark!( - ProverSHPLONK<_>, - VerifierSHPLONK<_>, - PoseidonTranscript<_, _>, - PoseidonTranscript<_, _>, - ChallengeScalar<_>, - ¶ms, - &pk, - &protocol, - &circuits - ); - (params, snark) - }; - let snark2 = { - const K: u32 = 9; - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - K, - halo2_kzg_config!(true, 1), - StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) - ); - halo2_kzg_create_snark!( - ProverSHPLONK<_>, - VerifierSHPLONK<_>, - PoseidonTranscript<_, _>, - PoseidonTranscript<_, _>, - ChallengeScalar<_>, - ¶ms, - &pk, - &protocol, - &circuits - ) - }; - Self::new(¶ms, [snark1, snark2]) - } - - pub fn two_snark_with_accumulator() -> Self { - let (params, pk, protocol, circuits) = { - const K: u32 = 22; - halo2_kzg_prepare!( - K, - halo2_kzg_config!(true, 2, Some(Self::accumulator_indices())), - Self::two_snark() - ) - }; - let snark = halo2_kzg_create_snark!( - ProverSHPLONK<_>, - VerifierSHPLONK<_>, - PoseidonTranscript<_, _>, - PoseidonTranscript<_, _>, - ChallengeScalar<_>, - ¶ms, - &pk, - &protocol, - &circuits - ); - Self::new(¶ms, [snark]) - } - - pub fn instances(&self) -> Vec> { - vec![self.instances.clone()] - } - - pub fn as_proof(&self) -> Value<&[u8]> { - self.as_proof.as_ref().map(Vec::as_slice) - } -} - -impl Circuit for Accumulation { - type Config = Halo2VerifierCircuitConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - svk: self.svk, - snarks: self.snarks.iter().map(SnarkWitness::without_witnesses).collect(), - instances: Vec::new(), - as_vk: self.as_vk, - as_proof: Value::unknown(), - } - } - - fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - let path = "./configs/verify_circuit.config"; - let params_str = - std::fs::read_to_string(path).expect(format!("{} should exist", path).as_str()); - let params: Halo2VerifierCircuitConfigParams = - serde_json::from_str(params_str.as_str()).unwrap(); - - assert!( - params.limb_bits == BITS && params.num_limbs == LIMBS, - "For now we fix limb_bits = {}, otherwise change code", - BITS - ); - let base_field_config = FpConfig::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - halo2_base::utils::modulus::(), - 0, - params.degree as usize, - ); - - let instance = meta.instance_column(); - meta.enable_equality(instance); - - Self::Config { base_field_config, instance } - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), plonk::Error> { - let mut layouter = layouter.namespace(|| "aggregation"); - config.base_field_config.load_lookup_table(&mut layouter)?; - - // Need to trick layouter to skip first pass in get shape mode - let mut first_pass = halo2_base::SKIP_FIRST_PASS; - let mut assigned_instances = None; - layouter.assign_region( - || "", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - let ctx = Context::new( - region, - ContextParams { - max_rows: config.base_field_config.range.gate.max_rows, - num_context_ids: 1, - fixed_columns: config.base_field_config.range.gate.constants.clone(), - }, - ); - - let loader = - Halo2Loader::new(EccChip::construct(config.base_field_config.clone()), ctx); - let KzgAccumulator { lhs, rhs } = - accumulate(&self.svk, &loader, &self.snarks, &self.as_vk, self.as_proof()); - - let lhs = lhs.assigned(); - let rhs = rhs.assigned(); - // REQUIRED STEP - config.base_field_config.finalize(&mut loader.ctx_mut()); - - let instances: Vec<_> = lhs - .x - .truncation - .limbs - .iter() - .chain(lhs.y.truncation.limbs.iter()) - .chain(rhs.x.truncation.limbs.iter()) - .chain(rhs.y.truncation.limbs.iter()) - .map(|assigned| assigned.cell().clone()) - .collect(); - assigned_instances = Some(instances); - - Ok(()) - }, - )?; - // TODO: use less instances by following Scroll's strategy of keeping only last bit of y coordinate - let mut layouter = layouter.namespace(|| "expose"); - for (i, cell) in assigned_instances.unwrap().into_iter().enumerate() { - layouter.constrain_instance(cell, config.instance, i); - } - Ok(()) - } -} - -macro_rules! test { - (@ $(#[$attr:meta],)* $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - paste! { - $(#[$attr])* - fn []() { - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - $k, - $config, - $create_circuit - ); - let snark = halo2_kzg_create_snark!( - ProverSHPLONK<_>, - VerifierSHPLONK<_>, - Blake2bWrite<_, _, _>, - Blake2bRead<_, _, _>, - Challenge255<_>, - ¶ms, - &pk, - &protocol, - &circuits - ); - halo2_kzg_native_verify!( - Plonk, - params, - &snark.protocol, - &snark.instances, - &mut Blake2bRead::<_, G1Affine, _>::init(snark.proof.as_slice()) - ); - } - } - }; - ($name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - test!(@ #[test], $name, $k, $config, $create_circuit); - }; - ($(#[$attr:meta],)* $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - test!(@ #[test] $(,#[$attr])*, $name, $k, $config, $create_circuit); - }; -} - -test!( - // create aggregation circuit A that aggregates two simple snarks {B,C}, then verify proof of this aggregation circuit A - zk_aggregate_two_snarks, - 21, - halo2_kzg_config!(true, 1, Some(Accumulation::accumulator_indices())), - Accumulation::two_snark() -); -test!( - // create aggregation circuit A that aggregates two copies of same aggregation circuit B that aggregates two simple snarks {C, D}, then verifies proof of this aggregation circuit A - zk_aggregate_two_snarks_with_accumulator, - 22, // 22 = 21 + 1 since there are two copies of circuit B - halo2_kzg_config!(true, 1, Some(Accumulation::accumulator_indices())), - Accumulation::two_snark_with_accumulator() -); - -pub trait TargetCircuit: Circuit { - const TARGET_CIRCUIT_K: u32; - const PUBLIC_INPUT_SIZE: usize; - const N_PROOFS: usize; - const NAME: &'static str; - - fn default_circuit() -> Self; - fn instances(&self) -> Vec>; -} - -pub fn create_snark() -> (ParamsKZG, Snark) { - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - T::TARGET_CIRCUIT_K, - halo2_kzg_config!(true, T::N_PROOFS), - T::default_circuit() - ); - - let proof_time = start_timer!(|| "create proof"); - // usual shenanigans to turn nested Vec into nested slice - let instances0: Vec>> = - circuits.iter().map(|circuit| T::instances(circuit)).collect_vec(); - let instances1: Vec> = instances0 - .iter() - .map(|instances| instances.iter().map(Vec::as_slice).collect_vec()) - .collect_vec(); - let instances2: Vec<&[&[Fr]]> = instances1.iter().map(Vec::as_slice).collect_vec(); - // TODO: need to cache the instances as well! - - let proof = { - let path = format!("./data/proof_{}.data", T::NAME); - match std::fs::File::open(path.as_str()) { - Ok(mut file) => { - let mut buf = vec![]; - file.read_to_end(&mut buf).unwrap(); - buf - } - Err(_) => { - let mut transcript = PoseidonTranscript::>::init(Vec::new()); - create_proof::, ProverSHPLONK<_>, _, _, _, _>( - ¶ms, - &pk, - &circuits, - instances2.as_slice(), - &mut ChaCha20Rng::from_seed(Default::default()), - &mut transcript, - ) - .unwrap(); - let proof = transcript.finalize(); - let mut file = std::fs::File::create(path.as_str()) - .expect(format!("{:?} should exist", path).as_str()); - file.write_all(&proof).unwrap(); - proof - } - } - }; - end_timer!(proof_time); - - let verify_time = start_timer!(|| "verify proof"); - { - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = - >> as TranscriptReadBuffer< - _, - _, - _, - >>::init(Cursor::new(proof.clone())); - verify_proof::<_, VerifierSHPLONK<_>, _, _, _>( - verifier_params, - pk.get_vk(), - strategy, - instances2.as_slice(), - &mut transcript, - ) - .unwrap() - } - end_timer!(verify_time); - - (params, Snark::new(protocol.clone(), instances0.into_iter().flatten().collect_vec(), proof)) -} - -/* -pub mod zkevm { - use super::*; - use zkevm_circuit_benchmarks::evm_circuit::TestCircuit as EvmCircuit; - use zkevm_circuits::evm_circuit::witness::RwMap; - use zkevm_circuits::state_circuit::StateCircuit; - - impl TargetCircuit for EvmCircuit { - const TARGET_CIRCUIT_K: u32 = 18; - const PUBLIC_INPUT_SIZE: usize = 0; // (Self::TARGET_CIRCUIT_K * 2) as usize; - const N_PROOFS: usize = 1; - const NAME: &'static str = "zkevm"; - - fn default_circuit() -> Self { - Self::default() - } - fn instances(&self) -> Vec> { - vec![] - } - } - - fn evm_verify_circuit() -> Accumulation { - let (params, evm_snark) = create_snark::>(); - println!("creating aggregation circuit"); - Accumulation::new(¶ms, [evm_snark]) - } - - test!( - bench_evm_circuit, - load_verify_circuit_degree(), - halo2_kzg_config!(true, 1, Accumulation::accumulator_indices()), - evm_verify_circuit() - ); - - impl TargetCircuit for StateCircuit { - const TARGET_CIRCUIT_K: u32 = 18; - const PUBLIC_INPUT_SIZE: usize = 0; //(Self::TARGET_CIRCUIT_K * 2) as usize; - const N_PROOFS: usize = 1; - const NAME: &'static str = "state-circuit"; - - fn default_circuit() -> Self { - StateCircuit::::new(Fr::default(), RwMap::default(), 1) - } - fn instances(&self) -> Vec> { - self.instance() - } - } - - fn state_verify_circuit() -> Accumulation { - let (params, snark) = create_snark::>(); - println!("creating aggregation circuit"); - Accumulation::new(¶ms, [snark]) - } - - test!( - bench_state_circuit, - load_verify_circuit_degree(), - halo2_kzg_config!(true, 1, Accumulation::accumulator_indices()), - state_verify_circuit() - ); - - fn evm_and_state_aggregation_circuit() -> Accumulation { - let (params, evm_snark) = create_snark::>(); - let (_, state_snark) = create_snark::>(); - println!("creating aggregation circuit"); - Accumulation::new(¶ms, [evm_snark, state_snark]) - } - - test!( - bench_evm_and_state, - load_verify_circuit_degree(), - halo2_kzg_config!(true, 1, Accumulation::accumulator_indices()), - evm_and_state_aggregation_circuit() - ); -} -*/ diff --git a/snark-verifier/src/system/halo2/test/kzg/native.rs b/snark-verifier/src/system/halo2/test/kzg/native.rs deleted file mode 100644 index 0801a317..00000000 --- a/snark-verifier/src/system/halo2/test/kzg/native.rs +++ /dev/null @@ -1,70 +0,0 @@ -use crate::halo2_curves::bn256::{Bn256, G1Affine}; -use crate::halo2_proofs::{ - poly::kzg::multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, - transcript::{Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer}, -}; -use crate::{ - pcs::kzg::{Bdfg21, Gwc19, Kzg, LimbsEncoding}, - system::halo2::test::{ - kzg::{ - halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, halo2_kzg_prepare, - BITS, LIMBS, - }, - StandardPlonk, - }, - verifier::Plonk, -}; -use paste::paste; -use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; - -macro_rules! test { - (@ $prefix:ident, $name:ident, $k:expr, $config:expr, $create_cirucit:expr, $prover:ty, $verifier:ty, $plonk_verifier:ty) => { - paste! { - #[test] - fn []() { - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - $k, - $config, - $create_cirucit - ); - let snark = halo2_kzg_create_snark!( - $prover, - $verifier, - Blake2bWrite<_, _, _>, - Blake2bRead<_, _, _>, - Challenge255<_>, - ¶ms, - &pk, - &protocol, - &circuits - ); - halo2_kzg_native_verify!( - $plonk_verifier, - params, - &snark.protocol, - &snark.instances, - &mut Blake2bRead::<_, G1Affine, _>::init(snark.proof.as_slice()) - ); - } - } - }; - ($name:ident, $k:expr, $config:expr, $create_cirucit:expr) => { - test!(@ shplonk, $name, $k, $config, $create_cirucit, ProverSHPLONK<_>, VerifierSHPLONK<_>, Plonk, LimbsEncoding>); - test!(@ plonk, $name, $k, $config, $create_cirucit, ProverGWC<_>, VerifierGWC<_>, Plonk, LimbsEncoding>); - } -} - -test!( - zk_standard_plonk_rand, - 9, - halo2_kzg_config!(true, 2), - StandardPlonk::rand(ChaCha20Rng::from_seed(Default::default())) -); -/* -test!( - zk_main_gate_with_range_with_mock_kzg_accumulator, - 9, - halo2_kzg_config!(true, 2, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - main_gate_with_range_with_mock_kzg_accumulator::() -); -*/ diff --git a/snark-verifier/src/system/halo2/transcript.rs b/snark-verifier/src/system/halo2/transcript.rs index 8b97f968..9cfd6b89 100644 --- a/snark-verifier/src/system/halo2/transcript.rs +++ b/snark-verifier/src/system/halo2/transcript.rs @@ -1,4 +1,7 @@ +//! Transcripts implemented with both `halo2_proofs::transcript` and +//! `crate::util::transcript`. use crate::halo2_proofs; +use halo2_proofs::transcript::{Blake2bRead, Blake2bWrite, Challenge255}; use crate::{ loader::native::{self, NativeLoader}, util::{ @@ -7,7 +10,6 @@ use crate::{ }, Error, }; -use halo2_proofs::transcript::{Blake2bRead, Blake2bWrite, Challenge255}; use std::io::{Read, Write}; #[cfg(feature = "loader_evm")] diff --git a/snark-verifier/src/system/halo2/transcript/evm.rs b/snark-verifier/src/system/halo2/transcript/evm.rs index 909bb71d..c71c9e79 100644 --- a/snark-verifier/src/system/halo2/transcript/evm.rs +++ b/snark-verifier/src/system/halo2/transcript/evm.rs @@ -1,7 +1,9 @@ +//! Transcript for verifier on EVM. + use crate::halo2_proofs; use crate::{ loader::{ - evm::{loader::Value, u256_to_fe, EcPoint, EvmLoader, MemoryChunk, Scalar}, + evm::{loader::Value, u256_to_fe, util::MemoryChunk, EcPoint, EvmLoader, Scalar, U256}, native::{self, NativeLoader}, Loader, }, @@ -13,7 +15,6 @@ use crate::{ }, Error, }; -use ethereum_types::U256; use halo2_proofs::transcript::EncodedChallenge; use std::{ io::{self, Read, Write}, @@ -21,6 +22,9 @@ use std::{ marker::PhantomData, rc::Rc, }; + +/// Transcript for verifier on EVM using keccak256 as hasher. +#[derive(Debug)] pub struct EvmTranscript, S, B> { loader: L, stream: S, @@ -33,6 +37,8 @@ where C: CurveAffine, C::Scalar: PrimeField, { + /// Initialize [`EvmTranscript`] given [`Rc`] and pre-allocate an + /// u256 for `transcript_initial_state`. pub fn new(loader: &Rc) -> Self { let ptr = loader.allocate(0x20); assert_eq!(ptr, 0); @@ -41,6 +47,7 @@ where Self { loader: loader.clone(), stream: 0, buf, _marker: PhantomData } } + /// Load `num_instance` instances from calldata to memory. pub fn load_instances(&mut self, num_instance: Vec) -> Vec> { num_instance .into_iter() @@ -66,6 +73,8 @@ where &self.loader } + /// Does not allow the input to be a one-byte sequence, because the Transcript trait only supports writing scalars and elliptic curve points. + /// If the one-byte sequence [0x01] is a valid input to the transcript, the empty input [] will have the same transcript result as [0x01]. fn squeeze_challenge(&mut self) -> Scalar { let len = if self.buf.len() == 0x20 { assert_eq!(self.loader.ptr(), self.buf.end()); @@ -144,6 +153,8 @@ impl EvmTranscript> where C: CurveAffine, { + /// Initialize [`EvmTranscript`] given readable or writeable stream for + /// verifying or proving with [`NativeLoader`]. pub fn new(stream: S) -> Self { Self { loader: NativeLoader, stream, buf: Vec::new(), _marker: PhantomData } } @@ -173,10 +184,7 @@ where fn common_ec_point(&mut self, ec_point: &C) -> Result<(), Error> { let coordinates = Option::>::from(ec_point.coordinates()).ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Cannot write points at infinity to the transcript".to_string(), - ) + Error::Transcript(io::ErrorKind::Other, "Invalid elliptic curve point".to_string()) })?; [coordinates.x(), coordinates.y()].map(|coordinate| { @@ -239,15 +247,20 @@ where C: CurveAffine, S: Write, { + /// Returns mutable `stream`. pub fn stream_mut(&mut self) -> &mut S { &mut self.stream } + /// Finalize transcript and returns `stream`. pub fn finalize(self) -> S { self.stream } } +/// [`EncodedChallenge`] implemented for verifier on EVM, which use input in +/// big-endian as the challenge. +#[derive(Debug)] pub struct ChallengeEvm(C::Scalar) where C: CurveAffine, diff --git a/snark-verifier/src/system/halo2/transcript/halo2.rs b/snark-verifier/src/system/halo2/transcript/halo2.rs index 3c82d83b..86b1929c 100644 --- a/snark-verifier/src/system/halo2/transcript/halo2.rs +++ b/snark-verifier/src/system/halo2/transcript/halo2.rs @@ -1,3 +1,5 @@ +//! Transcript for verifier in [`halo2_proofs`] circuit. + use crate::halo2_proofs; use crate::{ loader::{ @@ -7,23 +9,24 @@ use crate::{ }, util::{ arithmetic::{fe_to_fe, CurveAffine, PrimeField}, - hash::Poseidon, + hash::{OptimizedPoseidonSpec, Poseidon}, transcript::{Transcript, TranscriptRead, TranscriptWrite}, Itertools, }, Error, }; -use halo2_proofs::{circuit::Value, transcript::EncodedChallenge}; +use halo2_proofs::transcript::EncodedChallenge; use std::{ io::{self, Read, Write}, rc::Rc, }; /// Encoding that encodes elliptic curve point into native field elements. -pub trait NativeEncoding<'a, C>: EccInstructions<'a, C> +pub trait NativeEncoding: EccInstructions where C: CurveAffine, { + /// Encode. fn encode( &self, ctx: &mut Self::Context, @@ -31,6 +34,10 @@ where ) -> Result, Error>; } +#[derive(Debug)] +/// Transcript for verifier in [`halo2_proofs`] circuit using poseidon hasher. +/// Currently It assumes the elliptic curve scalar field is same as native +/// field. pub struct PoseidonTranscript< C, L, @@ -48,55 +55,59 @@ pub struct PoseidonTranscript< buf: Poseidon>::LoadedScalar, T, RATE>, } -impl<'a, C, R, EccChip, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize> - PoseidonTranscript>, Value, T, RATE, R_F, R_P> +impl + PoseidonTranscript>, R, T, RATE, R_F, R_P> where C: CurveAffine, R: Read, - EccChip: NativeEncoding<'a, C>, + EccChip: NativeEncoding, { - pub fn new(loader: &Rc>, stream: Value) -> Self { - let buf = Poseidon::new(loader, R_F, R_P); + /// Initialize [`PoseidonTranscript`] given readable or writeable stream for + /// verifying or proving with [`NativeLoader`]. + pub fn new(loader: &Rc>, stream: R) -> Self { + let buf = Poseidon::new::(loader); Self { loader: loader.clone(), stream, buf } } + /// Initialize [`PoseidonTranscript`] from a precomputed spec of round constants and MDS matrix because computing the constants is expensive. pub fn from_spec( - loader: &Rc>, - stream: Value, - spec: crate::poseidon::Spec, + loader: &Rc>, + stream: R, + spec: OptimizedPoseidonSpec, ) -> Self { let buf = Poseidon::from_spec(loader, spec); Self { loader: loader.clone(), stream, buf } } - pub fn new_stream(&mut self, stream: Value) { + /// Clear the buffer and set the stream to a new one. Effectively the same as starting from a new transcript. + pub fn new_stream(&mut self, stream: R) { self.buf.clear(); self.stream = stream; } } -impl<'a, C, R, EccChip, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize> - Transcript>> - for PoseidonTranscript>, Value, T, RATE, R_F, R_P> +impl + Transcript>> + for PoseidonTranscript>, R, T, RATE, R_F, R_P> where C: CurveAffine, R: Read, - EccChip: NativeEncoding<'a, C>, + EccChip: NativeEncoding, { - fn loader(&self) -> &Rc> { + fn loader(&self) -> &Rc> { &self.loader } - fn squeeze_challenge(&mut self) -> Scalar<'a, C, EccChip> { + fn squeeze_challenge(&mut self) -> Scalar { self.buf.squeeze() } - fn common_scalar(&mut self, scalar: &Scalar<'a, C, EccChip>) -> Result<(), Error> { + fn common_scalar(&mut self, scalar: &Scalar) -> Result<(), Error> { self.buf.update(&[scalar.clone()]); Ok(()) } - fn common_ec_point(&mut self, ec_point: &EcPoint<'a, C, EccChip>) -> Result<(), Error> { + fn common_ec_point(&mut self, ec_point: &EcPoint) -> Result<(), Error> { let encoded = self .loader .ecc_chip() @@ -118,39 +129,31 @@ where } } -impl<'a, C, R, EccChip, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize> - TranscriptRead>> - for PoseidonTranscript>, Value, T, RATE, R_F, R_P> +impl + TranscriptRead>> + for PoseidonTranscript>, R, T, RATE, R_F, R_P> where C: CurveAffine, R: Read, - EccChip: NativeEncoding<'a, C>, + EccChip: NativeEncoding, { - fn read_scalar(&mut self) -> Result, Error> { - let scalar = self.stream.as_mut().and_then(|stream| { + fn read_scalar(&mut self) -> Result, Error> { + let scalar = { let mut data = ::Repr::default(); - if stream.read_exact(data.as_mut()).is_err() { - return Value::unknown(); - } - Option::::from(C::Scalar::from_repr(data)) - .map(Value::known) - .unwrap_or_else(Value::unknown) - }); + self.stream.read_exact(data.as_mut()).unwrap(); + C::Scalar::from_repr(data).unwrap() + }; let scalar = self.loader.assign_scalar(scalar); self.common_scalar(&scalar)?; Ok(scalar) } - fn read_ec_point(&mut self) -> Result, Error> { - let ec_point = self.stream.as_mut().and_then(|stream| { + fn read_ec_point(&mut self) -> Result, Error> { + let ec_point = { let mut compressed = C::Repr::default(); - if stream.read_exact(compressed.as_mut()).is_err() { - return Value::unknown(); - } - Option::::from(C::from_bytes(&compressed)) - .map(Value::known) - .unwrap_or_else(Value::unknown) - }); + self.stream.read_exact(compressed.as_mut()).unwrap(); + C::from_bytes(&compressed).unwrap() + }; let ec_point = self.loader.assign_ec_point(ec_point); self.common_ec_point(&ec_point)?; Ok(ec_point) @@ -160,14 +163,22 @@ where impl PoseidonTranscript { - pub fn new(stream: S) -> Self { - Self { loader: NativeLoader, stream, buf: Poseidon::new(&NativeLoader, R_F, R_P) } + /// Initialize [`PoseidonTranscript`] given readable or writeable stream for + /// verifying or proving with [`NativeLoader`]. + pub fn new(stream: S) -> Self { + Self { + loader: NativeLoader, + stream, + buf: Poseidon::new::(&NativeLoader), + } } - pub fn from_spec(stream: S, spec: crate::poseidon::Spec) -> Self { + /// Initialize [`PoseidonTranscript`] from a precomputed spec of round constants and MDS matrix because computing the constants is expensive. + pub fn from_spec(stream: S, spec: OptimizedPoseidonSpec) -> Self { Self { loader: NativeLoader, stream, buf: Poseidon::from_spec(&NativeLoader, spec) } } + /// Clear the buffer and set the stream to a new one. Effectively the same as starting from a new transcript. pub fn new_stream(&mut self, stream: S) { self.buf.clear(); self.stream = stream; @@ -177,6 +188,7 @@ impl PoseidonTranscript, T, RATE, R_F, R_P> { + /// Clear the buffer and stream. pub fn clear(&mut self) { self.buf.clear(); self.stream.clear(); @@ -254,10 +266,12 @@ where C: CurveAffine, W: Write, { + /// Returns mutable `stream`. pub fn stream_mut(&mut self) -> &mut W { &mut self.stream } + /// Finalize transcript and returns `stream`. pub fn finalize(self) -> W { self.stream } @@ -289,6 +303,10 @@ where } } +/// [`EncodedChallenge`] implemented for verifier in [`halo2_proofs`] circuit. +/// Currently It assumes the elliptic curve scalar field is same as native +/// field. +#[derive(Debug)] pub struct ChallengeScalar(C::Scalar); impl EncodedChallenge for ChallengeScalar { @@ -360,7 +378,7 @@ where R: Read, { fn init(reader: R) -> Self { - Self::new(reader) + Self::new::<0>(reader) } } @@ -394,7 +412,7 @@ where W: Write, { fn init(writer: W) -> Self { - Self::new(writer) + Self::new::<0>(writer) } fn finalize(self) -> W { @@ -405,10 +423,9 @@ where mod halo2_lib { use crate::halo2_curves::CurveAffineExt; use crate::system::halo2::transcript::halo2::NativeEncoding; - use halo2_base::utils::PrimeField; - use halo2_ecc::ecc::BaseFieldEccChip; + use halo2_ecc::{ecc::BaseFieldEccChip, fields::PrimeField}; - impl<'a, C: CurveAffineExt> NativeEncoding<'a, C> for BaseFieldEccChip + impl<'chip, C: CurveAffineExt> NativeEncoding for BaseFieldEccChip<'chip, C> where C::Scalar: PrimeField, C::Base: PrimeField, @@ -418,31 +435,7 @@ mod halo2_lib { _: &mut Self::Context, ec_point: &Self::AssignedEcPoint, ) -> Result, crate::Error> { - Ok(vec![ec_point.x().native().clone(), ec_point.y().native().clone()]) - } - } -} - -/* -mod halo2_wrong { - use crate::system::halo2::transcript::halo2::NativeEncoding; - use halo2_curves::CurveAffine; - use halo2_proofs::circuit::AssignedCell; - use halo2_wrong_ecc::BaseFieldEccChip; - - impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> NativeEncoding<'a, C> - for BaseFieldEccChip - { - fn encode( - &self, - _: &mut Self::Context, - ec_point: &Self::AssignedEcPoint, - ) -> Result>, crate::Error> { - Ok(vec![ - ec_point.x().native().clone(), - ec_point.y().native().clone(), - ]) + Ok(vec![*ec_point.x().native(), *ec_point.y().native()]) } } } -*/ diff --git a/snark-verifier/src/util.rs b/snark-verifier/src/util.rs index b42db61c..508c0eeb 100644 --- a/snark-verifier/src/util.rs +++ b/snark-verifier/src/util.rs @@ -1,8 +1,9 @@ +//! Utilities. + pub mod arithmetic; pub mod hash; pub mod msm; pub mod poly; -pub mod protocol; pub mod transcript; pub(crate) use itertools::Itertools; @@ -10,6 +11,7 @@ pub(crate) use itertools::Itertools; #[cfg(feature = "parallel")] pub(crate) use rayon::current_num_threads; +/// Parallelly executing the function on the items of the given iterator. pub fn parallelize_iter(iter: I, f: F) where I: Send + Iterator, @@ -27,6 +29,7 @@ where iter.for_each(f); } +/// Parallelly executing the function on the given mutable slice. pub fn parallelize(v: &mut [T], f: F) where T: Send, diff --git a/snark-verifier/src/util/arithmetic.rs b/snark-verifier/src/util/arithmetic.rs index 2d24961e..97962e32 100644 --- a/snark-verifier/src/util/arithmetic.rs +++ b/snark-verifier/src/util/arithmetic.rs @@ -1,15 +1,7 @@ -use crate::util::Itertools; -use num_bigint::BigUint; -use num_traits::One; -use serde::{Deserialize, Serialize}; -use std::{ - cmp::Ordering, - fmt::Debug, - iter, mem, - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, -}; +//! Arithmetic related re-exported traits and utilities. use crate::halo2_curves; +use crate::util::Itertools; pub use halo2_curves::{ group::{ ff::{BatchInvert, Field, PrimeField}, @@ -19,11 +11,22 @@ pub use halo2_curves::{ pairing::MillerLoopResult, Coordinates, CurveAffine, CurveExt, FieldExt, }; +use num_bigint::BigUint; +use num_traits::One; +use serde::{Deserialize, Serialize}; +use std::{ + cmp::Ordering, + fmt::Debug, + iter, mem, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, +}; +/// [`halo2_curves::pairing::MultiMillerLoop`] with [`std::fmt::Debug`]. pub trait MultiMillerLoop: halo2_curves::pairing::MultiMillerLoop + Debug {} impl MultiMillerLoop for M {} +/// Operations that could be done with field elements. pub trait FieldOps: Sized + Neg @@ -40,26 +43,29 @@ pub trait FieldOps: + for<'a> SubAssign<&'a Self> + for<'a> MulAssign<&'a Self> { + /// Returns multiplicative inversion if any. fn invert(&self) -> Option; } +/// Batch invert [`PrimeField`] elements and multiply all with given coefficient. pub fn batch_invert_and_mul(values: &mut [F], coeff: &F) { + if values.is_empty() { + return; + } let products = values .iter() - .filter(|value| !value.is_zero_vartime()) .scan(F::one(), |acc, value| { *acc *= value; Some(*acc) }) .collect_vec(); - let mut all_product_inv = products.last().unwrap().invert().unwrap() * coeff; + let mut all_product_inv = Option::::from(products.last().unwrap().invert()) + .expect("Attempted to batch invert an array containing zero") + * coeff; - for (value, product) in values - .iter_mut() - .rev() - .filter(|value| !value.is_zero_vartime()) - .zip(products.into_iter().rev().skip(1).chain(Some(F::one()))) + for (value, product) in + values.iter_mut().rev().zip(products.into_iter().rev().skip(1).chain(Some(F::one()))) { let mut inv = all_product_inv * product; mem::swap(value, &mut inv); @@ -67,10 +73,18 @@ pub fn batch_invert_and_mul(values: &mut [F], coeff: &F) { } } +/// Batch invert [`PrimeField`] elements. pub fn batch_invert(values: &mut [F]) { batch_invert_and_mul(values, &F::one()) } +/// Root of unity of 2^k-sized multiplicative subgroup of [`PrimeField`] by +/// repeatedly squaring the root of unity of the largest multiplicative +/// subgroup. +/// +/// # Panic +/// +/// If given `k` is greater than [`PrimeField::S`]. pub fn root_of_unity(k: usize) -> F { assert!(k <= F::S as usize); @@ -80,18 +94,22 @@ pub fn root_of_unity(k: usize) -> F { .unwrap() } +/// Rotation on a group. #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub struct Rotation(pub i32); impl Rotation { + /// No rotation pub fn cur() -> Self { Rotation(0) } + /// To previous element pub fn prev() -> Self { Rotation(-1) } + /// To next element pub fn next() -> Self { Rotation(1) } @@ -103,16 +121,23 @@ impl From for Rotation { } } +/// 2-adicity multiplicative domain #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Domain { + /// Log size of the domain. pub k: usize, + /// Size of the domain. pub n: usize, + /// Inverse of `n`. pub n_inv: F, + /// Generator of the domain. pub gen: F, + /// Inverse of `gen`. pub gen_inv: F, } impl Domain { + /// Initialize a domain with specified generator. pub fn new(k: usize, gen: F) -> Self { let n = 1 << k; let n_inv = F::from(n as u64).invert().unwrap(); @@ -121,15 +146,17 @@ impl Domain { Self { k, n, n_inv, gen, gen_inv } } + /// Rotate an element to given `rotation`. pub fn rotate_scalar(&self, scalar: F, rotation: Rotation) -> F { match rotation.0.cmp(&0) { Ordering::Equal => scalar, Ordering::Greater => scalar * self.gen.pow_vartime([rotation.0 as u64]), - Ordering::Less => scalar * self.gen_inv.pow_vartime([(-rotation.0) as u64]), + Ordering::Less => scalar * self.gen_inv.pow_vartime([(-(rotation.0 as i64)) as u64]), } } } +/// Contains numerator and denominator for deferred evaluation. #[derive(Clone, Debug)] pub struct Fraction { numer: Option, @@ -139,14 +166,17 @@ pub struct Fraction { } impl Fraction { + /// Initialize an unevaluated fraction. pub fn new(numer: T, denom: T) -> Self { Self { numer: Some(numer), denom, eval: None, inv: false } } + /// Initialize an unevaluated fraction without numerator. pub fn one_over(denom: T) -> Self { Self { numer: None, denom, eval: None, inv: false } } + /// Returns denominator. pub fn denom(&self) -> Option<&T> { if !self.inv { Some(&self.denom) @@ -155,6 +185,8 @@ impl Fraction { } } + #[must_use = "To be inverted"] + /// Returns mutable denominator for doing inversion. pub fn denom_mut(&mut self) -> Option<&mut T> { if !self.inv { self.inv = true; @@ -166,18 +198,29 @@ impl Fraction { } impl Fraction { + /// Evaluate the fraction and cache the result. + /// + /// # Panic + /// + /// If `denom_mut` is not called before. pub fn evaluate(&mut self) { assert!(self.inv); - assert!(self.eval.is_none()); - - self.eval = Some( - self.numer - .take() - .map(|numer| numer * &self.denom) - .unwrap_or_else(|| self.denom.clone()), - ); + + if self.eval.is_none() { + self.eval = Some( + self.numer + .take() + .map(|numer| numer * &self.denom) + .unwrap_or_else(|| self.denom.clone()), + ); + } } + /// Returns cached fraction evaluation. + /// + /// # Panic + /// + /// If `evaluate` is not called before. pub fn evaluated(&self) -> &T { assert!(self.eval.is_some()); @@ -185,14 +228,12 @@ impl Fraction { } } -pub fn ilog2(value: usize) -> usize { - (usize::BITS - value.leading_zeros() - 1) as usize -} - +/// Modulus of a [`PrimeField`] pub fn modulus() -> BigUint { fe_to_big(-F::one()) + 1usize } +/// Convert a [`BigUint`] into a [`PrimeField`] . pub fn fe_from_big(big: BigUint) -> F { let bytes = big.to_bytes_le(); let mut repr = F::Repr::default(); @@ -201,14 +242,18 @@ pub fn fe_from_big(big: BigUint) -> F { F::from_repr(repr).unwrap() } +/// Convert a [`PrimeField`] into a [`BigUint`]. pub fn fe_to_big(fe: F) -> BigUint { BigUint::from_bytes_le(fe.to_repr().as_ref()) } +/// Convert a [`PrimeField`] into another [`PrimeField`]. pub fn fe_to_fe(fe: F1) -> F2 { fe_from_big(fe_to_big(fe) % modulus::()) } +/// Convert `LIMBS` limbs into a [`PrimeField`], assuming each limb contains at +/// most `BITS`. pub fn fe_from_limbs( limbs: [F1; LIMBS], ) -> F2 { @@ -223,6 +268,8 @@ pub fn fe_from_limbs( fe: F1, ) -> [F2; LIMBS] { @@ -237,10 +284,12 @@ pub fn fe_to_limbs(scalar: F) -> impl Iterator { iter::successors(Some(F::one()), move |power| Some(scalar * power)) } +/// Compute inner product of 2 slice of [`Field`]. pub fn inner_product(lhs: &[F], rhs: &[F]) -> F { lhs.iter() .zip_eq(rhs.iter()) diff --git a/snark-verifier/src/util/hash.rs b/snark-verifier/src/util/hash.rs index 17ede0b3..a8fe168c 100644 --- a/snark-verifier/src/util/hash.rs +++ b/snark-verifier/src/util/hash.rs @@ -1,6 +1,10 @@ +//! Hash algorithms. + +#[cfg(feature = "loader_halo2")] mod poseidon; -pub use crate::util::hash::poseidon::Poseidon; +#[cfg(feature = "loader_halo2")] +pub use crate::util::hash::poseidon::{OptimizedPoseidonSpec, Poseidon}; #[cfg(feature = "loader_evm")] pub use sha3::{Digest, Keccak256}; diff --git a/snark-verifier/src/util/hash/poseidon.rs b/snark-verifier/src/util/hash/poseidon.rs index fa7442f4..1ff06ab9 100644 --- a/snark-verifier/src/util/hash/poseidon.rs +++ b/snark-verifier/src/util/hash/poseidon.rs @@ -1,21 +1,346 @@ -use crate::poseidon::{self, SparseMDSMatrix, Spec}; +#![allow(clippy::needless_range_loop)] // for clarity of matrix operations use crate::{ loader::{LoadedScalar, ScalarLoader}, util::{arithmetic::FieldExt, Itertools}, }; +use poseidon_circuit::poseidon::primitives::Spec as PoseidonSpec; // trait use std::{iter, marker::PhantomData, mem}; -#[derive(Clone)] +#[cfg(test)] +mod tests; + +// struct so we can use PoseidonSpec trait to generate round constants and MDS matrix +#[derive(Debug)] +pub struct Poseidon128Pow5Gen< + F: FieldExt, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + const SECURE_MDS: usize, +> { + _marker: PhantomData, +} + +impl< + F: FieldExt, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + const SECURE_MDS: usize, + > PoseidonSpec for Poseidon128Pow5Gen +{ + fn full_rounds() -> usize { + R_F + } + + fn partial_rounds() -> usize { + R_P + } + + fn sbox(val: F) -> F { + val.pow_vartime([5]) + } + + // see "Avoiding insecure matrices" in Section 2.3 of https://eprint.iacr.org/2019/458.pdf + // most Specs used in practice have SECURE_MDS = 0 + fn secure_mds() -> usize { + SECURE_MDS + } +} + +// We use the optimized Poseidon implementation described in Supplementary Material Section B of https://eprint.iacr.org/2019/458.pdf +// This involves some further computation of optimized constants and sparse MDS matrices beyond what the Scroll PoseidonSpec generates +// The implementation below is adapted from https://github.com/privacy-scaling-explorations/poseidon + +/// `OptimizedPoseidonSpec` holds construction parameters as well as constants that are used in +/// permutation step. +#[derive(Debug, Clone)] +pub struct OptimizedPoseidonSpec { + pub(crate) r_f: usize, + pub(crate) mds_matrices: MDSMatrices, + pub(crate) constants: OptimizedConstants, +} + +/// `OptimizedConstants` has round constants that are added each round. While +/// full rounds has T sized constants there is a single constant for each +/// partial round +#[derive(Debug, Clone)] +pub struct OptimizedConstants { + pub(crate) start: Vec<[F; T]>, + pub(crate) partial: Vec, + pub(crate) end: Vec<[F; T]>, +} + +/// The type used to hold the MDS matrix +pub(crate) type Mds = [[F; T]; T]; + +/// `MDSMatrices` holds the MDS matrix as well as transition matrix which is +/// also called `pre_sparse_mds` and sparse matrices that enables us to reduce +/// number of multiplications in apply MDS step +#[derive(Debug, Clone)] +pub struct MDSMatrices { + pub(crate) mds: MDSMatrix, + pub(crate) pre_sparse_mds: MDSMatrix, + pub(crate) sparse_matrices: Vec>, +} + +/// `SparseMDSMatrix` are in `[row], [hat | identity]` form and used in linear +/// layer of partial rounds instead of the original MDS +#[derive(Debug, Clone)] +pub struct SparseMDSMatrix { + pub(crate) row: [F; T], + pub(crate) col_hat: [F; RATE], +} + +/// `MDSMatrix` is applied to `State` to achive linear layer of Poseidon +#[derive(Clone, Debug)] +pub struct MDSMatrix(pub(crate) Mds); + +impl MDSMatrix { + pub(crate) fn mul_vector(&self, v: &[F; T]) -> [F; T] { + let mut res = [F::zero(); T]; + for i in 0..T { + for j in 0..T { + res[i] += self.0[i][j] * v[j]; + } + } + res + } + + fn identity() -> Mds { + let mut mds = [[F::zero(); T]; T]; + for i in 0..T { + mds[i][i] = F::one(); + } + mds + } + + /// Multiplies two MDS matrices. Used in sparse matrix calculations + fn mul(&self, other: &Self) -> Self { + let mut res = [[F::zero(); T]; T]; + for i in 0..T { + for j in 0..T { + for k in 0..T { + res[i][j] += self.0[i][k] * other.0[k][j]; + } + } + } + Self(res) + } + + fn transpose(&self) -> Self { + let mut res = [[F::zero(); T]; T]; + for i in 0..T { + for j in 0..T { + res[i][j] = self.0[j][i]; + } + } + Self(res) + } + + fn determinant(m: [[F; N]; N]) -> F { + let mut res = F::one(); + let mut m = m; + for i in 0..N { + let mut pivot = i; + while m[pivot][i] == F::zero() { + pivot += 1; + assert!(pivot < N, "matrix is not invertible"); + } + if pivot != i { + res = -res; + m.swap(pivot, i); + } + res *= m[i][i]; + let inv = m[i][i].invert().unwrap(); + for j in i + 1..N { + let factor = m[j][i] * inv; + for k in i + 1..N { + m[j][k] -= m[i][k] * factor; + } + } + } + res + } + + /// See Section B in Supplementary Material https://eprint.iacr.org/2019/458.pdf + /// Factorises an MDS matrix `M` into `M'` and `M''` where `M = M' * M''`. + /// Resulted `M''` matrices are the sparse ones while `M'` will contribute + /// to the accumulator of the process + fn factorise(&self) -> (Self, SparseMDSMatrix) { + assert_eq!(RATE + 1, T); + // Given `(t-1 * t-1)` MDS matrix called `hat` constructs the `t * t` matrix in + // form `[[1 | 0], [0 | m]]`, ie `hat` is the right bottom sub-matrix + let prime = |hat: Mds| -> Self { + let mut prime = Self::identity(); + for (prime_row, hat_row) in prime.iter_mut().skip(1).zip(hat.iter()) { + for (el_prime, el_hat) in prime_row.iter_mut().skip(1).zip(hat_row.iter()) { + *el_prime = *el_hat; + } + } + Self(prime) + }; + + // Given `(t-1)` sized `w_hat` vector constructs the matrix in form + // `[[m_0_0 | m_0_i], [w_hat | identity]]` + let prime_prime = |w_hat: [F; RATE]| -> Mds { + let mut prime_prime = Self::identity(); + prime_prime[0] = self.0[0]; + for (row, w) in prime_prime.iter_mut().skip(1).zip(w_hat.iter()) { + row[0] = *w + } + prime_prime + }; + + let w = self.0.iter().skip(1).map(|row| row[0]).collect::>(); + // m_hat is the `(t-1 * t-1)` right bottom sub-matrix of m := self.0 + let mut m_hat = [[F::zero(); RATE]; RATE]; + for i in 0..RATE { + for j in 0..RATE { + m_hat[i][j] = self.0[i + 1][j + 1]; + } + } + // w_hat = m_hat^{-1} * w, where m_hat^{-1} is matrix inverse and * is matrix mult + // we avoid computing m_hat^{-1} explicitly by using Cramer's rule: https://en.wikipedia.org/wiki/Cramer%27s_rule + let mut w_hat = [F::zero(); RATE]; + let det = Self::determinant(m_hat); + let det_inv = Option::::from(det.invert()).expect("matrix is not invertible"); + for j in 0..RATE { + let mut m_hat_j = m_hat; + for i in 0..RATE { + m_hat_j[i][j] = w[i]; + } + w_hat[j] = Self::determinant(m_hat_j) * det_inv; + } + let m_prime = prime(m_hat); + let m_prime_prime = prime_prime(w_hat); + // row = first row of m_prime_prime.transpose() = first column of m_prime_prime + let row: [F; T] = + m_prime_prime.iter().map(|row| row[0]).collect::>().try_into().unwrap(); + // col_hat = first column of m_prime_prime.transpose() without first element = first row of m_prime_prime without first element + let col_hat: [F; RATE] = m_prime_prime[0][1..].try_into().unwrap(); + (m_prime, SparseMDSMatrix { row, col_hat }) + } +} + +impl OptimizedPoseidonSpec { + /// Generate new spec with specific number of full and partial rounds. `SECURE_MDS` is usually 0, but may need to be specified because insecure matrices may sometimes be generated + pub fn new() -> Self { + let (round_constants, mds, mds_inv) = + Poseidon128Pow5Gen::::constants(); + let mds = MDSMatrix(mds); + let inverse_mds = MDSMatrix(mds_inv); + + let constants = + Self::calculate_optimized_constants(R_F, R_P, round_constants, &inverse_mds); + let (sparse_matrices, pre_sparse_mds) = Self::calculate_sparse_matrices(R_P, &mds); + + Self { + r_f: R_F, + constants, + mds_matrices: MDSMatrices { mds, sparse_matrices, pre_sparse_mds }, + } + } + + fn calculate_optimized_constants( + r_f: usize, + r_p: usize, + constants: Vec<[F; T]>, + inverse_mds: &MDSMatrix, + ) -> OptimizedConstants { + let (number_of_rounds, r_f_half) = (r_f + r_p, r_f / 2); + assert_eq!(constants.len(), number_of_rounds); + + // Calculate optimized constants for first half of the full rounds + let mut constants_start: Vec<[F; T]> = vec![[F::zero(); T]; r_f_half]; + constants_start[0] = constants[0]; + for (optimized, constants) in + constants_start.iter_mut().skip(1).zip(constants.iter().skip(1)) + { + *optimized = inverse_mds.mul_vector(constants); + } + + // Calculate constants for partial rounds + let mut acc = constants[r_f_half + r_p]; + let mut constants_partial = vec![F::zero(); r_p]; + for (optimized, constants) in constants_partial + .iter_mut() + .rev() + .zip(constants.iter().skip(r_f_half).rev().skip(r_f_half)) + { + let mut tmp = inverse_mds.mul_vector(&acc); + *optimized = tmp[0]; + + tmp[0] = F::zero(); + for ((acc, tmp), constant) in acc.iter_mut().zip(tmp.into_iter()).zip(constants.iter()) + { + *acc = tmp + constant + } + } + constants_start.push(inverse_mds.mul_vector(&acc)); + + // Calculate optimized constants for ending half of the full rounds + let mut constants_end: Vec<[F; T]> = vec![[F::zero(); T]; r_f_half - 1]; + for (optimized, constants) in + constants_end.iter_mut().zip(constants.iter().skip(r_f_half + r_p + 1)) + { + *optimized = inverse_mds.mul_vector(constants); + } + + OptimizedConstants { + start: constants_start, + partial: constants_partial, + end: constants_end, + } + } + + fn calculate_sparse_matrices( + r_p: usize, + mds: &MDSMatrix, + ) -> (Vec>, MDSMatrix) { + let mds = mds.transpose(); + let mut acc = mds.clone(); + let mut sparse_matrices = (0..r_p) + .map(|_| { + let (m_prime, m_prime_prime) = acc.factorise(); + acc = mds.mul(&m_prime); + m_prime_prime + }) + .collect::>>(); + + sparse_matrices.reverse(); + (sparse_matrices, acc.transpose()) + } +} + +// ================ END OF CONSTRUCTION OF POSEIDON SPEC ==================== + +// now we get to actual trait based implementation of Poseidon permutation +// this works for any loader, where the two loaders used are NativeLoader (native rust) and Halo2Loader (ZK circuit) +#[derive(Clone, Debug)] struct State { inner: [L; T], _marker: PhantomData, } +// the transcript hash implementation is the one suggested in the original paper https://eprint.iacr.org/2019/458.pdf +// another reference implementation is https://github.com/privacy-scaling-explorations/halo2wrong/tree/master/transcript/src impl, const T: usize, const RATE: usize> State { fn new(inner: [L; T]) -> Self { Self { inner, _marker: PhantomData } } + fn default(loader: &L::Loader) -> Self { + let mut default_state = [F::zero(); T]; + // from Section 4.2 of https://eprint.iacr.org/2019/458.pdf + // • Variable-Input-Length Hashing. The capacity value is 2^64 + (o−1) where o the output length. + // for our transcript use cases, o = 1 + default_state[0] = F::from_u128(1u128 << 64); + Self::new(default_state.map(|state| loader.load_const(&state))) + } + fn loader(&self) -> &L::Loader { self.inner[0].loader() } @@ -52,6 +377,8 @@ impl, const T: usize, const RATE: usize> State, const T: usize, const RATE: usize> State) { self.inner = iter::once( self.loader() - .sum_with_coeff(&mds.row().iter().cloned().zip(self.inner.iter()).collect_vec()), + .sum_with_coeff(&mds.row.iter().cloned().zip(self.inner.iter()).collect_vec()), ) - .chain(mds.col_hat().iter().zip(self.inner.iter().skip(1)).map(|(coeff, state)| { + .chain(mds.col_hat.iter().zip(self.inner.iter().skip(1)).map(|(coeff, state)| { self.loader().sum_with_coeff(&[(*coeff, &self.inner[0]), (F::one(), state)]) })) .collect_vec() @@ -82,40 +409,49 @@ impl, const T: usize, const RATE: usize> State { - spec: Spec, + spec: OptimizedPoseidonSpec, default_state: State, state: State, buf: Vec, } impl, const T: usize, const RATE: usize> Poseidon { - pub fn new(loader: &L::Loader, r_f: usize, r_p: usize) -> Self { - let default_state = - State::new(poseidon::State::default().words().map(|state| loader.load_const(&state))); + /// Initialize a poseidon hasher. + /// Generates a new spec with specific number of full and partial rounds. `SECURE_MDS` is usually 0, but may need to be specified because insecure matrices may sometimes be generated + pub fn new( + loader: &L::Loader, + ) -> Self { + let default_state = State::default(loader); Self { - spec: Spec::new(r_f, r_p), + spec: OptimizedPoseidonSpec::new::(), state: default_state.clone(), default_state, buf: Vec::new(), } } - pub fn from_spec(loader: &L::Loader, spec: Spec) -> Self { - let default_state = - State::new(poseidon::State::default().words().map(|state| loader.load_const(&state))); + /// Initialize a poseidon hasher from an existing spec. + pub fn from_spec(loader: &L::Loader, spec: OptimizedPoseidonSpec) -> Self { + let default_state = State::default(loader); Self { spec, state: default_state.clone(), default_state, buf: Vec::new() } } + /// Reset state to default and clear the buffer. pub fn clear(&mut self) { self.state = self.default_state.clone(); self.buf.clear(); } + /// Store given `elements` into buffer. pub fn update(&mut self, elements: &[L]) { self.buf.extend_from_slice(elements); } + /// Consume buffer and perform permutation, then output second element of + /// state. pub fn squeeze(&mut self) -> L { let buf = mem::take(&mut self.buf); let exact = buf.len() % RATE == 0; @@ -131,13 +467,13 @@ impl, const T: usize, const RATE: usize> Poseido } fn permutation(&mut self, inputs: &[L]) { - let r_f = self.spec.r_f() / 2; - let mds = self.spec.mds_matrices().mds().rows(); - let pre_sparse_mds = self.spec.mds_matrices().pre_sparse_mds().rows(); - let sparse_matrices = self.spec.mds_matrices().sparse_matrices(); + let r_f = self.spec.r_f / 2; + let mds = self.spec.mds_matrices.mds.0; + let pre_sparse_mds = self.spec.mds_matrices.pre_sparse_mds.0; + let sparse_matrices = &self.spec.mds_matrices.sparse_matrices; // First half of the full rounds - let constants = self.spec.constants().start(); + let constants = &self.spec.constants.start; self.state.absorb_with_pre_constants(inputs, &constants[0]); for constants in constants.iter().skip(1).take(r_f - 1) { self.state.sbox_full(constants); @@ -147,14 +483,14 @@ impl, const T: usize, const RATE: usize> Poseido self.state.apply_mds(&pre_sparse_mds); // Partial rounds - let constants = self.spec.constants().partial(); + let constants = &self.spec.constants.partial; for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { self.state.sbox_part(constant); self.state.apply_sparse_mds(sparse_mds); } // Second half of the full rounds - let constants = self.spec.constants().end(); + let constants = &self.spec.constants.end; for constants in constants.iter() { self.state.sbox_full(constants); self.state.apply_mds(&mds); diff --git a/snark-verifier/src/util/hash/poseidon/tests.rs b/snark-verifier/src/util/hash/poseidon/tests.rs new file mode 100644 index 00000000..cf4712bc --- /dev/null +++ b/snark-verifier/src/util/hash/poseidon/tests.rs @@ -0,0 +1,85 @@ +use halo2_base::halo2_proofs::halo2curves::group::ff::PrimeField; + +use super::*; +use crate::{halo2_curves::bn256::Fr, loader::native::NativeLoader}; + +#[test] +fn test_mds() { + let spec = OptimizedPoseidonSpec::::new::<8, 57, 0>(); + + let mds = vec![ + vec![ + "7511745149465107256748700652201246547602992235352608707588321460060273774987", + "10370080108974718697676803824769673834027675643658433702224577712625900127200", + "19705173408229649878903981084052839426532978878058043055305024233888854471533", + ], + vec![ + "18732019378264290557468133440468564866454307626475683536618613112504878618481", + "20870176810702568768751421378473869562658540583882454726129544628203806653987", + "7266061498423634438633389053804536045105766754026813321943009179476902321146", + ], + vec![ + "9131299761947733513298312097611845208338517739621853568979632113419485819303", + "10595341252162738537912664445405114076324478519622938027420701542910180337937", + "11597556804922396090267472882856054602429588299176362916247939723151043581408", + ], + ]; + for (row1, row2) in mds.iter().zip_eq(spec.mds_matrices.mds.0.iter()) { + for (e1, e2) in row1.iter().zip_eq(row2.iter()) { + assert_eq!(Fr::from_str_vartime(e1).unwrap(), *e2); + } + } +} + +#[test] +fn test_poseidon_against_test_vectors() { + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_3 + { + const R_F: usize = 8; + const R_P: usize = 57; + const T: usize = 3; + const RATE: usize = 2; + + let mut hasher = Poseidon::::new::(&NativeLoader); + + let state = vec![0u64, 1, 2].into_iter().map(Fr::from).collect::>(); + hasher.state = State::new(state.try_into().unwrap()); + hasher.permutation(&[(); RATE].map(|_| Fr::zero())); // avoid padding + let state_0 = hasher.state.inner; + let expected = vec![ + "7853200120776062878684798364095072458815029376092732009249414926327459813530", + "7142104613055408817911962100316808866448378443474503659992478482890339429929", + "6549537674122432311777789598043107870002137484850126429160507761192163713804", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word, Fr::from_str_vartime(expected).unwrap()); + } + } + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_5 + { + const R_F: usize = 8; + const R_P: usize = 60; + const T: usize = 5; + const RATE: usize = 4; + + let mut hasher = Poseidon::::new::(&NativeLoader); + + let state = vec![0u64, 1, 2, 3, 4].into_iter().map(Fr::from).collect::>(); + hasher.state = State::new(state.try_into().unwrap()); + hasher.permutation(&[(); RATE].map(|_| Fr::zero())); + let state_0 = hasher.state.inner; + let expected = vec![ + "18821383157269793795438455681495246036402687001665670618754263018637548127333", + "7817711165059374331357136443537800893307845083525445872661165200086166013245", + "16733335996448830230979566039396561240864200624113062088822991822580465420551", + "6644334865470350789317807668685953492649391266180911382577082600917830417726", + "3372108894677221197912083238087960099443657816445944159266857514496320565191", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word, Fr::from_str_vartime(expected).unwrap()); + } + } +} diff --git a/snark-verifier/src/util/msm.rs b/snark-verifier/src/util/msm.rs index 014a29e8..8d18cdf8 100644 --- a/snark-verifier/src/util/msm.rs +++ b/snark-verifier/src/util/msm.rs @@ -1,3 +1,5 @@ +//! Multi-scalar multiplication algorithm. + use crate::{ loader::{LoadedEcPoint, Loader}, util::{ @@ -14,6 +16,7 @@ use std::{ }; #[derive(Clone, Debug)] +/// Contains unevaluated multi-scalar multiplication. pub struct Msm<'a, C: CurveAffine, L: Loader> { constant: Option, scalars: Vec, @@ -26,11 +29,7 @@ where L: Loader, { fn default() -> Self { - Self { - constant: None, - scalars: Vec::new(), - bases: Vec::new(), - } + Self { constant: None, scalars: Vec::new(), bases: Vec::new() } } } @@ -39,20 +38,15 @@ where C: CurveAffine, L: Loader, { + /// Initialize with a constant. pub fn constant(constant: L::LoadedScalar) -> Self { - Msm { - constant: Some(constant), - ..Default::default() - } + Msm { constant: Some(constant), ..Default::default() } } + /// Initialize with a base. pub fn base<'b: 'a>(base: &'b L::LoadedEcPoint) -> Self { let one = base.loader().load_one(); - Msm { - scalars: vec![one], - bases: vec![base], - ..Default::default() - } + Msm { scalars: vec![one], bases: vec![base], ..Default::default() } } pub(crate) fn size(&self) -> usize { @@ -68,26 +62,21 @@ where self.bases.is_empty().then(|| self.constant.unwrap()) } + /// Evaluate multi-scalar multiplication. + /// + /// # Panic + /// + /// If given `gen` is `None` but there `constant` has some value. pub fn evaluate(self, gen: Option) -> L::LoadedEcPoint { - let gen = gen.map(|gen| { - self.bases - .first() - .unwrap() - .loader() - .ec_point_load_const(&gen) - }); + let gen = gen.map(|gen| self.bases.first().unwrap().loader().ec_point_load_const(&gen)); let pairs = iter::empty() - .chain( - self.constant - .as_ref() - .map(|constant| (constant, gen.as_ref().unwrap())), - ) + .chain(self.constant.as_ref().map(|constant| (constant, gen.as_ref().unwrap()))) .chain(self.scalars.iter().zip(self.bases.into_iter())) .collect_vec(); L::multi_scalar_multiplication(&pairs) } - pub fn scale(&mut self, factor: &L::LoadedScalar) { + fn scale(&mut self, factor: &L::LoadedScalar) { if let Some(constant) = self.constant.as_mut() { *constant *= factor; } @@ -96,7 +85,7 @@ where } } - pub fn push<'b: 'a>(&mut self, scalar: L::LoadedScalar, base: &'b L::LoadedEcPoint) { + fn push<'b: 'a>(&mut self, scalar: L::LoadedScalar, base: &'b L::LoadedEcPoint) { if let Some(pos) = self.bases.iter().position(|exist| exist.eq(&base)) { self.scalars[pos] += &scalar; } else { @@ -105,7 +94,7 @@ where } } - pub fn extend<'b: 'a>(&mut self, mut other: Msm<'b, C, L>) { + fn extend<'b: 'a>(&mut self, mut other: Msm<'b, C, L>) { match (self.constant.as_mut(), other.constant.as_ref()) { (Some(lhs), Some(rhs)) => *lhs += rhs, (None, Some(_)) => self.constant = other.constant.take(), @@ -293,7 +282,8 @@ fn multi_scalar_multiplication_serial( } } -// Copy from https://github.com/zcash/halo2/blob/main/halo2_proofs/src/arithmetic.rs +/// Multi-scalar multiplication algorithm copied from +/// . pub fn multi_scalar_multiplication(scalars: &[C::Scalar], bases: &[C]) -> C::Curve { assert_eq!(scalars.len(), bases.len()); @@ -311,17 +301,12 @@ pub fn multi_scalar_multiplication(scalars: &[C::Scalar], bases: let chunk_size = Integer::div_ceil(&scalars.len(), &num_threads); let mut results = vec![C::Curve::identity(); num_threads]; parallelize_iter( - scalars - .chunks(chunk_size) - .zip(bases.chunks(chunk_size)) - .zip(results.iter_mut()), + scalars.chunks(chunk_size).zip(bases.chunks(chunk_size)).zip(results.iter_mut()), |((scalars, bases), result)| { multi_scalar_multiplication_serial(scalars, bases, result); }, ); - results - .iter() - .fold(C::Curve::identity(), |acc, result| acc + result) + results.iter().fold(C::Curve::identity(), |acc, result| acc + result) } #[cfg(not(feature = "parallel"))] { diff --git a/snark-verifier/src/util/poly.rs b/snark-verifier/src/util/poly.rs index ea120b33..17a065f9 100644 --- a/snark-verifier/src/util/poly.rs +++ b/snark-verifier/src/util/poly.rs @@ -1,4 +1,7 @@ +//! Polynomial. + use crate::util::{arithmetic::Field, parallelize}; +use itertools::Itertools; use rand::Rng; use std::{ iter::{self, Sum}, @@ -9,46 +12,50 @@ use std::{ }; #[derive(Clone, Debug)] +/// Univariate polynomial. pub struct Polynomial(Vec); impl Polynomial { + /// Initialize an univariate polynomial. pub fn new(inner: Vec) -> Self { Self(inner) } + /// Returns `true` if the `Polynomial` contains no elements. pub fn is_empty(&self) -> bool { self.0.is_empty() } + /// Returns the length of the `Polynomial`. pub fn len(&self) -> usize { self.0.len() } + /// Returns an iterator of the `Polynomial`. pub fn iter(&self) -> impl Iterator { self.0.iter() } + /// Returns a mutable iterator of the `Polynomial`. pub fn iter_mut(&mut self) -> impl Iterator { self.0.iter_mut() } + /// Into vector of coefficients. pub fn to_vec(self) -> Vec { self.0 } } impl Polynomial { - pub fn rand(n: usize, mut rng: R) -> Self { + pub(crate) fn rand(n: usize, mut rng: R) -> Self { Self::new(iter::repeat_with(|| F::random(&mut rng)).take(n).collect()) } + /// Returns evaluation at given `x`. pub fn evaluate(&self, x: F) -> F { - let evaluate_serial = |coeffs: &[F]| { - coeffs - .iter() - .rev() - .fold(F::zero(), |acc, coeff| acc * x + coeff) - }; + let evaluate_serial = + |coeffs: &[F]| coeffs.iter().rev().fold(F::zero(), |acc, coeff| acc * x + coeff); #[cfg(feature = "parallel")] { @@ -63,10 +70,12 @@ impl Polynomial { let chunk_size = Integer::div_ceil(&self.len(), &num_threads); let mut results = vec![F::zero(); num_threads]; parallelize_iter( - results - .iter_mut() - .zip(self.0.chunks(chunk_size)) - .zip(powers(x.pow_vartime(&[chunk_size as u64, 0, 0, 0]))), + results.iter_mut().zip(self.0.chunks(chunk_size)).zip(powers(x.pow_vartime(&[ + chunk_size as u64, + 0, + 0, + 0, + ]))), |((result, coeffs), scalar)| *result = evaluate_serial(coeffs) * scalar, ); results.iter().fold(F::zero(), |acc, result| acc + result) @@ -81,7 +90,7 @@ impl<'a, F: Field> Add<&'a Polynomial> for Polynomial { fn add(mut self, rhs: &'a Polynomial) -> Polynomial { parallelize(&mut self.0, |(lhs, start)| { - for (lhs, rhs) in lhs.iter_mut().zip(rhs.0[start..].iter()) { + for (lhs, rhs) in lhs.iter_mut().zip_eq(rhs.0[start..].iter()) { *lhs += *rhs; } }); @@ -94,7 +103,7 @@ impl<'a, F: Field> Sub<&'a Polynomial> for Polynomial { fn sub(mut self, rhs: &'a Polynomial) -> Polynomial { parallelize(&mut self.0, |(lhs, start)| { - for (lhs, rhs) in lhs.iter_mut().zip(rhs.0[start..].iter()) { + for (lhs, rhs) in lhs.iter_mut().zip_eq(rhs.0[start..].iter()) { *lhs -= *rhs; } }); diff --git a/snark-verifier/src/util/transcript.rs b/snark-verifier/src/util/transcript.rs index 3337324d..b871b083 100644 --- a/snark-verifier/src/util/transcript.rs +++ b/snark-verifier/src/util/transcript.rs @@ -1,46 +1,62 @@ +//! Transcript traits. + use crate::{ loader::{native::NativeLoader, Loader}, {util::arithmetic::CurveAffine, Error}, }; +/// Common methods for prover and verifier. pub trait Transcript where C: CurveAffine, L: Loader, { + /// Returns [`Loader`]. fn loader(&self) -> &L; + /// Squeeze a challenge. fn squeeze_challenge(&mut self) -> L::LoadedScalar; + /// Squeeze `n` challenges. fn squeeze_n_challenges(&mut self, n: usize) -> Vec { (0..n).map(|_| self.squeeze_challenge()).collect() } + /// Update with an elliptic curve point. fn common_ec_point(&mut self, ec_point: &L::LoadedEcPoint) -> Result<(), Error>; + /// Update with a scalar. fn common_scalar(&mut self, scalar: &L::LoadedScalar) -> Result<(), Error>; } +/// Transcript for verifier. pub trait TranscriptRead: Transcript where C: CurveAffine, L: Loader, { + /// Read a scalar. fn read_scalar(&mut self) -> Result; + /// Read `n` scalar. fn read_n_scalars(&mut self, n: usize) -> Result, Error> { (0..n).map(|_| self.read_scalar()).collect() } + /// Read a elliptic curve point. fn read_ec_point(&mut self) -> Result; + /// Read `n` elliptic curve point. fn read_n_ec_points(&mut self, n: usize) -> Result, Error> { (0..n).map(|_| self.read_ec_point()).collect() } } +/// Transcript for prover. pub trait TranscriptWrite: Transcript { + /// Write a scalar. fn write_scalar(&mut self, scalar: C::Scalar) -> Result<(), Error>; + /// Write a elliptic curve point. fn write_ec_point(&mut self, ec_point: C) -> Result<(), Error>; } diff --git a/snark-verifier/src/verifier.rs b/snark-verifier/src/verifier.rs index e9ad2e1f..813065db 100644 --- a/snark-verifier/src/verifier.rs +++ b/snark-verifier/src/verifier.rs @@ -1,50 +1,44 @@ +//! Verifiers for (S)NARK. + use crate::{ loader::Loader, - pcs::{Decider, MultiOpenScheme}, util::{arithmetic::CurveAffine, transcript::TranscriptRead}, - Protocol, + Error, }; use std::fmt::Debug; -mod plonk; - -pub use plonk::{Plonk, PlonkProof}; +pub mod plonk; -pub trait PlonkVerifier +/// (S)NARK verifier for verifying a (S)NARK. +pub trait SnarkVerifier where C: CurveAffine, L: Loader, - MOS: MultiOpenScheme, { + /// Verifying key for subroutines if any. + type VerifyingKey: Clone + Debug; + /// Protocol specifying configuration of a (S)NARK. + type Protocol: Clone + Debug; + /// Structured proof read from transcript. type Proof: Clone + Debug; + /// Output of verification. + type Output: Clone + Debug; + /// Read [`SnarkVerifier::Proof`] from transcript. fn read_proof( - svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + vk: &Self::VerifyingKey, + protocol: &Self::Protocol, instances: &[Vec], transcript: &mut T, - ) -> Self::Proof + ) -> Result where T: TranscriptRead; - fn succinct_verify( - svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, - instances: &[Vec], - proof: &Self::Proof, - ) -> Vec; - + /// Verify [`SnarkVerifier::Proof`] and output [`SnarkVerifier::Output`]. fn verify( - svk: &MOS::SuccinctVerifyingKey, - dk: &MOS::DecidingKey, - protocol: &Protocol, + vk: &Self::VerifyingKey, + protocol: &Self::Protocol, instances: &[Vec], proof: &Self::Proof, - ) -> MOS::Output - where - MOS: Decider, - { - let accumulators = Self::succinct_verify(svk, protocol, instances, proof); - MOS::decide_all(dk, accumulators) - } + ) -> Result; } diff --git a/snark-verifier/src/verifier/plonk.rs b/snark-verifier/src/verifier/plonk.rs index f42c912c..d5937ab8 100644 --- a/snark-verifier/src/verifier/plonk.rs +++ b/snark-verifier/src/verifier/plonk.rs @@ -1,57 +1,69 @@ +//! Verifiers for [PLONK], currently there are [`PlonkSuccinctVerifier`] and +//! [`PlonkVerifier`] implemented and both are implemented assuming the used +//! [`PolynomialCommitmentScheme`] has [atomic] or [split] accumulation scheme +//! ([`PlonkVerifier`] is just [`PlonkSuccinctVerifier`] plus doing accumulator +//! deciding then returns accept/reject as ouput). +//! +//! [PLONK]: https://eprint.iacr.org/2019/953 +//! [atomic]: https://eprint.iacr.org/2020/499 +//! [split]: https://eprint.iacr.org/2020/1618 + use crate::{ cost::{Cost, CostEstimation}, - loader::{native::NativeLoader, LoadedScalar, Loader}, - pcs::{self, AccumulatorEncoding, MultiOpenScheme}, - util::{ - arithmetic::{CurveAffine, Field, Rotation}, - msm::Msm, - protocol::{ - CommonPolynomial::Lagrange, CommonPolynomialEvaluation, LinearizationStrategy, Query, - }, - transcript::TranscriptRead, - Itertools, + loader::Loader, + pcs::{ + AccumulationDecider, AccumulationScheme, AccumulatorEncoding, PolynomialCommitmentScheme, + Query, }, - verifier::PlonkVerifier, - Error, Protocol, + util::{arithmetic::CurveAffine, transcript::TranscriptRead}, + verifier::{plonk::protocol::CommonPolynomialEvaluation, SnarkVerifier}, + Error, }; -use rustc_hash::FxHashMap; use std::{iter, marker::PhantomData}; -pub struct Plonk(PhantomData<(MOS, AE)>); +mod proof; +pub(crate) mod protocol; + +pub use proof::PlonkProof; +pub use protocol::PlonkProtocol; + +/// Verifier that verifies the cheap part of PLONK and ouput the accumulator. +#[derive(Debug)] +pub struct PlonkSuccinctVerifier>(PhantomData<(AS, AE)>); -impl PlonkVerifier for Plonk +impl SnarkVerifier for PlonkSuccinctVerifier where C: CurveAffine, L: Loader, - MOS: MultiOpenScheme, - AE: AccumulatorEncoding, + AS: AccumulationScheme + PolynomialCommitmentScheme, + AE: AccumulatorEncoding, { - type Proof = PlonkProof; + type VerifyingKey = >::VerifyingKey; + type Protocol = PlonkProtocol; + type Proof = PlonkProof; + type Output = Vec; fn read_proof( - svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + svk: &Self::VerifyingKey, + protocol: &Self::Protocol, instances: &[Vec], transcript: &mut T, - ) -> Self::Proof + ) -> Result where T: TranscriptRead, { PlonkProof::read::(svk, protocol, instances, transcript) } - fn succinct_verify( - svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + fn verify( + svk: &Self::VerifyingKey, + protocol: &Self::Protocol, instances: &[Vec], proof: &Self::Proof, - ) -> Vec { + ) -> Result { let common_poly_eval = { - let mut common_poly_eval = CommonPolynomialEvaluation::new( - &protocol.domain, - langranges(protocol, instances), - &proof.z, - ); + let mut common_poly_eval = + CommonPolynomialEvaluation::new(&protocol.domain, protocol.langranges(), &proof.z); L::batch_invert(common_poly_eval.denoms()); common_poly_eval.evaluate(); @@ -59,313 +71,80 @@ where common_poly_eval }; - let mut evaluations = proof.evaluations(protocol, instances, &common_poly_eval); - let commitments = proof.commitments(protocol, &common_poly_eval, &mut evaluations); + let mut evaluations = proof.evaluations(protocol, instances, &common_poly_eval)?; + let commitments = proof.commitments(protocol, &common_poly_eval, &mut evaluations)?; let queries = proof.queries(protocol, evaluations); - let accumulator = MOS::succinct_verify(svk, &commitments, &proof.z, &queries, &proof.pcs); + let accumulator = >::verify( + svk, + &commitments, + &proof.z, + &queries, + &proof.pcs, + )?; let accumulators = iter::empty() .chain(Some(accumulator)) .chain(proof.old_accumulators.iter().cloned()) .collect(); - accumulators + Ok(accumulators) } } -#[derive(Clone, Debug)] -pub struct PlonkProof -where - C: CurveAffine, - L: Loader, - MOS: MultiOpenScheme, -{ - pub committed_instances: Option>, - pub witnesses: Vec, - pub challenges: Vec, - pub quotients: Vec, - pub z: L::LoadedScalar, - pub evaluations: Vec, - pub pcs: MOS::Proof, - pub old_accumulators: Vec, -} +/// Verifier that first verifies the cheap part of PLONK, then decides +/// accumulator and returns accept/reject as ouput. +#[derive(Debug)] +pub struct PlonkVerifier>(PhantomData<(AS, AE)>); -impl PlonkProof +impl SnarkVerifier for PlonkVerifier where C: CurveAffine, L: Loader, - MOS: MultiOpenScheme, + AS: AccumulationDecider + PolynomialCommitmentScheme, + AS::DecidingKey: AsRef<>::VerifyingKey>, + AE: AccumulatorEncoding, { - pub fn read( - svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + type VerifyingKey = AS::DecidingKey; + type Protocol = PlonkProtocol; + type Proof = PlonkProof; + type Output = (); + + fn read_proof( + vk: &Self::VerifyingKey, + protocol: &Self::Protocol, instances: &[Vec], transcript: &mut T, - ) -> Self + ) -> Result where T: TranscriptRead, - AE: AccumulatorEncoding, { - if let Some(transcript_initial_state) = &protocol.transcript_initial_state { - transcript.common_scalar(transcript_initial_state).unwrap(); - } - - debug_assert_eq!( - protocol.num_instance, - instances.iter().map(|instances| instances.len()).collect_vec(), - "Invalid Instances" - ); - - let committed_instances = if let Some(ick) = &protocol.instance_committing_key { - let loader = transcript.loader(); - let bases = - ick.bases.iter().map(|value| loader.ec_point_load_const(value)).collect_vec(); - let constant = ick.constant.as_ref().map(|value| loader.ec_point_load_const(value)); - - let committed_instances = instances - .iter() - .map(|instances| { - instances - .iter() - .zip(bases.iter()) - .map(|(scalar, base)| Msm::::base(base) * scalar) - .chain(constant.as_ref().map(Msm::base)) - .sum::>() - .evaluate(None) - }) - .collect_vec(); - for committed_instance in committed_instances.iter() { - transcript.common_ec_point(committed_instance).unwrap(); - } - - Some(committed_instances) - } else { - for instances in instances.iter() { - for instance in instances.iter() { - transcript.common_scalar(instance).unwrap(); - } - } - - None - }; - - let (witnesses, challenges) = { - let (witnesses, challenges): (Vec<_>, Vec<_>) = protocol - .num_witness - .iter() - .zip(protocol.num_challenge.iter()) - .map(|(&n, &m)| { - (transcript.read_n_ec_points(n).unwrap(), transcript.squeeze_n_challenges(m)) - }) - .unzip(); - - ( - witnesses.into_iter().flatten().collect_vec(), - challenges.into_iter().flatten().collect_vec(), - ) - }; - - let quotients = transcript.read_n_ec_points(protocol.quotient.num_chunk()).unwrap(); - - let z = transcript.squeeze_challenge(); - let evaluations = transcript.read_n_scalars(protocol.evaluations.len()).unwrap(); - - let pcs = MOS::read_proof(svk, &Self::empty_queries(protocol), transcript); - - let old_accumulators = protocol - .accumulator_indices - .iter() - .map(|accumulator_indices| { - AE::from_repr( - &accumulator_indices.iter().map(|&(i, j)| &instances[i][j]).collect_vec(), - ) - .unwrap() - }) - .collect_vec(); - - Self { - committed_instances, - witnesses, - challenges, - quotients, - z, - evaluations, - pcs, - old_accumulators, - } - } - - pub fn empty_queries(protocol: &Protocol) -> Vec> { - protocol - .queries - .iter() - .map(|query| pcs::Query { - poly: query.poly, - shift: protocol.domain.rotate_scalar(C::Scalar::one(), query.rotation), - eval: (), - }) - .collect() - } - - fn queries( - &self, - protocol: &Protocol, - mut evaluations: FxHashMap, - ) -> Vec> { - Self::empty_queries(protocol) - .into_iter() - .zip(protocol.queries.iter().map(|query| evaluations.remove(query).unwrap())) - .map(|(query, eval)| query.with_evaluation(eval)) - .collect() + PlonkProof::read::(vk.as_ref(), protocol, instances, transcript) } - fn commitments<'a>( - &'a self, - protocol: &'a Protocol, - common_poly_eval: &CommonPolynomialEvaluation, - evaluations: &mut FxHashMap, - ) -> Vec> { - let loader = common_poly_eval.zn().loader(); - let mut commitments = iter::empty() - .chain(protocol.preprocessed.iter().map(Msm::base)) - .chain( - self.committed_instances - .as_ref() - .map(|committed_instances| { - committed_instances.iter().map(Msm::base).collect_vec() - }) - .unwrap_or_else(|| { - iter::repeat_with(Default::default) - .take(protocol.num_instance.len()) - .collect_vec() - }), - ) - .chain(self.witnesses.iter().map(Msm::base)) - .collect_vec(); - - let numerator = protocol.quotient.numerator.evaluate( - &|scalar| Msm::constant(loader.load_const(&scalar)), - &|poly| Msm::constant(common_poly_eval.get(poly).clone()), - &|query| { - evaluations - .get(&query) - .cloned() - .map(Msm::constant) - .or_else(|| { - (query.rotation == Rotation::cur()) - .then(|| commitments.get(query.poly).cloned()) - .flatten() - }) - .ok_or(Error::InvalidQuery(query)) - .unwrap() - }, - &|index| { - self.challenges - .get(index) - .cloned() - .map(Msm::constant) - .ok_or(Error::InvalidChallenge(index)) - .unwrap() - }, - &|a| -a, - &|a, b| a + b, - &|a, b| match (a.size(), b.size()) { - (0, _) => b * &a.try_into_constant().unwrap(), - (_, 0) => a * &b.try_into_constant().unwrap(), - (_, _) => panic!("{:?}", Error::InvalidLinearization), - }, - &|a, scalar| a * &loader.load_const(&scalar), - ); - - let quotient_query = Query::new( - protocol.preprocessed.len() + protocol.num_instance.len() + self.witnesses.len(), - Rotation::cur(), - ); - let quotient = common_poly_eval - .zn() - .pow_const(protocol.quotient.chunk_degree as u64) - .powers(self.quotients.len()) - .into_iter() - .zip(self.quotients.iter().map(Msm::base)) - .map(|(coeff, chunk)| chunk * &coeff) - .sum::>(); - match protocol.linearization { - Some(LinearizationStrategy::WithoutConstant) => { - let linearization_query = Query::new(quotient_query.poly + 1, Rotation::cur()); - let (msm, constant) = numerator.split(); - commitments.push(quotient); - commitments.push(msm); - evaluations.insert( - quotient_query, - (constant.unwrap_or_else(|| loader.load_zero()) - + evaluations.get(&linearization_query).unwrap()) - * common_poly_eval.zn_minus_one_inv(), - ); - } - Some(LinearizationStrategy::MinusVanishingTimesQuotient) => { - let (msm, constant) = - (numerator - quotient * common_poly_eval.zn_minus_one()).split(); - commitments.push(msm); - evaluations.insert(quotient_query, constant.unwrap_or_else(|| loader.load_zero())); - } - None => { - commitments.push(quotient); - evaluations.insert( - quotient_query, - numerator.try_into_constant().ok_or(Error::InvalidLinearization).unwrap() - * common_poly_eval.zn_minus_one_inv(), - ); - } - } - - commitments - } - - fn evaluations( - &self, - protocol: &Protocol, + fn verify( + vk: &Self::VerifyingKey, + protocol: &Self::Protocol, instances: &[Vec], - common_poly_eval: &CommonPolynomialEvaluation, - ) -> FxHashMap { - let loader = common_poly_eval.zn().loader(); - let instance_evals = protocol.instance_committing_key.is_none().then(|| { - let offset = protocol.preprocessed.len(); - let queries = { - let range = offset..offset + protocol.num_instance.len(); - protocol - .quotient - .numerator - .used_query() - .into_iter() - .filter(move |query| range.contains(&query.poly)) - }; - queries - .map(move |query| { - let instances = instances[query.poly - offset].iter(); - let l_i_minus_r = (-query.rotation.0..) - .map(|i_minus_r| common_poly_eval.get(Lagrange(i_minus_r))); - let eval = loader.sum_products(&instances.zip(l_i_minus_r).collect_vec()); - (query, eval) - }) - .collect_vec() - }); - - iter::empty() - .chain(instance_evals.into_iter().flatten()) - .chain(protocol.evaluations.iter().cloned().zip(self.evaluations.iter().cloned())) - .collect() + proof: &Self::Proof, + ) -> Result { + let accumulators = + PlonkSuccinctVerifier::::verify(vk.as_ref(), protocol, instances, proof)?; + AS::decide_all(vk, accumulators) } } -impl CostEstimation<(C, MOS)> for Plonk +impl CostEstimation<(C, L)> for PlonkSuccinctVerifier where C: CurveAffine, - MOS: MultiOpenScheme + CostEstimation>>, + L: Loader, + AS: AccumulationScheme + + PolynomialCommitmentScheme + + CostEstimation>>, { - type Input = Protocol; + type Input = PlonkProtocol; - fn estimate_cost(protocol: &Protocol) -> Cost { + fn estimate_cost(protocol: &PlonkProtocol) -> Cost { let plonk_cost = { let num_accumulator = protocol.accumulator_indices.len(); let num_instance = protocol.num_instance.iter().sum(); @@ -373,52 +152,28 @@ where protocol.num_witness.iter().sum::() + protocol.quotient.num_chunk(); let num_evaluation = protocol.evaluations.len(); let num_msm = protocol.preprocessed.len() + num_commitment + 1 + 2 * num_accumulator; - Cost::new(num_instance, num_commitment, num_evaluation, num_msm) + Cost { num_instance, num_commitment, num_evaluation, num_msm, ..Default::default() } }; let pcs_cost = { - let queries = PlonkProof::::empty_queries(protocol); - MOS::estimate_cost(&queries) + let queries = PlonkProof::::empty_queries(protocol); + AS::estimate_cost(&queries) }; plonk_cost + pcs_cost } } -fn langranges( - protocol: &Protocol, - instances: &[Vec], -) -> impl IntoIterator +impl CostEstimation<(C, L)> for PlonkVerifier where C: CurveAffine, L: Loader, + AS: AccumulationScheme + + PolynomialCommitmentScheme + + CostEstimation>>, { - let instance_eval_lagrange = protocol.instance_committing_key.is_none().then(|| { - let queries = { - let offset = protocol.preprocessed.len(); - let range = offset..offset + protocol.num_instance.len(); - protocol - .quotient - .numerator - .used_query() - .into_iter() - .filter(move |query| range.contains(&query.poly)) - }; - let (min_rotation, max_rotation) = queries.fold((0, 0), |(min, max), query| { - if query.rotation.0 < min { - (query.rotation.0, max) - } else if query.rotation.0 > max { - (min, query.rotation.0) - } else { - (min, max) - } - }); - let max_instance_len = - Iterator::max(instances.iter().map(|instance| instance.len())).unwrap_or_default(); - -max_rotation..max_instance_len as i32 + min_rotation.abs() - }); - protocol - .quotient - .numerator - .used_langrange() - .into_iter() - .chain(instance_eval_lagrange.into_iter().flatten()) + type Input = PlonkProtocol; + + fn estimate_cost(protocol: &PlonkProtocol) -> Cost { + PlonkSuccinctVerifier::::estimate_cost(protocol) + + Cost { num_pairing: 2, ..Default::default() } + } } diff --git a/snark-verifier/src/verifier/plonk/proof.rs b/snark-verifier/src/verifier/plonk/proof.rs new file mode 100644 index 00000000..7adba7ac --- /dev/null +++ b/snark-verifier/src/verifier/plonk/proof.rs @@ -0,0 +1,319 @@ +use crate::{ + loader::{LoadedScalar, Loader}, + pcs::{self, AccumulationScheme, AccumulatorEncoding, PolynomialCommitmentScheme}, + util::{ + arithmetic::{CurveAffine, Field, Rotation}, + msm::Msm, + transcript::TranscriptRead, + Itertools, + }, + verifier::plonk::protocol::{ + CommonPolynomial::Lagrange, CommonPolynomialEvaluation, LinearizationStrategy, + PlonkProtocol, Query, + }, + Error, +}; +use std::{collections::HashMap, iter}; + +/// Proof of PLONK with [`PolynomialCommitmentScheme`] that has +/// [`AccumulationScheme`]. +#[derive(Clone, Debug)] +pub struct PlonkProof +where + C: CurveAffine, + L: Loader, + AS: AccumulationScheme + PolynomialCommitmentScheme, +{ + /// Computed commitments of instance polynomials. + pub committed_instances: Option>, + /// Commitments of witness polynomials read from transcript. + pub witnesses: Vec, + /// Challenges squeezed from transcript. + pub challenges: Vec, + /// Quotient commitments read from transcript. + pub quotients: Vec, + /// Query point squeezed from transcript. + pub z: L::LoadedScalar, + /// Evaluations read from transcript. + pub evaluations: Vec, + /// Proof of [`PolynomialCommitmentScheme`]. + pub pcs: >::Proof, + /// Old [`AccumulationScheme::Accumulator`]s read from instnaces. + pub old_accumulators: Vec, +} + +impl PlonkProof +where + C: CurveAffine, + L: Loader, + AS: AccumulationScheme + PolynomialCommitmentScheme, +{ + /// Reads each part from transcript as [`PlonkProof`]. + pub fn read( + svk: &>::VerifyingKey, + protocol: &PlonkProtocol, + instances: &[Vec], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + AE: AccumulatorEncoding, + { + if let Some(transcript_initial_state) = &protocol.transcript_initial_state { + transcript.common_scalar(transcript_initial_state)?; + } + + if protocol.num_instance != instances.iter().map(|instances| instances.len()).collect_vec() + { + return Err(Error::InvalidInstances); + } + + let committed_instances = if let Some(ick) = &protocol.instance_committing_key { + let loader = transcript.loader(); + let bases = + ick.bases.iter().map(|value| loader.ec_point_load_const(value)).collect_vec(); + let constant = ick.constant.as_ref().map(|value| loader.ec_point_load_const(value)); + + let committed_instances = instances + .iter() + .map(|instances| { + instances + .iter() + .zip(bases.iter()) + .map(|(scalar, base)| Msm::::base(base) * scalar) + .chain(constant.as_ref().map(Msm::base)) + .sum::>() + .evaluate(None) + }) + .collect_vec(); + for committed_instance in committed_instances.iter() { + transcript.common_ec_point(committed_instance)?; + } + + Some(committed_instances) + } else { + for instances in instances.iter() { + for instance in instances.iter() { + transcript.common_scalar(instance)?; + } + } + + None + }; + + let (witnesses, challenges) = { + let (witnesses, challenges) = protocol + .num_witness + .iter() + .zip(protocol.num_challenge.iter()) + .map(|(&n, &m)| { + Ok((transcript.read_n_ec_points(n)?, transcript.squeeze_n_challenges(m))) + }) + .collect::, Error>>()? + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + ( + witnesses.into_iter().flatten().collect_vec(), + challenges.into_iter().flatten().collect_vec(), + ) + }; + + let quotients = transcript.read_n_ec_points(protocol.quotient.num_chunk())?; + + let z = transcript.squeeze_challenge(); + let evaluations = transcript.read_n_scalars(protocol.evaluations.len())?; + + let pcs = >::read_proof( + svk, + &Self::empty_queries(protocol), + transcript, + )?; + + let old_accumulators = protocol + .accumulator_indices + .iter() + .map(|accumulator_indices| { + AE::from_repr( + &accumulator_indices.iter().map(|&(i, j)| &instances[i][j]).collect_vec(), + ) + }) + .collect::, _>>()?; + + Ok(Self { + committed_instances, + witnesses, + challenges, + quotients, + z, + evaluations, + pcs, + old_accumulators, + }) + } + + /// Empty queries + pub fn empty_queries(protocol: &PlonkProtocol) -> Vec> { + protocol + .queries + .iter() + .map(|query| { + let shift = protocol.domain.rotate_scalar(C::Scalar::one(), query.rotation); + pcs::Query::new(query.poly, shift) + }) + .collect() + } + + pub(super) fn queries( + &self, + protocol: &PlonkProtocol, + mut evaluations: HashMap, + ) -> Vec> { + Self::empty_queries(protocol) + .into_iter() + .zip(protocol.queries.iter().map(|query| evaluations.remove(query).unwrap())) + .map(|(query, eval)| query.with_evaluation(eval)) + .collect() + } + + pub(super) fn commitments<'a>( + &'a self, + protocol: &'a PlonkProtocol, + common_poly_eval: &CommonPolynomialEvaluation, + evaluations: &mut HashMap, + ) -> Result>, Error> { + let loader = common_poly_eval.zn().loader(); + let mut commitments = iter::empty() + .chain(protocol.preprocessed.iter().map(Msm::base)) + .chain( + self.committed_instances + .as_ref() + .map(|committed_instances| { + committed_instances.iter().map(Msm::base).collect_vec() + }) + .unwrap_or_else(|| { + iter::repeat_with(Default::default) + .take(protocol.num_instance.len()) + .collect_vec() + }), + ) + .chain(self.witnesses.iter().map(Msm::base)) + .collect_vec(); + + let numerator = protocol.quotient.numerator.evaluate( + &|scalar| Ok(Msm::constant(loader.load_const(&scalar))), + &|poly| Ok(Msm::constant(common_poly_eval.get(poly).clone())), + &|query| { + evaluations + .get(&query) + .cloned() + .map(Msm::constant) + .or_else(|| { + (query.rotation == Rotation::cur()) + .then(|| commitments.get(query.poly).cloned()) + .flatten() + }) + .ok_or_else(|| Error::InvalidProtocol(format!("Missing query {query:?}"))) + }, + &|index| { + self.challenges + .get(index) + .cloned() + .map(Msm::constant) + .ok_or_else(|| Error::InvalidProtocol(format!("Missing challenge {index}"))) + }, + &|a| Ok(-a?), + &|a, b| Ok(a? + b?), + &|a, b| { + let (a, b) = (a?, b?); + match (a.size(), b.size()) { + (0, _) => Ok(b * &a.try_into_constant().unwrap()), + (_, 0) => Ok(a * &b.try_into_constant().unwrap()), + (_, _) => Err(Error::InvalidProtocol("Invalid linearization".to_string())), + } + }, + &|a, scalar| Ok(a? * &loader.load_const(&scalar)), + )?; + + let quotient_query = Query::new( + protocol.preprocessed.len() + protocol.num_instance.len() + self.witnesses.len(), + Rotation::cur(), + ); + let quotient = common_poly_eval + .zn() + .pow_const(protocol.quotient.chunk_degree as u64) + .powers(self.quotients.len()) + .into_iter() + .zip(self.quotients.iter().map(Msm::base)) + .map(|(coeff, chunk)| chunk * &coeff) + .sum::>(); + match protocol.linearization { + Some(LinearizationStrategy::WithoutConstant) => { + let linearization_query = Query::new(quotient_query.poly + 1, Rotation::cur()); + let (msm, constant) = numerator.split(); + commitments.push(quotient); + commitments.push(msm); + evaluations.insert( + quotient_query, + (constant.unwrap_or_else(|| loader.load_zero()) + + evaluations.get(&linearization_query).unwrap()) + * common_poly_eval.zn_minus_one_inv(), + ); + } + Some(LinearizationStrategy::MinusVanishingTimesQuotient) => { + let (msm, constant) = + (numerator - quotient * common_poly_eval.zn_minus_one()).split(); + commitments.push(msm); + evaluations.insert(quotient_query, constant.unwrap_or_else(|| loader.load_zero())); + } + None => { + commitments.push(quotient); + evaluations.insert( + quotient_query, + numerator.try_into_constant().ok_or_else(|| { + Error::InvalidProtocol("Invalid linearization".to_string()) + })? * common_poly_eval.zn_minus_one_inv(), + ); + } + } + + Ok(commitments) + } + + pub(super) fn evaluations( + &self, + protocol: &PlonkProtocol, + instances: &[Vec], + common_poly_eval: &CommonPolynomialEvaluation, + ) -> Result, Error> { + let loader = common_poly_eval.zn().loader(); + let instance_evals = protocol.instance_committing_key.is_none().then(|| { + let offset = protocol.preprocessed.len(); + let queries = { + let range = offset..offset + protocol.num_instance.len(); + protocol + .quotient + .numerator + .used_query() + .into_iter() + .filter(move |query| range.contains(&query.poly)) + }; + queries + .map(move |query| { + let instances = instances[query.poly - offset].iter(); + let l_i_minus_r = (-query.rotation.0..) + .map(|i_minus_r| common_poly_eval.get(Lagrange(i_minus_r))); + let eval = loader.sum_products(&instances.zip(l_i_minus_r).collect_vec()); + (query, eval) + }) + .collect_vec() + }); + + let evals = iter::empty() + .chain(instance_evals.into_iter().flatten()) + .chain(protocol.evaluations.iter().cloned().zip(self.evaluations.iter().cloned())) + .collect(); + + Ok(evals) + } +} diff --git a/snark-verifier/src/util/protocol.rs b/snark-verifier/src/verifier/plonk/protocol.rs similarity index 68% rename from snark-verifier/src/util/protocol.rs rename to snark-verifier/src/verifier/plonk/protocol.rs index a883a599..a3a84346 100644 --- a/snark-verifier/src/util/protocol.rs +++ b/snark-verifier/src/verifier/plonk/protocol.rs @@ -1,10 +1,9 @@ use crate::{ - loader::{LoadedScalar, Loader}, + loader::{native::NativeLoader, LoadedScalar, Loader}, util::{ arithmetic::{CurveAffine, Domain, Field, Fraction, Rotation}, Itertools, }, - Protocol, }; use num_integer::Integer; use num_traits::One; @@ -17,11 +16,94 @@ use std::{ ops::{Add, Mul, Neg, Sub}, }; -impl Protocol +/// Protocol specifying configuration of a PLONK. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct PlonkProtocol where C: CurveAffine, + L: Loader, { - pub fn loaded>(&self, loader: &L) -> Protocol { + #[serde(bound( + serialize = "C::Scalar: Serialize", + deserialize = "C::Scalar: Deserialize<'de>" + ))] + /// Working domain. + pub domain: Domain, + #[serde(bound( + serialize = "L::LoadedEcPoint: Serialize", + deserialize = "L::LoadedEcPoint: Deserialize<'de>" + ))] + /// Commitments of preprocessed polynomials. + pub preprocessed: Vec, + /// Number of instances in each instance polynomial. + pub num_instance: Vec, + /// Number of witness polynomials in each phase. + pub num_witness: Vec, + /// Number of challenges to squeeze from transcript after each phase. + pub num_challenge: Vec, + /// Evaluations to read from transcript. + pub evaluations: Vec, + /// [`crate::pcs::PolynomialCommitmentScheme`] queries to verify. + pub queries: Vec, + /// Structure of quotient polynomial. + pub quotient: QuotientPolynomial, + #[serde(bound( + serialize = "L::LoadedScalar: Serialize", + deserialize = "L::LoadedScalar: Deserialize<'de>" + ))] + /// Prover and verifier common initial state to write to transcript if any. + pub transcript_initial_state: Option, + /// Instance polynomials commiting key if any. + pub instance_committing_key: Option>, + /// Linearization strategy. + pub linearization: Option, + /// Indices (instance polynomial index, row) of encoded + /// [`crate::pcs::AccumulationScheme::Accumulator`]s. + pub accumulator_indices: Vec>, +} + +impl PlonkProtocol +where + C: CurveAffine, + L: Loader, +{ + pub(super) fn langranges(&self) -> impl IntoIterator { + let instance_eval_lagrange = self.instance_committing_key.is_none().then(|| { + let queries = { + let offset = self.preprocessed.len(); + let range = offset..offset + self.num_instance.len(); + self.quotient + .numerator + .used_query() + .into_iter() + .filter(move |query| range.contains(&query.poly)) + }; + let (min_rotation, max_rotation) = queries.fold((0, 0), |(min, max), query| { + if query.rotation.0 < min { + (query.rotation.0, max) + } else if query.rotation.0 > max { + (min, query.rotation.0) + } else { + (min, max) + } + }); + let max_instance_len = self.num_instance.iter().max().copied().unwrap_or_default(); + -max_rotation..max_instance_len as i32 + min_rotation.abs() + }); + self.quotient + .numerator + .used_langrange() + .into_iter() + .chain(instance_eval_lagrange.into_iter().flatten()) + } +} +impl PlonkProtocol +where + C: CurveAffine, +{ + /// Loaded `PlonkProtocol` with `preprocessed` and + /// `transcript_initial_state` loaded as constant. + pub fn loaded>(&self, loader: &L) -> PlonkProtocol { let preprocessed = self .preprocessed .iter() @@ -31,7 +113,7 @@ where .transcript_initial_state .as_ref() .map(|transcript_initial_state| loader.load_const(transcript_initial_state)); - Protocol { + PlonkProtocol { domain: self.domain.clone(), preprocessed, num_instance: self.num_instance.clone(), @@ -48,6 +130,53 @@ where } } +#[cfg(feature = "loader_halo2")] +mod halo2 { + use crate::{ + loader::halo2::{EccInstructions, Halo2Loader}, + util::arithmetic::CurveAffine, + verifier::plonk::PlonkProtocol, + }; + use std::rc::Rc; + + impl PlonkProtocol + where + C: CurveAffine, + { + /// Loaded `PlonkProtocol` with `preprocessed` and + /// `transcript_initial_state` loaded as witness, which is useful when + /// doing recursion. + pub fn loaded_preprocessed_as_witness>( + &self, + loader: &Rc>, + ) -> PlonkProtocol>> { + let preprocessed = self + .preprocessed + .iter() + .map(|preprocessed| loader.assign_ec_point(*preprocessed)) + .collect(); + let transcript_initial_state = self + .transcript_initial_state + .as_ref() + .map(|transcript_initial_state| loader.assign_scalar(*transcript_initial_state)); + PlonkProtocol { + domain: self.domain.clone(), + preprocessed, + num_instance: self.num_instance.clone(), + num_witness: self.num_witness.clone(), + num_challenge: self.num_challenge.clone(), + evaluations: self.evaluations.clone(), + queries: self.queries.clone(), + quotient: self.quotient.clone(), + transcript_initial_state, + instance_committing_key: self.instance_committing_key.clone(), + linearization: self.linearization, + accumulator_indices: self.accumulator_indices.clone(), + } + } + } +} + #[derive(Clone, Copy, Debug, Serialize, Deserialize)] pub enum CommonPolynomial { Identity, @@ -150,11 +279,14 @@ pub struct QuotientPolynomial { impl QuotientPolynomial { pub fn num_chunk(&self) -> usize { - Integer::div_ceil(&(self.numerator.degree() - 1), &self.chunk_degree) + Integer::div_ceil( + &(self.numerator.degree().checked_sub(1).unwrap_or_default()), + &self.chunk_degree, + ) } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub struct Query { pub poly: usize, pub rotation: Rotation,