Skip to content

Commit

Permalink
Rejection sampling (#88)
Browse files Browse the repository at this point in the history
Implemented rejection sampling in the construction of ZK proofs + update generators to match rejection sampling apporach

Co-authored-by: Tore Frederiksen <[email protected]>
Co-authored-by: Weiwu Zhang <[email protected]>
  • Loading branch information
3 people authored Jan 6, 2021
1 parent 79cc879 commit 17bfa62
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 63 deletions.
108 changes: 64 additions & 44 deletions src/main/java/com/alphawallet/attestation/core/AttestationCrypto.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.security.Security;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.bouncycastle.asn1.sec.SECNamedCurves;
Expand Down Expand Up @@ -37,23 +38,34 @@ public class AttestationCrypto {
public static final BigInteger fieldSize = new BigInteger("21888242871839275222246405745257275088696311157297823662689037894645226208583");
// IMPORTANT: if another group is used then curveOrder should be the largest subgroup order
public static final BigInteger curveOrder = new BigInteger("21888242871839275222246405745257275088548364400416034343698204186575808495617");
// NOTE: Curve order for BN256 is 254 bit
public static final int curveOrderBitLength = curveOrder.bitLength(); // minus 1 since the bitcount includes an extra bit for sign since BigInteger is two's complement
public static final BigInteger cofactor = new BigInteger("1");
public static final ECCurve curve = new Fp(fieldSize, BigInteger.ZERO, new BigInteger("3"), curveOrder, cofactor);
// Generator for message part of Pedersen commitments generated deterministically from mapToInteger queried on 0 and mapped to the curve using try-and-increment
public static final ECPoint G = curve.createPoint(new BigInteger("12022136709705892117842496518378933837282529509560188557390124672992517127582"), new BigInteger("6765325636686621066142015726326349598074684595222800743368698766652936798612"));
public static final ECPoint G = curve.createPoint(new BigInteger("15729599519504045482191519010597390184315499143087863467258091083496429125073"), new BigInteger("1368880882406055711853124887741765079727455879193744504977106900552137574951"));
// Generator for randomness part of Pedersen commitments generated deterministically from mapToInteger queried on 1 to the curve using try-and-increment
public static final ECPoint H = curve.createPoint(new BigInteger("12263903704889727924109846582336855803381529831687633314439453294155493615168"), new BigInteger("1637819407897162978922461013726819811885734067940976901570219278871042378189"));
public static final ECPoint H = curve.createPoint(new BigInteger("10071451177251346351593122552258400731070307792115572537969044314339076126231"), new BigInteger("2894161621123416739138844080004799398680035544501805450971689609134516348045"));
private final SecureRandom rand;

public AttestationCrypto(SecureRandom rand) {
Security.addProvider(new BouncyCastleProvider());
this.rand = rand;
// Verify that fieldSize = 3 mod 4, otherwise the crypto won't work
if (!fieldSize.mod(new BigInteger("4")).equals(new BigInteger("3"))) {
throw new RuntimeException("The crypto will not work with this choice of curve");
if (!verifyCurveOrder(curveOrder)) {
throw new RuntimeException("Static values do not work with current implementation");
}
}

private boolean verifyCurveOrder(BigInteger curveOrder) {
// Verify that the curve order is less than 2^256 bits, which is required by mapToCurveMultiplier
// Specifically checking if it is larger than 2^curveOrderBitLength and that no bits at position curveOrderBitLength+1 or larger are set
if (curveOrder.compareTo(BigInteger.ONE.shiftLeft(curveOrderBitLength-1)) < 0 || curveOrder.shiftRight(curveOrderBitLength).compareTo(BigInteger.ZERO) > 0) {
System.err.println("Curve order is not 253 bits which is required by the current implementation");
return false;
}
return true;
}

/**
* Code shamelessly stolen from https://medium.com/@fixone/ecc-for-ethereum-on-android-7e35dc6624c9
* @param key
Expand Down Expand Up @@ -94,7 +106,7 @@ public AsymmetricCipherKeyPair constructECKeys() {
* @return
*/
public static byte[] makeCommitment(String identity, AttestationType type, BigInteger secret) {
BigInteger hashedIdentity = mapToInteger(type.ordinal(), identity);
BigInteger hashedIdentity = mapToCurveMultiplier(type, identity);
// Construct Pedersen commitment
ECPoint commitment = G.multiply(hashedIdentity).add(H.multiply(secret));
return commitment.getEncoded(false);
Expand All @@ -109,7 +121,7 @@ public static byte[] makeCommitment(String identity, AttestationType type, BigIn
* @return
*/
public static byte[] makeCommitment(String identity, AttestationType type, ECPoint hiding) {
BigInteger hashedIdentity = mapToInteger(type.ordinal(), identity);
BigInteger hashedIdentity = mapToCurveMultiplier(type, identity);
// Construct Pedersen commitment
ECPoint commitment = G.multiply(hashedIdentity).add(hiding);
return commitment.getEncoded(false);
Expand All @@ -125,11 +137,8 @@ public static byte[] makeCommitment(String identity, AttestationType type, ECPoi
public ProofOfExponent computeAttestationProof(BigInteger randomness) {
// Compute the random part of the commitment, i.e. H^randomness
ECPoint riddle = H.multiply(randomness);
BigInteger r = makeSecret();
ECPoint t = H.multiply(r);
BigInteger c = mapToInteger(makeArray(Arrays.asList(G, H, riddle, t))).mod(curveOrder);
BigInteger d = r.add(c.multiply(randomness)).mod(curveOrder);
return new ProofOfExponent(H, riddle.normalize(), t.normalize(), d);
List<ECPoint> challengeList = Arrays.asList(G, H, riddle);
return constructSchnorrPOK(riddle, randomness, challengeList);
}

/**
Expand All @@ -156,11 +165,29 @@ public ProofOfExponent computeEqualityProof(byte[] commitment1, byte[] commitmen
ECPoint comPoint2 = decodePoint(commitment2);
// Compute H*(randomness1-randomness2=commitment1-commitment2=G*msg+H*randomness1-G*msg+H*randomness2
ECPoint riddle = comPoint1.subtract(comPoint2);
BigInteger hiding = makeSecret();
ECPoint t = H.multiply(hiding);
// TODO ideally Bob's ethreum address should also be part of the challenge
BigInteger c = mapToInteger(makeArray(Arrays.asList(G, H, comPoint1, comPoint2, t))).mod(curveOrder);
BigInteger d = hiding.add(c.multiply(randomness1.subtract(randomness2))).mod(curveOrder);
BigInteger exponent = randomness1.subtract(randomness2).mod(curveOrder);
List<ECPoint> challengeList = Arrays.asList(G, H, comPoint1, comPoint2);
return constructSchnorrPOK(riddle, exponent, challengeList);
}

/**
* Constructs a Schnorr proof of knowledge of exponent of a riddle to base H.
* The challenge value used (c) is computed from the challengeList and the internal t value.
* The method uses rejection sampling to ensure that the t value is sampled s.t. the
* challenge will always be less than curveOrder.
*/
private ProofOfExponent constructSchnorrPOK(ECPoint riddle, BigInteger exponent, List<ECPoint> challengeList) {
ECPoint t;
BigInteger c, d;
// Use rejection sampling to sample a hiding value s.t. the random oracle challenge c computed from it is less than curveOrder
do {
BigInteger hiding = makeSecret();
t = H.multiply(hiding);
List<ECPoint> finalChallengeList = new ArrayList<>(challengeList);
finalChallengeList.add(t);
c = mapTo256BitInteger(makeArray(finalChallengeList));
d = hiding.add(c.multiply(exponent)).mod(curveOrder);
} while (c.compareTo(curveOrder) >= 0);
return new ProofOfExponent(H, riddle.normalize(), t.normalize(), d);
}

Expand All @@ -170,7 +197,7 @@ public ProofOfExponent computeEqualityProof(byte[] commitment1, byte[] commitmen
* @return True if the proof is OK and false otherwise
*/
public static boolean verifyAttestationRequestProof(ProofOfExponent pok) {
BigInteger c = mapToInteger(makeArray(Arrays.asList(G, pok.getBase(), pok.getRiddle(), pok.getPoint()))).mod(curveOrder);
BigInteger c = mapTo256BitInteger(makeArray(Arrays.asList(G, pok.getBase(), pok.getRiddle(), pok.getPoint())));
// Ensure that the right base has been used in the proof
if (!pok.getBase().equals(H)) {
return false;
Expand Down Expand Up @@ -199,7 +226,7 @@ public static boolean verifyEqualityProof(byte[] commitment1, byte[] commitment2
if (!pok.getBase().equals(H)) {
return false;
}
BigInteger c = mapToInteger(makeArray(Arrays.asList(G, pok.getBase(), comPoint1, comPoint2, pok.getPoint()))).mod(curveOrder);
BigInteger c = mapTo256BitInteger(makeArray(Arrays.asList(G, pok.getBase(), comPoint1, comPoint2, pok.getPoint())));
return verifyPok(pok, c);
}

Expand All @@ -213,7 +240,7 @@ public BigInteger makeSecret() {
return new BigInteger(256+128, rand).mod(curveOrder);
}

private static byte[] makeArray(List<ECPoint> points ) {
static byte[] makeArray(List<ECPoint> points ) {
try {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
for (ECPoint current : points) {
Expand All @@ -228,44 +255,37 @@ private static byte[] makeArray(List<ECPoint> points ) {
}

/**
* Map a byte array into a Big Integer using an double execution of Keccak 256.
* @param value
* @return
* Map a byte array into a uniformly random 256 bit (positive) integer, stored as a Big Integer.
*/
private static BigInteger mapToInteger(byte[] value) {
static BigInteger mapTo256BitInteger(byte[] input) {
try {
MessageDigest KECCAK = new Keccak.Digest256();
KECCAK.reset();
KECCAK.update((byte) 0);
KECCAK.update(value);
byte[] hash0 = KECCAK.digest();
KECCAK.reset();
KECCAK.update((byte) 1);
KECCAK.update(value);
byte[] hash1 = KECCAK.digest();
byte[] res = new byte[32*2];
System.arraycopy(hash0, 0, res, 0, hash0.length);
System.arraycopy(hash1, 0, res, hash0.length, hash1.length);
// Note that we use double hashing to get a digest that is at least fieldSize or curve order
// + security parameter in length to avoid any potential bias
return new BigInteger(res);
// In case of failure we rehash using the old output
KECCAK.update(input);
byte[] digest = KECCAK.digest();
// Construct an positive BigInteger from the bytes
return new BigInteger(1, digest);
} catch (Exception e) {
throw new RuntimeException(e);
}
}

/**
*
* @param type
* @param identity
* @return
* Maps and identifier of a certain type to an integer deterministic, yet sampled from
* the uniformly random distribution between 0 and curveOrder -1.
* This is done using deterministic rejection sampling based on the input.
*/
public static BigInteger mapToInteger(int type, String identity) {
public static BigInteger mapToCurveMultiplier(AttestationType type, String identity) {
byte[] identityBytes = identity.trim().toLowerCase().getBytes(StandardCharsets.UTF_8);
ByteBuffer buf = ByteBuffer.allocate(4 + identityBytes.length);
buf.putInt(type);
buf.putInt(type.ordinal());
buf.put(identityBytes);
return mapToInteger(buf.array());
BigInteger sampledVal = new BigInteger(1, buf.array());
do {
sampledVal = mapTo256BitInteger(sampledVal.toByteArray());
} while (sampledVal.compareTo(curveOrder) >= 0);
return sampledVal;
}

public static ECPoint decodePoint(byte[] point) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package com.alphawallet.attestation;
package com.alphawallet.attestation.core;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.alphawallet.attestation.IdentifierAttestation.AttestationType;
import com.alphawallet.attestation.ProofOfExponent;
import com.alphawallet.attestation.core.AttestationCrypto;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.math.BigInteger;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
Expand All @@ -18,8 +22,10 @@
import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.crypto.params.ECPublicKeyParameters;
import org.bouncycastle.math.ec.ECPoint;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestTemplate;

public class CryptoTest {
private AsymmetricCipherKeyPair subjectKeys;
Expand All @@ -43,6 +49,24 @@ public void setupCrypto() throws NoSuchAlgorithmException {
senderKeys = crypto.constructECKeys();
}

@Test
public void tooSmallCurveOrder() throws Exception {
Method verifyCurveOrder = AttestationCrypto.class.getDeclaredMethod("verifyCurveOrder", BigInteger.class);
verifyCurveOrder.setAccessible(true);
// Set 2^253-1
BigInteger smallCurveOrder = BigInteger.ONE.shiftLeft(253).subtract(BigInteger.ONE);
assertFalse((boolean) verifyCurveOrder.invoke(crypto, smallCurveOrder));
}

@Test
public void tooLargeCurveOrder() throws Exception {
Method verifyCurveOrder = AttestationCrypto.class.getDeclaredMethod("verifyCurveOrder", BigInteger.class);
verifyCurveOrder.setAccessible(true);
// Set the final curveOrder field to 2^254
BigInteger largeCurveOrder = BigInteger.ONE.shiftLeft(254);
assertFalse((boolean) verifyCurveOrder.invoke(crypto, largeCurveOrder));
}

@Test
public void testAddressFromKey() {
String key = AttestationCrypto.addressFromKey(subjectKeys.getPublic());
Expand Down Expand Up @@ -208,6 +232,18 @@ public void testEqualityProof() {
assertFalse(AttestationCrypto.verifyEqualityProof(com1, com2, pok2));
}

@Test
public void testRejectionSamplingInEqualityProof() {
for (int i = 1; i < 40; i++) {
byte[] com1 = AttestationCrypto.makeCommitment(ID+i, TYPE, SECRET1.add(BigInteger.valueOf(i)));
byte[] com2 = AttestationCrypto.makeCommitment(ID+i, TYPE, SECRET2.multiply(BigInteger.valueOf(i)));
ProofOfExponent pok = crypto.computeEqualityProof(com1, com2, SECRET1.add(BigInteger.valueOf(i)), SECRET2.multiply(BigInteger.valueOf(i)));
// Compute the c value used in the proof and for proof verification
BigInteger c = AttestationCrypto.mapTo256BitInteger(AttestationCrypto.makeArray(Arrays.asList(AttestationCrypto.G, pok.getBase(), AttestationCrypto.decodePoint(com1), AttestationCrypto.decodePoint(com2), pok.getPoint())));
assertTrue(c.compareTo(AttestationCrypto.curveOrder) < 0);
}
}

@Test
public void testMakeSecret() {
BigInteger sec = crypto.makeSecret();
Expand All @@ -219,29 +255,51 @@ public void testMakeSecret() {
}

@Test
public void testMapToInteger() {
BigInteger value = AttestationCrypto.mapToInteger(TYPE.ordinal(), ID);
public void testMapToCurveMultiplier() {
BigInteger value = AttestationCrypto.mapToCurveMultiplier(TYPE, ID);
// Sanity checks
assertFalse(value.equals(BigInteger.ZERO));
assertFalse(value.equals(BigInteger.ONE));
assertFalse(value.equals(AttestationCrypto.curveOrder));
assertFalse(value.equals(AttestationCrypto.fieldSize));
assertFalse(value.compareTo(AttestationCrypto.curveOrder) >= 0);
assertFalse(value.compareTo(AttestationCrypto.fieldSize) >= 0);
assertFalse(value.equals(AttestationCrypto.curveOrder.subtract(BigInteger.ONE)));
assertFalse(value.equals(AttestationCrypto.fieldSize.subtract(BigInteger.ONE)));
// This should hold with probability at least 1-2^-30
assertTrue(value.shiftRight(AttestationCrypto.curveOrderBitLength-30).compareTo(BigInteger.ZERO) > 0);

// Check consistency
BigInteger value2 = AttestationCrypto.mapToInteger(TYPE.ordinal(), ID);
BigInteger value2 = AttestationCrypto.mapToCurveMultiplier(TYPE, ID);
assertEquals(value, value2);

// Negative tests
value2 = AttestationCrypto.mapToInteger(TYPE.ordinal(), "test");
value2 = AttestationCrypto.mapToCurveMultiplier(TYPE, "test");
assertNotEquals(value, value2);
value2 = AttestationCrypto.mapToInteger(TYPE.ordinal(), ID + " 1");
value2 = AttestationCrypto.mapToCurveMultiplier(TYPE, ID + " 1");
assertNotEquals(value, value2);
value2 = AttestationCrypto.mapToInteger(AttestationType.PHONE.ordinal(), ID);
value2 = AttestationCrypto.mapToCurveMultiplier(AttestationType.PHONE, ID);
assertNotEquals(value, value2);
}

@Test
public void verifyLargeOutputOfMapToMultiplier() {
int counter = 0;
// Except with probability 2^-40 we should get at least one result that is curveOrderBitLength long,
// hence we ensure that the result of mapToCurveMultiplier is greater than 0 when shifting curveOrderBitLength to the right
for (int i = 0; i < 40; i++) {
BigInteger res = AttestationCrypto.mapToCurveMultiplier(TYPE, Integer.toString(i));
if (res.shiftRight(AttestationCrypto.curveOrderBitLength-1).compareTo(BigInteger.ZERO) > 0) {
counter++;
}
// This should hold with probability at least 1-2^-30
assertTrue(res.shiftRight(AttestationCrypto.curveOrderBitLength-30).compareTo(BigInteger.ZERO) > 0);
// Sanity check
assertFalse(res.equals(BigInteger.ZERO));
assertFalse(res.equals(BigInteger.ONE));
assertFalse(res.compareTo(AttestationCrypto.curveOrder) >= 0);
}
assertTrue(counter > 0);
}

@Test
public void testConstructAttRequestProof() throws NoSuchAlgorithmException{
SecureRandom rand2 = SecureRandom.getInstance("SHA1PRNG");
Expand Down Expand Up @@ -281,40 +339,45 @@ public void testDecodePoint() {

/**
* This test is here to show that we have nothing-up-our-sleeve in picking the generators
* @throws Exception
*/
@Test
public void computeGenerators() throws Exception {
public void computeGenerators() {
assertFalse(AttestationCrypto.G.add(AttestationCrypto.G).isInfinity());
assertFalse(AttestationCrypto.H.add(AttestationCrypto.H).isInfinity());

Method mapToInteger = AttestationCrypto.class.getDeclaredMethod("mapToInteger", byte[].class);
mapToInteger.setAccessible(true);

byte[] input = new byte[1];
input[0] = 0;
BigInteger gVal = (BigInteger) mapToInteger.invoke(crypto, input);
BigInteger gVal = rejectionSample(BigInteger.ZERO);
ECPoint g = computePoint(gVal);
assertEquals(AttestationCrypto.G, g);
// Check order
assertTrue(g.multiply(AttestationCrypto.curveOrder).isInfinity());
assertArrayEquals(g.multiply(AttestationCrypto.curveOrder.subtract(BigInteger.ONE)).normalize().getXCoord().getEncoded(), g.normalize().getXCoord().getEncoded());
input[0] = 1;
BigInteger hVal = (BigInteger) mapToInteger.invoke(crypto, input);

BigInteger hVal = rejectionSample(BigInteger.ONE);
ECPoint h = computePoint(hVal);
assertEquals(AttestationCrypto.H, h);
// Check order
assertTrue(h.multiply(AttestationCrypto.curveOrder).isInfinity());
assertArrayEquals(h.multiply(AttestationCrypto.curveOrder.subtract(BigInteger.ONE)).normalize().getXCoord().getEncoded(), h.normalize().getXCoord().getEncoded());
}

private BigInteger rejectionSample(BigInteger seed) {
do {
seed = AttestationCrypto.mapTo256BitInteger(seed.toByteArray());
} while (seed.compareTo(AttestationCrypto.curveOrder) >= 0);
return seed;
}

/**
* Compute a specific point on the curve (generator) based on x using the try-and-increment method
* https://eprint.iacr.org/2009/226.pdf
* @param x The x-coordinate for which we will compute y
* @return A corresponding y coordinate for x
*/
private static ECPoint computePoint(BigInteger x) {
// Verify that fieldSize = 3 mod 4, otherwise the crypto won't work
if (!AttestationCrypto.fieldSize.mod(new BigInteger("4")).equals(new BigInteger("3"))) {
throw new RuntimeException("The crypto will not work with this choice of curve");
}
x = x.mod(AttestationCrypto.fieldSize);
BigInteger ySquare, quadraticResidue;
ECPoint resPoint, referencePoint;
Expand Down

0 comments on commit 17bfa62

Please sign in to comment.