Skip to content

Commit

Permalink
pair: make hashing deterministic regardless of orders (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
juanli16 authored Oct 11, 2024
1 parent f600d47 commit 7ad605f
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 7 deletions.
22 changes: 15 additions & 7 deletions pkg/pair/pair.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ var (

// PrivateKey represents a PAIR private key.
type PrivateKey struct {
// h is the hash function used to hash the data
h hash.Hash
mode PAIRMode

// salt for h
salt []byte
Expand All @@ -41,11 +40,12 @@ type PrivateKey struct {
// New instantiates a new private key with the given salt and scalar.
// It expects the scalar to be base64 encoded.
func (p PAIRMode) New(salt []byte, scalar []byte) (*PrivateKey, error) {
pk := new(PrivateKey)
pk := &PrivateKey{
mode: p,
}

switch p {
case PAIRSHA256Ristretto255:
pk.h = crypto.SHA256.New()
if len(salt) != sha256SaltSize {
return nil, ErrInvalidSaltSize
}
Expand All @@ -64,11 +64,19 @@ func (p PAIRMode) New(salt []byte, scalar []byte) (*PrivateKey, error) {

// hash hashes the data using the private key's hash function with the salt.
func (pk *PrivateKey) hash(data []byte) []byte {
var h hash.Hash

switch pk.mode {
case PAIRSHA256Ristretto255:
h = crypto.SHA256.New()
default:
}

// salt the hash function
pk.h.Write(pk.salt)
h.Write(pk.salt)
// hash the data
pk.h.Write(data)
return pk.h.Sum(nil)
h.Write(data)
return h.Sum(nil)
}

// Encrypt first hashes the data with a salted hash function,
Expand Down
67 changes: 67 additions & 0 deletions pkg/pair/pair_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,73 @@ func TestPAIR(t *testing.T) {
}
}

func TestDeterministicEncryption(t *testing.T) {
var (
salt = make([]byte, sha256SaltSize)
scalar = ristretto255.NewScalar()
)

if _, err := rand.Read(salt); err != nil {
t.Fatal(err)
}

// sha512 produces a 64-byte psuedo-uniformized data
src := sha512.Sum512(salt)
scalar.FromUniformBytes(src[:])
sk, err := scalar.MarshalText()
if err != nil {
t.Fatalf("failed to marshal the scalar: %s", err.Error())
}

// Create a new PAIR instance
pairID := PAIRSHA256Ristretto255

pair, err := pairID.New(salt, sk)
if err != nil {
t.Fatalf("failed to instantiate a new PAIR instance: %s", err.Error())
}

var (
id1 = []byte("[email protected]")
id2 = []byte("[email protected]")
)

// Encrypt the data
ciphertext1, err := pair.Encrypt(id1)
if err != nil {
t.Fatalf("failed to encrypt the data: %s", err.Error())
}

ciphertext2, err := pair.Encrypt(id2)
if err != nil {
t.Fatalf("failed to encrypt the data: %s", err.Error())
}

// reverse order
reversePair, err := pairID.New(salt, sk)
if err != nil {
t.Fatalf("failed to instantiate a new PAIR instance: %s", err.Error())
}

ciphertext3, err := reversePair.Encrypt(id2)
if err != nil {
t.Fatalf("failed to encrypt the data: %s", err.Error())
}

ciphertext4, err := reversePair.Encrypt(id1)
if err != nil {
t.Fatalf("failed to encrypt the data: %s", err.Error())
}

if strings.Compare(string(ciphertext1), string(ciphertext4)) != 0 {
t.Fatalf("want: %s, got: %s", string(ciphertext1), string(ciphertext4))
}

if strings.Compare(string(ciphertext2), string(ciphertext3)) != 0 {
t.Fatalf("want: %s, got: %s", string(ciphertext2), string(ciphertext3))
}
}

func genData(n int) [][]byte {
data := make([][]byte, n)
for i := 0; i < n; i++ {
Expand Down

0 comments on commit 7ad605f

Please sign in to comment.