From 147a62e55adfb2e28b7a62abe86949e1a1f726ce Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 22 Oct 2024 19:42:10 +0000 Subject: [PATCH] feat: generalize arm64 mul for larger modulus --- ecc/bls12-377/fp/element_arm64.s | 2 +- ecc/bls12-377/fr/element_arm64.s | 2 +- ecc/bls12-381/fp/element_arm64.s | 2 +- ecc/bls12-381/fr/element_arm64.s | 2 +- ecc/bls24-315/fr/element_arm64.s | 2 +- ecc/bls24-317/fr/element_arm64.s | 2 +- ecc/bn254/fp/element_arm64.s | 2 +- ecc/bn254/fr/element_arm64.s | 2 +- ecc/bw6-633/fp/element_purego.go | 2 +- ecc/bw6-761/fp/element_purego.go | 2 +- ecc/bw6-761/fr/element_arm64.s | 2 +- ecc/stark-curve/fp/element_arm64.s | 2 +- ecc/stark-curve/fr/element_arm64.s | 2 +- field/asm/.gitignore | 7 +- field/asm/element_4w_arm64.s | 140 ++++----- field/asm/element_6w_arm64.s | 202 ++++++------- field/generator/asm/arm64/build.go | 4 +- field/generator/asm/arm64/element_ops.go | 267 ------------------ field/generator/config/field_config.go | 2 +- field/generator/generator_test.go | 3 + .../internal/templates/element/ops_asm.go | 6 + go.mod | 2 +- go.sum | 4 + internal/generator/main.go | 2 + 24 files changed, 210 insertions(+), 455 deletions(-) delete mode 100644 field/generator/asm/arm64/element_ops.go diff --git a/ecc/bls12-377/fp/element_arm64.s b/ecc/bls12-377/fp/element_arm64.s index 4c01eca83..2a3f7d0b2 100644 --- a/ecc/bls12-377/fp/element_arm64.s +++ b/ecc/bls12-377/fp/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17561434332277668166 +// We include the hash to force the Go compiler to recompile: 15397482240260640864 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-377/fr/element_arm64.s b/ecc/bls12-377/fr/element_arm64.s index 75bf9d9d1..5d328815a 100644 --- a/ecc/bls12-377/fr/element_arm64.s +++ b/ecc/bls12-377/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls12-381/fp/element_arm64.s b/ecc/bls12-381/fp/element_arm64.s index 4c01eca83..2a3f7d0b2 100644 --- a/ecc/bls12-381/fp/element_arm64.s +++ b/ecc/bls12-381/fp/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17561434332277668166 +// We include the hash to force the Go compiler to recompile: 15397482240260640864 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/bls12-381/fr/element_arm64.s b/ecc/bls12-381/fr/element_arm64.s index 75bf9d9d1..5d328815a 100644 --- a/ecc/bls12-381/fr/element_arm64.s +++ b/ecc/bls12-381/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-315/fr/element_arm64.s b/ecc/bls24-315/fr/element_arm64.s index 75bf9d9d1..5d328815a 100644 --- a/ecc/bls24-315/fr/element_arm64.s +++ b/ecc/bls24-315/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bls24-317/fr/element_arm64.s b/ecc/bls24-317/fr/element_arm64.s index 75bf9d9d1..5d328815a 100644 --- a/ecc/bls24-317/fr/element_arm64.s +++ b/ecc/bls24-317/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fp/element_arm64.s b/ecc/bn254/fp/element_arm64.s index 75bf9d9d1..5d328815a 100644 --- a/ecc/bn254/fp/element_arm64.s +++ b/ecc/bn254/fp/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bn254/fr/element_arm64.s b/ecc/bn254/fr/element_arm64.s index 75bf9d9d1..5d328815a 100644 --- a/ecc/bn254/fr/element_arm64.s +++ b/ecc/bn254/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/bw6-633/fp/element_purego.go b/ecc/bw6-633/fp/element_purego.go index 811df961e..637ecd9d6 100644 --- a/ecc/bw6-633/fp/element_purego.go +++ b/ecc/bw6-633/fp/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego || !amd64 +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-761/fp/element_purego.go b/ecc/bw6-761/fp/element_purego.go index 128e16274..4338d90c2 100644 --- a/ecc/bw6-761/fp/element_purego.go +++ b/ecc/bw6-761/fp/element_purego.go @@ -1,4 +1,4 @@ -//go:build purego || !amd64 +//go:build purego || (!amd64 && !arm64) // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-761/fr/element_arm64.s b/ecc/bw6-761/fr/element_arm64.s index 4c01eca83..2a3f7d0b2 100644 --- a/ecc/bw6-761/fr/element_arm64.s +++ b/ecc/bw6-761/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17561434332277668166 +// We include the hash to force the Go compiler to recompile: 15397482240260640864 #include "../../../field/asm/element_6w_arm64.s" diff --git a/ecc/stark-curve/fp/element_arm64.s b/ecc/stark-curve/fp/element_arm64.s index 75bf9d9d1..5d328815a 100644 --- a/ecc/stark-curve/fp/element_arm64.s +++ b/ecc/stark-curve/fp/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/ecc/stark-curve/fr/element_arm64.s b/ecc/stark-curve/fr/element_arm64.s index 75bf9d9d1..5d328815a 100644 --- a/ecc/stark-curve/fr/element_arm64.s +++ b/ecc/stark-curve/fr/element_arm64.s @@ -16,6 +16,6 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -// We include the hash to force the Go compiler to recompile: 17105046060840004046 +// We include the hash to force the Go compiler to recompile: 1501560133179981797 #include "../../../field/asm/element_4w_arm64.s" diff --git a/field/asm/.gitignore b/field/asm/.gitignore index 7c22f7f93..d534769fc 100644 --- a/field/asm/.gitignore +++ b/field/asm/.gitignore @@ -3,4 +3,9 @@ element_2w_amd64.s element_3w_amd64.s element_7w_amd64.s element_8w_amd64.s -*.h \ No newline at end of file +*.h + +element_2w_arm64.s +element_3w_arm64.s +element_7w_arm64.s +element_8w_arm64.s \ No newline at end of file diff --git a/field/asm/element_4w_arm64.s b/field/asm/element_4w_arm64.s index 58eccf414..fce96e300 100644 --- a/field/asm/element_4w_arm64.s +++ b/field/asm/element_4w_arm64.s @@ -54,89 +54,89 @@ TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 #define DIVSHIFT() \ - MUL R7, R17, R0 \ - ADDS R0, R11, R11 \ - MUL R8, R17, R0 \ - ADCS R0, R12, R12 \ - MUL R9, R17, R0 \ - ADCS R0, R13, R13 \ - MUL R10, R17, R0 \ - ADCS R0, R14, R14 \ - ADC R15, ZR, R15 \ - UMULH R7, R17, R0 \ - ADDS R0, R12, R11 \ - UMULH R8, R17, R0 \ - ADCS R0, R13, R12 \ - UMULH R9, R17, R0 \ - ADCS R0, R14, R13 \ - UMULH R10, R17, R0 \ - ADCS R0, R15, R14 \ + MUL R13, R12, R0 \ + ADDS R0, R6, R6 \ + MUL R14, R12, R0 \ + ADCS R0, R7, R7 \ + MUL R15, R12, R0 \ + ADCS R0, R8, R8 \ + MUL R16, R12, R0 \ + ADCS R0, R9, R9 \ + ADC R10, ZR, R10 \ + UMULH R13, R12, R0 \ + ADDS R0, R7, R6 \ + UMULH R14, R12, R0 \ + ADCS R0, R8, R7 \ + UMULH R15, R12, R0 \ + ADCS R0, R9, R8 \ + UMULH R16, R12, R0 \ + ADCS R0, R10, R9 \ #define MUL_WORD_N() \ - MUL R3, R2, R0 \ - ADDS R0, R11, R11 \ - MUL R11, R16, R17 \ - MUL R4, R2, R0 \ - ADCS R0, R12, R12 \ - MUL R5, R2, R0 \ - ADCS R0, R13, R13 \ - MUL R6, R2, R0 \ - ADCS R0, R14, R14 \ - ADC ZR, ZR, R15 \ - UMULH R3, R2, R0 \ - ADDS R0, R12, R12 \ - UMULH R4, R2, R0 \ - ADCS R0, R13, R13 \ - UMULH R5, R2, R0 \ - ADCS R0, R14, R14 \ - UMULH R6, R2, R0 \ - ADC R0, R15, R15 \ - DIVSHIFT() \ + MUL R2, R1, R0 \ + ADDS R0, R6, R6 \ + MUL R6, R11, R12 \ + MUL R3, R1, R0 \ + ADCS R0, R7, R7 \ + MUL R4, R1, R0 \ + ADCS R0, R8, R8 \ + MUL R5, R1, R0 \ + ADCS R0, R9, R9 \ + ADC ZR, ZR, R10 \ + UMULH R2, R1, R0 \ + ADDS R0, R7, R7 \ + UMULH R3, R1, R0 \ + ADCS R0, R8, R8 \ + UMULH R4, R1, R0 \ + ADCS R0, R9, R9 \ + UMULH R5, R1, R0 \ + ADC R0, R10, R10 \ + DIVSHIFT() \ #define MUL_WORD_0() \ - MUL R3, R2, R11 \ - MUL R4, R2, R12 \ - MUL R5, R2, R13 \ - MUL R6, R2, R14 \ - UMULH R3, R2, R0 \ - ADDS R0, R12, R12 \ - UMULH R4, R2, R0 \ - ADCS R0, R13, R13 \ - UMULH R5, R2, R0 \ - ADCS R0, R14, R14 \ - UMULH R6, R2, R0 \ - ADC R0, ZR, R15 \ - MUL R11, R16, R17 \ - DIVSHIFT() \ + MUL R2, R1, R6 \ + MUL R3, R1, R7 \ + MUL R4, R1, R8 \ + MUL R5, R1, R9 \ + UMULH R2, R1, R0 \ + ADDS R0, R7, R7 \ + UMULH R3, R1, R0 \ + ADCS R0, R8, R8 \ + UMULH R4, R1, R0 \ + ADCS R0, R9, R9 \ + UMULH R5, R1, R0 \ + ADC R0, ZR, R10 \ + MUL R6, R11, R12 \ + DIVSHIFT() \ - MOVD y+16(FP), R1 + MOVD y+16(FP), R17 MOVD x+8(FP), R0 - LDP 0(R0), (R3, R4) - LDP 16(R0), (R5, R6) - MOVD 0(R1), R2 - MOVD $const_qInvNeg, R16 - LDP ·qElement+0(SB), (R7, R8) - LDP ·qElement+16(SB), (R9, R10) + LDP 0(R0), (R2, R3) + LDP 16(R0), (R4, R5) + MOVD 0(R17), R1 + MOVD $const_qInvNeg, R11 + LDP ·qElement+0(SB), (R13, R14) + LDP ·qElement+16(SB), (R15, R16) MUL_WORD_0() - MOVD 8(R1), R2 + MOVD 8(R17), R1 MUL_WORD_N() - MOVD 16(R1), R2 + MOVD 16(R17), R1 MUL_WORD_N() - MOVD 24(R1), R2 + MOVD 24(R17), R1 MUL_WORD_N() // reduce if necessary - SUBS R7, R11, R7 - SBCS R8, R12, R8 - SBCS R9, R13, R9 - SBCS R10, R14, R10 + SUBS R13, R6, R13 + SBCS R14, R7, R14 + SBCS R15, R8, R15 + SBCS R16, R9, R16 MOVD res+0(FP), R0 - CSEL CS, R7, R11, R11 - CSEL CS, R8, R12, R12 - STP (R11, R12), 0(R0) - CSEL CS, R9, R13, R13 - CSEL CS, R10, R14, R14 - STP (R13, R14), 16(R0) + CSEL CS, R13, R6, R6 + CSEL CS, R14, R7, R7 + STP (R6, R7), 0(R0) + CSEL CS, R15, R8, R8 + CSEL CS, R16, R9, R9 + STP (R8, R9), 16(R0) RET // reduce(res *Element) diff --git a/field/asm/element_6w_arm64.s b/field/asm/element_6w_arm64.s index 7b4946b3f..7dbd7ecaf 100644 --- a/field/asm/element_6w_arm64.s +++ b/field/asm/element_6w_arm64.s @@ -71,122 +71,122 @@ TEXT ·Butterfly(SB), NOFRAME|NOSPLIT, $0-16 // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 TEXT ·mul(SB), NOFRAME|NOSPLIT, $0-24 #define DIVSHIFT() \ - MUL R9, R24, R0 \ - ADDS R0, R15, R15 \ - MUL R10, R24, R0 \ - ADCS R0, R16, R16 \ - MUL R11, R24, R0 \ - ADCS R0, R17, R17 \ - MUL R12, R24, R0 \ - ADCS R0, R19, R19 \ - MUL R13, R24, R0 \ - ADCS R0, R20, R20 \ - MUL R14, R24, R0 \ - ADCS R0, R21, R21 \ - ADC R22, ZR, R22 \ - UMULH R9, R24, R0 \ - ADDS R0, R16, R15 \ - UMULH R10, R24, R0 \ - ADCS R0, R17, R16 \ - UMULH R11, R24, R0 \ - ADCS R0, R19, R17 \ - UMULH R12, R24, R0 \ - ADCS R0, R20, R19 \ - UMULH R13, R24, R0 \ - ADCS R0, R21, R20 \ - UMULH R14, R24, R0 \ - ADCS R0, R22, R21 \ + MUL R17, R16, R0 \ + ADDS R0, R8, R8 \ + MUL R19, R16, R0 \ + ADCS R0, R9, R9 \ + MUL R20, R16, R0 \ + ADCS R0, R10, R10 \ + MUL R21, R16, R0 \ + ADCS R0, R11, R11 \ + MUL R22, R16, R0 \ + ADCS R0, R12, R12 \ + MUL R23, R16, R0 \ + ADCS R0, R13, R13 \ + ADC R14, ZR, R14 \ + UMULH R17, R16, R0 \ + ADDS R0, R9, R8 \ + UMULH R19, R16, R0 \ + ADCS R0, R10, R9 \ + UMULH R20, R16, R0 \ + ADCS R0, R11, R10 \ + UMULH R21, R16, R0 \ + ADCS R0, R12, R11 \ + UMULH R22, R16, R0 \ + ADCS R0, R13, R12 \ + UMULH R23, R16, R0 \ + ADCS R0, R14, R13 \ #define MUL_WORD_N() \ - MUL R3, R2, R0 \ - ADDS R0, R15, R15 \ - MUL R15, R23, R24 \ - MUL R4, R2, R0 \ - ADCS R0, R16, R16 \ - MUL R5, R2, R0 \ - ADCS R0, R17, R17 \ - MUL R6, R2, R0 \ - ADCS R0, R19, R19 \ - MUL R7, R2, R0 \ - ADCS R0, R20, R20 \ - MUL R8, R2, R0 \ - ADCS R0, R21, R21 \ - ADC ZR, ZR, R22 \ - UMULH R3, R2, R0 \ - ADDS R0, R16, R16 \ - UMULH R4, R2, R0 \ - ADCS R0, R17, R17 \ - UMULH R5, R2, R0 \ - ADCS R0, R19, R19 \ - UMULH R6, R2, R0 \ - ADCS R0, R20, R20 \ - UMULH R7, R2, R0 \ - ADCS R0, R21, R21 \ - UMULH R8, R2, R0 \ - ADC R0, R22, R22 \ - DIVSHIFT() \ + MUL R2, R1, R0 \ + ADDS R0, R8, R8 \ + MUL R8, R15, R16 \ + MUL R3, R1, R0 \ + ADCS R0, R9, R9 \ + MUL R4, R1, R0 \ + ADCS R0, R10, R10 \ + MUL R5, R1, R0 \ + ADCS R0, R11, R11 \ + MUL R6, R1, R0 \ + ADCS R0, R12, R12 \ + MUL R7, R1, R0 \ + ADCS R0, R13, R13 \ + ADC ZR, ZR, R14 \ + UMULH R2, R1, R0 \ + ADDS R0, R9, R9 \ + UMULH R3, R1, R0 \ + ADCS R0, R10, R10 \ + UMULH R4, R1, R0 \ + ADCS R0, R11, R11 \ + UMULH R5, R1, R0 \ + ADCS R0, R12, R12 \ + UMULH R6, R1, R0 \ + ADCS R0, R13, R13 \ + UMULH R7, R1, R0 \ + ADC R0, R14, R14 \ + DIVSHIFT() \ #define MUL_WORD_0() \ - MUL R3, R2, R15 \ - MUL R4, R2, R16 \ - MUL R5, R2, R17 \ - MUL R6, R2, R19 \ - MUL R7, R2, R20 \ - MUL R8, R2, R21 \ - UMULH R3, R2, R0 \ - ADDS R0, R16, R16 \ - UMULH R4, R2, R0 \ - ADCS R0, R17, R17 \ - UMULH R5, R2, R0 \ - ADCS R0, R19, R19 \ - UMULH R6, R2, R0 \ - ADCS R0, R20, R20 \ - UMULH R7, R2, R0 \ - ADCS R0, R21, R21 \ - UMULH R8, R2, R0 \ - ADC R0, ZR, R22 \ - MUL R15, R23, R24 \ - DIVSHIFT() \ + MUL R2, R1, R8 \ + MUL R3, R1, R9 \ + MUL R4, R1, R10 \ + MUL R5, R1, R11 \ + MUL R6, R1, R12 \ + MUL R7, R1, R13 \ + UMULH R2, R1, R0 \ + ADDS R0, R9, R9 \ + UMULH R3, R1, R0 \ + ADCS R0, R10, R10 \ + UMULH R4, R1, R0 \ + ADCS R0, R11, R11 \ + UMULH R5, R1, R0 \ + ADCS R0, R12, R12 \ + UMULH R6, R1, R0 \ + ADCS R0, R13, R13 \ + UMULH R7, R1, R0 \ + ADC R0, ZR, R14 \ + MUL R8, R15, R16 \ + DIVSHIFT() \ - MOVD y+16(FP), R1 + MOVD y+16(FP), R24 MOVD x+8(FP), R0 - LDP 0(R0), (R3, R4) - LDP 16(R0), (R5, R6) - LDP 32(R0), (R7, R8) - MOVD 0(R1), R2 - MOVD $const_qInvNeg, R23 - LDP ·qElement+0(SB), (R9, R10) - LDP ·qElement+16(SB), (R11, R12) - LDP ·qElement+32(SB), (R13, R14) + LDP 0(R0), (R2, R3) + LDP 16(R0), (R4, R5) + LDP 32(R0), (R6, R7) + MOVD 0(R24), R1 + MOVD $const_qInvNeg, R15 + LDP ·qElement+0(SB), (R17, R19) + LDP ·qElement+16(SB), (R20, R21) + LDP ·qElement+32(SB), (R22, R23) MUL_WORD_0() - MOVD 8(R1), R2 + MOVD 8(R24), R1 MUL_WORD_N() - MOVD 16(R1), R2 + MOVD 16(R24), R1 MUL_WORD_N() - MOVD 24(R1), R2 + MOVD 24(R24), R1 MUL_WORD_N() - MOVD 32(R1), R2 + MOVD 32(R24), R1 MUL_WORD_N() - MOVD 40(R1), R2 + MOVD 40(R24), R1 MUL_WORD_N() // reduce if necessary - SUBS R9, R15, R9 - SBCS R10, R16, R10 - SBCS R11, R17, R11 - SBCS R12, R19, R12 - SBCS R13, R20, R13 - SBCS R14, R21, R14 + SUBS R17, R8, R17 + SBCS R19, R9, R19 + SBCS R20, R10, R20 + SBCS R21, R11, R21 + SBCS R22, R12, R22 + SBCS R23, R13, R23 MOVD res+0(FP), R0 - CSEL CS, R9, R15, R15 - CSEL CS, R10, R16, R16 - STP (R15, R16), 0(R0) - CSEL CS, R11, R17, R17 - CSEL CS, R12, R19, R19 - STP (R17, R19), 16(R0) - CSEL CS, R13, R20, R20 - CSEL CS, R14, R21, R21 - STP (R20, R21), 32(R0) + CSEL CS, R17, R8, R8 + CSEL CS, R19, R9, R9 + STP (R8, R9), 0(R0) + CSEL CS, R20, R10, R10 + CSEL CS, R21, R11, R11 + STP (R10, R11), 16(R0) + CSEL CS, R22, R12, R12 + CSEL CS, R23, R13, R13 + STP (R12, R13), 32(R0) RET // reduce(res *Element) diff --git a/field/generator/asm/arm64/build.go b/field/generator/asm/arm64/build.go index cba7bde8b..08bc52fb1 100644 --- a/field/generator/asm/arm64/build.go +++ b/field/generator/asm/arm64/build.go @@ -107,7 +107,9 @@ func GenerateCommonASM(w io.Writer, nbWords int, hasVector bool) error { panic("NbWords must be even") } - f.generateButterfly() + if f.NbWords <= 6 { + f.generateButterfly() + } f.generateMul() f.generateReduce() diff --git a/field/generator/asm/arm64/element_ops.go b/field/generator/asm/arm64/element_ops.go deleted file mode 100644 index 354cd68d9..000000000 --- a/field/generator/asm/arm64/element_ops.go +++ /dev/null @@ -1,267 +0,0 @@ -// Copyright 2022 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arm64 - -import ( - "github.com/consensys/bavard/arm64" -) - -func (f *FFArm64) generateButterfly() { - f.Comment("butterfly(a, b *Element)") - f.Comment("a, b = a+b, a-b") - registers := f.FnHeader("Butterfly", 0, 16) - defer f.AssertCleanStack(0, 0) - - // registers - a := registers.PopN(f.NbWords) - b := registers.PopN(f.NbWords) - r := registers.PopN(f.NbWords) - t := registers.PopN(f.NbWords) - aPtr := registers.Pop() - bPtr := registers.Pop() - - f.LDP("x+0(FP)", aPtr, bPtr) - f.load(aPtr, a) - f.load(bPtr, b) - - for i := 0; i < f.NbWords; i++ { - f.add0n(i)(a[i], b[i], r[i]) - } - - f.SUBS(b[0], a[0], b[0]) - for i := 1; i < f.NbWords; i++ { - f.SBCS(b[i], a[i], b[i]) - } - - for i := 0; i < f.NbWords; i++ { - if i%2 == 0 { - f.LDP(f.qAt(i), a[i], a[i+1]) - } - f.CSEL("CS", "ZR", a[i], t[i]) - } - f.Comment("add q if underflow, 0 if not") - for i := 0; i < f.NbWords; i++ { - f.add0n(i)(b[i], t[i], b[i]) - if i%2 == 1 { - f.STP(b[i-1], b[i], bPtr.At(i-1)) - } - } - - f.reduceAndStore(r, a, aPtr) - - f.RET() -} - -func (f *FFArm64) generateReduce() { - f.Comment("reduce(res *Element)") - registers := f.FnHeader("reduce", 0, 8) - defer f.AssertCleanStack(0, 0) - - // registers - t := registers.PopN(f.NbWords) - q := registers.PopN(f.NbWords) - rPtr := registers.Pop() - - for i := 0; i < f.NbWords; i += 2 { - f.LDP(f.qAt(i), q[i], q[i+1]) - } - - f.MOVD("res+0(FP)", rPtr) - f.load(rPtr, t) - f.reduceAndStore(t, q, rPtr) - - f.RET() -} - -func (f *FFArm64) generateMul() { - f.Comment("mul(res, x, y *Element)") - f.Comment("Algorithm 2 of Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS") - f.Comment("by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521") - registers := f.FnHeader("mul", 0, 24) - defer f.AssertCleanStack(0, 0) - - xPtr := registers.Pop() - yPtr := registers.Pop() - bi := registers.Pop() - a := registers.PopN(f.NbWords) - q := registers.PopN(f.NbWords) - t := registers.PopN(f.NbWords + 1) - - ax := xPtr - qInv0 := registers.Pop() - m := registers.Pop() - - divShift := f.Define("divShift", 0, func(args ...arm64.Register) { - // for j=0 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - - for j := 0; j < f.NbWords; j++ { - f.MUL(q[j], m, ax) - f.add0m(j)(ax, t[j], t[j]) - } - f.add0m(f.NbWords)(t[f.NbWords], "ZR", t[f.NbWords]) - - // propagate high bits - f.UMULH(q[0], m, ax) - for j := 1; j <= f.NbWords; j++ { - f.add1m(j, true)(ax, t[j], t[j-1]) - if j != f.NbWords { - f.UMULH(q[j], m, ax) - } - } - }) - - mulWordN := f.Define("MUL_WORD_N", 0, func(args ...arm64.Register) { - // for j=0 to N-1 - // (C,t[j]) := t[j] + a[j]*b[i] + C - - // lo bits - for j := 0; j < f.NbWords; j++ { - f.MUL(a[j], bi, ax) - f.add0m(j)(ax, t[j], t[j]) - - if j == 0 { - f.MUL(t[0], qInv0, m) - } - } - f.add0m(f.NbWords)("ZR", "ZR", t[f.NbWords]) - - // propagate high bits - f.UMULH(a[0], bi, ax) - for j := 1; j <= f.NbWords; j++ { - f.add1m(j)(ax, t[j], t[j]) - if j != f.NbWords { - f.UMULH(a[j], bi, ax) - } - } - divShift() - }) - - mulWord0 := f.Define("MUL_WORD_0", 0, func(args ...arm64.Register) { - // for j=0 to N-1 - // (C,t[j]) := t[j] + a[j]*b[i] + C - // lo bits - for j := 0; j < f.NbWords; j++ { - f.MUL(a[j], bi, t[j]) - } - - // propagate high bits - f.UMULH(a[0], bi, ax) - for j := 1; j < f.NbWords; j++ { - f.add1m(j)(ax, t[j], t[j]) - f.UMULH(a[j], bi, ax) - } - f.add1m(f.NbWords)(ax, "ZR", t[f.NbWords]) - f.MUL(t[0], qInv0, m) - divShift() - }) - - f.MOVD("y+16(FP)", yPtr) - f.MOVD("x+8(FP)", xPtr) - f.load(xPtr, a) - for i := 0; i < f.NbWords; i++ { - f.MOVD(yPtr.At(i), bi) - - if i == 0 { - // load qInv0 and q at first iteration. - f.MOVD(f.qInv0(), qInv0) - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(f.qAt(i), q[i], q[i+1]) - } - mulWord0() - } else { - mulWordN() - } - } - - f.Comment("reduce if necessary") - f.SUBS(q[0], t[0], q[0]) - for i := 1; i < f.NbWords; i++ { - f.SBCS(q[i], t[i], q[i]) - } - - f.MOVD("res+0(FP)", ax) - for i := 0; i < f.NbWords; i++ { - f.CSEL("CS", q[i], t[i], t[i]) - if i%2 == 1 { - f.STP(t[i-1], t[i], ax.At(i-1)) - } - } - - f.RET() -} - -func (f *FFArm64) load(zPtr arm64.Register, z []arm64.Register) { - for i := 0; i < f.NbWords-1; i += 2 { - f.LDP(zPtr.At(i), z[i], z[i+1]) - } -} - -// q must contain the modulus -// q is modified -// t = t mod q (t must be less than 2q) -// t is stored in zPtr -func (f *FFArm64) reduceAndStore(t, q []arm64.Register, zPtr arm64.Register) { - f.Comment("q = t - q") - f.SUBS(q[0], t[0], q[0]) - for i := 1; i < f.NbWords; i++ { - f.SBCS(q[i], t[i], q[i]) - } - - f.Comment("if no borrow, return q, else return t") - for i := 0; i < f.NbWords; i++ { - f.CSEL("CS", q[i], t[i], t[i]) - if i%2 == 1 { - f.STP(t[i-1], t[i], zPtr.At(i-1)) - } - } -} - -func (f *FFArm64) add0n(i int) func(op1, op2, dst interface{}, comment ...string) { - switch { - case i == 0: - return f.ADDS - case i == f.NbWordsLastIndex: - return f.ADC - default: - return f.ADCS - } -} - -func (f *FFArm64) add0m(i int) func(op1, op2, dst interface{}, comment ...string) { - switch { - case i == 0: - return f.ADDS - case i == f.NbWordsLastIndex+1: - return f.ADC - default: - return f.ADCS - } -} - -func (f *FFArm64) add1m(i int, dumb ...bool) func(op1, op2, dst interface{}, comment ...string) { - switch { - case i == 1: - return f.ADDS - case i == f.NbWordsLastIndex+1: - if len(dumb) == 1 && dumb[0] { - // odd, but it performs better on c8g instances. - return f.ADCS - } - return f.ADC - default: - return f.ADCS - } -} diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index 1efe7d8b3..f6e5513b4 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -267,7 +267,7 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) // asm code generation for moduli with more than 6 words can be optimized further F.GenerateOpsAMD64 = F.NoCarry && F.NbWords <= 12 && F.NbWords > 1 F.GenerateVectorOpsAMD64 = F.GenerateOpsAMD64 && F.NbWords == 4 && F.NbBits > 225 - F.GenerateOpsARM64 = F.GenerateOpsAMD64 && (F.NbWords == 6 || F.NbWords == 4) + F.GenerateOpsARM64 = F.GenerateOpsAMD64 && (F.NbWords%2 == 0) F.GenerateVectorOpsARM64 = false // setting Mu 2^288 / q diff --git a/field/generator/generator_test.go b/field/generator/generator_test.go index e490baa03..ee5b5fafc 100644 --- a/field/generator/generator_test.go +++ b/field/generator/generator_test.go @@ -84,6 +84,9 @@ func TestIntegration(t *testing.T) { assert.NoError(GenerateAMD64(7, asmDir, false)) assert.NoError(GenerateAMD64(8, asmDir, false)) + assert.NoError(GenerateARM64(2, asmDir, false)) + assert.NoError(GenerateARM64(8, asmDir, false)) + for elementName, modulus := range moduli { var fIntegration *field.FieldConfig // generate field diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index 271d7d116..52b2c0b2b 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -50,8 +50,14 @@ const OpsARM64 = ` // Butterfly sets // a = a + b (mod q) // b = a - b (mod q) +{{- if le .NbWords 6}} //go:noescape func Butterfly(a, b *{{.ElementName}}) +{{- else}} +func Butterfly(a, b *{{.ElementName}}) { + _butterflyGeneric(a, b) +} +{{- end}} //go:noescape func mul(res,x,y *{{.ElementName}}) diff --git a/go.mod b/go.mod index 4297d5054..c4dbc8bc8 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.1.23-0.20241021201139-ab3fee069cde + github.com/consensys/bavard v0.1.23-0.20241022191302-a6fdcdb6e8f3 github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 diff --git a/go.sum b/go.sum index 40ca76c74..af73f869f 100644 --- a/go.sum +++ b/go.sum @@ -61,6 +61,10 @@ github.com/consensys/bavard v0.1.23-0.20241019150039-28659c2eb91c h1:sK5i7h6ZVAj github.com/consensys/bavard v0.1.23-0.20241019150039-28659c2eb91c/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/bavard v0.1.23-0.20241021201139-ab3fee069cde h1:KXywceL5kuPe9PAQHHBvt4Kki7/XqsW7ABJI9dn4zik= github.com/consensys/bavard v0.1.23-0.20241021201139-ab3fee069cde/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.23-0.20241022191117-d73e50a886cc h1:NwWCvGXSPH8BYATHBdy7qTJ3NMoT1kWVAvuEPtvasqg= +github.com/consensys/bavard v0.1.23-0.20241022191117-d73e50a886cc/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.23-0.20241022191302-a6fdcdb6e8f3 h1:8gPxbjhwhxXTakOXII32eLlAFLlYImoENa3uQ6iP+go= +github.com/consensys/bavard v0.1.23-0.20241022191302-a6fdcdb6e8f3/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= diff --git a/internal/generator/main.go b/internal/generator/main.go index 45fbb4fe2..c17605a41 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -60,6 +60,8 @@ func main() { assertNoError(generator.GenerateARM64(4, asmDirBuildPath, false)) assertNoError(generator.GenerateARM64(6, asmDirBuildPath, false)) + assertNoError(generator.GenerateARM64(10, asmDirBuildPath, false)) + assertNoError(generator.GenerateARM64(12, asmDirBuildPath, false)) var wg sync.WaitGroup for _, conf := range config.Curves {