From 7ad605f44fc3889ad159ee551e7e864f71248dc1 Mon Sep 17 00:00:00 2001 From: Justin Li Date: Fri, 11 Oct 2024 14:47:43 -0400 Subject: [PATCH] pair: make hashing deterministic regardless of orders (#76) --- pkg/pair/pair.go | 22 +++++++++----- pkg/pair/pair_test.go | 67 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 7 deletions(-) diff --git a/pkg/pair/pair.go b/pkg/pair/pair.go index c5c6c12..3dc7caa 100644 --- a/pkg/pair/pair.go +++ b/pkg/pair/pair.go @@ -28,8 +28,7 @@ var ( // PrivateKey represents a PAIR private key. type PrivateKey struct { - // h is the hash function used to hash the data - h hash.Hash + mode PAIRMode // salt for h salt []byte @@ -41,11 +40,12 @@ type PrivateKey struct { // New instantiates a new private key with the given salt and scalar. // It expects the scalar to be base64 encoded. func (p PAIRMode) New(salt []byte, scalar []byte) (*PrivateKey, error) { - pk := new(PrivateKey) + pk := &PrivateKey{ + mode: p, + } switch p { case PAIRSHA256Ristretto255: - pk.h = crypto.SHA256.New() if len(salt) != sha256SaltSize { return nil, ErrInvalidSaltSize } @@ -64,11 +64,19 @@ func (p PAIRMode) New(salt []byte, scalar []byte) (*PrivateKey, error) { // hash hashes the data using the private key's hash function with the salt. func (pk *PrivateKey) hash(data []byte) []byte { + var h hash.Hash + + switch pk.mode { + case PAIRSHA256Ristretto255: + h = crypto.SHA256.New() + default: + } + // salt the hash function - pk.h.Write(pk.salt) + h.Write(pk.salt) // hash the data - pk.h.Write(data) - return pk.h.Sum(nil) + h.Write(data) + return h.Sum(nil) } // Encrypt first hashes the data with a salted hash function, diff --git a/pkg/pair/pair_test.go b/pkg/pair/pair_test.go index 1542e0d..a45f375 100644 --- a/pkg/pair/pair_test.go +++ b/pkg/pair/pair_test.go @@ -62,6 +62,73 @@ func TestPAIR(t *testing.T) { } } +func TestDeterministicEncryption(t *testing.T) { + var ( + salt = make([]byte, sha256SaltSize) + scalar = ristretto255.NewScalar() + ) + + if _, err := rand.Read(salt); err != nil { + t.Fatal(err) + } + + // sha512 produces a 64-byte psuedo-uniformized data + src := sha512.Sum512(salt) + scalar.FromUniformBytes(src[:]) + sk, err := scalar.MarshalText() + if err != nil { + t.Fatalf("failed to marshal the scalar: %s", err.Error()) + } + + // Create a new PAIR instance + pairID := PAIRSHA256Ristretto255 + + pair, err := pairID.New(salt, sk) + if err != nil { + t.Fatalf("failed to instantiate a new PAIR instance: %s", err.Error()) + } + + var ( + id1 = []byte("alice@hello.com") + id2 = []byte("bob@hello.com") + ) + + // Encrypt the data + ciphertext1, err := pair.Encrypt(id1) + if err != nil { + t.Fatalf("failed to encrypt the data: %s", err.Error()) + } + + ciphertext2, err := pair.Encrypt(id2) + if err != nil { + t.Fatalf("failed to encrypt the data: %s", err.Error()) + } + + // reverse order + reversePair, err := pairID.New(salt, sk) + if err != nil { + t.Fatalf("failed to instantiate a new PAIR instance: %s", err.Error()) + } + + ciphertext3, err := reversePair.Encrypt(id2) + if err != nil { + t.Fatalf("failed to encrypt the data: %s", err.Error()) + } + + ciphertext4, err := reversePair.Encrypt(id1) + if err != nil { + t.Fatalf("failed to encrypt the data: %s", err.Error()) + } + + if strings.Compare(string(ciphertext1), string(ciphertext4)) != 0 { + t.Fatalf("want: %s, got: %s", string(ciphertext1), string(ciphertext4)) + } + + if strings.Compare(string(ciphertext2), string(ciphertext3)) != 0 { + t.Fatalf("want: %s, got: %s", string(ciphertext2), string(ciphertext3)) + } +} + func genData(n int) [][]byte { data := make([][]byte, n) for i := 0; i < n; i++ {