diff --git a/crypto/crypto.go b/crypto/crypto.go index b2a4c8f7b..a592c4606 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -199,7 +199,7 @@ func (client *Crypto) Migrate() error { for _, keyNameVersion := range keys { var keyRef orm.KeyReference // find existing record, if it exists do nothing - err := tx.WithContext(ctx).Model(&orm.KeyReference{}).Where("key_name = ? and version = ?", keyNameVersion.KeyName, keyNameVersion.KeyName).First(&keyRef).Error + err := tx.WithContext(ctx).Model(&orm.KeyReference{}).Where("key_name = ? and version = ?", keyNameVersion.KeyName, keyNameVersion.Version).First(&keyRef).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // create a new key reference diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go index a5eed99de..76d8ed795 100644 --- a/crypto/crypto_test.go +++ b/crypto/crypto_test.go @@ -61,12 +61,11 @@ func TestCrypto_Exists(t *testing.T) { } func TestCrypto_Migrate(t *testing.T) { - backend := NewMemoryStorage() - db := orm.NewTestDatabase(t) - client := &Crypto{backend: backend, db: db} - + keypair, _ := spi.GenerateKeyPair() t.Run("ok - 1 key migrated", func(t *testing.T) { - keypair, _ := spi.GenerateKeyPair() + backend := NewMemoryStorage() + db := orm.NewTestDatabase(t) + client := &Crypto{backend: backend, db: db} err := backend.SavePrivateKey(context.Background(), "test", keypair) require.NoError(t, err) @@ -80,9 +79,26 @@ func TestCrypto_Migrate(t *testing.T) { t.Run("ok - already exists", func(t *testing.T) { err = client.Migrate() + assert.NoError(t, err) }) }) + t.Run("don't migrate new keys", func(t *testing.T) { + backend := NewMemoryStorage() + db := orm.NewTestDatabase(t) + client := &Crypto{backend: backend, db: db} + err := backend.SavePrivateKey(context.Background(), "some-uuid", keypair) + require.NoError(t, err) + + err = db.Save(&orm.KeyReference{KID: "vm-id", KeyName: "some-uuid", Version: "1"}).Error + require.NoError(t, err) + + err = client.Migrate() + require.NoError(t, err) + + keys := client.List(context.Background()) + require.Len(t, keys, 1) + }) } func TestCrypto_New(t *testing.T) {