From e217df94f12bb7863c50e3e72d40893a157d488a Mon Sep 17 00:00:00 2001 From: Lajos Koszti Date: Wed, 4 May 2022 20:28:57 +0200 Subject: [PATCH] make postgres queries composable into transaction Create sub functions which requires sql.Tx as parameters, so they will run queries in transaction. The queries can be combined and reused in functions That's a start to make possible to roll back if we detect that client can't read the secret (bad decrypt key for example) --- api/api_test.go | 38 +-- entries/entry_meta_test.go | 21 ++ storage/integration/integration_test.go | 16 +- storage/postgresql/postgresql_storage.go | 226 +++++++----------- storage/postgresql/postgresql_storage_test.go | 85 ++----- storage/secret/secret_storage.go | 17 -- storage/secret/secret_storage_test.go | 19 +- storage/storage.go | 6 - 8 files changed, 172 insertions(+), 256 deletions(-) create mode 100644 entries/entry_meta_test.go diff --git a/api/api_test.go b/api/api_test.go index a020141..8a69128 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -164,7 +164,7 @@ func TestCreateEntryForm(t *testing.T) { value := "Foo" connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) t.Cleanup(func() { - connection.Close() + defer connection.Close() }) data, multi, err := createMultipart(map[string]io.Reader{ @@ -233,13 +233,13 @@ func TestRequestPathsCreateEntry(t *testing.T) { {Name: "/ path", Path: "/", StatusCode: 200}, {Name: "Longer path", Path: "/other", StatusCode: 404}, } + connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) + t.Cleanup(func() { + connection.Close() + }) for _, testCase := range testCases { t.Run(testCase.Name, func(t *testing.T) { - connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) - t.Cleanup(func() { - connection.Close() - }) req := httptest.NewRequest("POST", fmt.Sprintf("http://example.com%s", testCase.Path), bytes.NewReader([]byte("ASDF"))) w := httptest.NewRecorder() NewSecretHandler(NewHandlerConfig(connection)).ServeHTTP(w, req) @@ -264,17 +264,17 @@ func TestGetEntry(t *testing.T) { { "first", "foo", - "3f356f6c-c8b1-4b48-8243-aa04d07b8873", + uuid.NewUUIDString(), }, } + connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) + t.Cleanup(func() { + connection.Close() + }) + for _, testCase := range testCases { t.Run(testCase.Name, func(t *testing.T) { - connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) - t.Cleanup(func() { - connection.Close() - }) - k := key.NewKey() if err := k.Generate(); err != nil { t.Error(err) @@ -303,13 +303,12 @@ func TestGetEntry(t *testing.T) { } }) } - } func TestGetEntryJSON(t *testing.T) { connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) t.Cleanup(func() { - connection.Close() + defer connection.Close() }) testCase := struct { Name string @@ -319,7 +318,7 @@ func TestGetEntryJSON(t *testing.T) { "first", "foo", - "3f356f6c-c8b1-4b48-8243-aa04d07b8873", + uuid.NewUUIDString(), } k := key.NewKey() @@ -335,7 +334,10 @@ func TestGetEntryJSON(t *testing.T) { } ctx := context.Background() - connection.Write(ctx, testCase.UUID, encryptedData, time.Second*10, 1) + if err := connection.Write(ctx, testCase.UUID, encryptedData, time.Second*10, 1); err != nil { + t.Error(err) + } + fmt.Println("Wrote", testCase.UUID) req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/%s/%s", testCase.UUID, hex.EncodeToString(rsakey)), nil) req.Header.Add("Accept", "application/json") @@ -343,6 +345,10 @@ func TestGetEntryJSON(t *testing.T) { NewSecretHandler(NewHandlerConfig(connection)).ServeHTTP(w, req) resp := w.Result() + fmt.Println(resp.Header) + if resp.StatusCode != 200 { + t.Errorf("non 200 http statuscode: %d", resp.StatusCode) + } var encode entries.SecretResponse err = json.NewDecoder(resp.Body).Decode(&encode) @@ -540,7 +546,7 @@ func FuzzSetAndGetEntry(f *testing.F) { } connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) f.Cleanup(func() { - connection.Close() + defer connection.Close() }) f.Fuzz(func(t *testing.T, testCase string) { diff --git a/entries/entry_meta_test.go b/entries/entry_meta_test.go new file mode 100644 index 0000000..c1656cf --- /dev/null +++ b/entries/entry_meta_test.go @@ -0,0 +1,21 @@ +package entries + +import ( + "testing" + "time" +) + +func Test_EntryMeta(t *testing.T) { + expire := time.Now() + + meta := EntryMeta{Expire: expire.Add(time.Second)} + + if meta.IsExpired() { + t.Error("entry meta should not be expired") + } + + meta = EntryMeta{Expire: expire.Add(-time.Second)} + if !meta.IsExpired() { + t.Error("entry meta should be expired") + } +} diff --git a/storage/integration/integration_test.go b/storage/integration/integration_test.go index f32fb35..847568e 100644 --- a/storage/integration/integration_test.go +++ b/storage/integration/integration_test.go @@ -15,20 +15,16 @@ import ( ) func TestStorages(t *testing.T) { - connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) + psqlStorage := postgresql.NewStorage(testhelper.GetPSQLTestConn()) t.Cleanup(func() { - connection.Close() + psqlStorage.Close() }) - psqlStorage := postgresql.NewPostgresCleanableStorage(connection) - storages := map[string]storage.Cleanable{ + storages := map[string]storage.Storage{ "Postgres": psqlStorage, - "Secret": secret.NewCleanableSecretStorage( - secret.NewSecretStorage( - psqlStorage, - dummy.NewEncrypter(), - ), + "Secret": secret.NewSecretStorage( psqlStorage, + dummy.NewEncrypter(), ), } @@ -54,6 +50,7 @@ func TestStorages(t *testing.T) { t.Errorf("Expected expire error but got %v", err) } }) + t.Run("Read", func(t *testing.T) { UUID := uuid.NewUUIDString() err := storage.Write(ctx, UUID, []byte("foo"), time.Second*-10, 1) @@ -72,6 +69,7 @@ func TestStorages(t *testing.T) { t.Errorf("Expected expire error but got %v", err) } }) + t.Run("Delete", func(t *testing.T) { UUID := uuid.NewUUIDString() err := storage.Write(ctx, UUID, []byte("foo"), time.Second*-10, 1) diff --git a/storage/postgresql/postgresql_storage.go b/storage/postgresql/postgresql_storage.go index fd3a4b7..7ac3ee8 100644 --- a/storage/postgresql/postgresql_storage.go +++ b/storage/postgresql/postgresql_storage.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "fmt" - "log" "strings" "time" @@ -27,15 +26,18 @@ func (s Storage) Close() error { // Write stores a new entry in database func (s Storage) Write(ctx context.Context, UUID string, entry []byte, expire time.Duration, remainingReads int) error { - now := time.Now() - k, err := key.NewGeneratedKey() + tx, err := s.db.BeginTx(ctx, nil) if err != nil { return err } - deleteKey := k.ToHex() - _, err = s.db.ExecContext(ctx, `INSERT INTO entries (uuid, data, created, expire, remaining_reads, delete_key) VALUES ($1, $2, $3, $4, $5, $6) RETURNING uuid, delete_key;`, UUID, entry, now, now.Add(expire), remainingReads, deleteKey) - return err + if err = s.write(tx, UUID, entry, expire, remainingReads); err != nil { + tx.Rollback() + return err + } + + tx.Commit() + return nil } // ReadMeta to get entry metadata (without the actual secret) @@ -47,7 +49,59 @@ func (s Storage) ReadMeta(ctx context.Context, UUID string) (*entries.EntryMeta, return nil, err } - row := tx.QueryRowContext(ctx, ` + meta, err := s.readMeta(tx, UUID) + if err != nil { + tx.Rollback() + if err == sql.ErrNoRows { + return nil, entries.ErrEntryNotFound + } + + return nil, err + } + + if meta.IsExpired() { + if err := s.setAccessed(tx, UUID); err != nil { + tx.Rollback() + return nil, err + } + + if err = tx.Commit(); err != nil { + return nil, err + } + + return nil, entries.ErrEntryExpired + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return meta, nil +} + +func (s Storage) write(tx *sql.Tx, UUID string, entry []byte, expire time.Duration, remainingReads int) error { + now := time.Now() + k, err := key.NewGeneratedKey() + if err != nil { + return err + } + deleteKey := k.ToHex() + + _, err = tx.Exec(`INSERT INTO entries (uuid, data, created, expire, remaining_reads, delete_key) VALUES ($1, $2, $3, $4, $5, $6) RETURNING uuid, delete_key;`, UUID, entry, now, now.Add(expire), remainingReads, deleteKey) + + return err +} + +func (s Storage) setAccessed(tx *sql.Tx, UUID string) error { + if _, err := tx.Exec("UPDATE entries SET accessed=$1 WHERE uuid=$2", time.Now(), UUID); err != nil { + return err + } + + return nil +} + +func (s Storage) readMeta(tx *sql.Tx, UUID string) (*entries.EntryMeta, error) { + row := tx.QueryRow(` SELECT created, accessed, @@ -58,6 +112,7 @@ func (s Storage) ReadMeta(ctx context.Context, UUID string) (*entries.EntryMeta, entries WHERE uuid=$1 + AND remaining_reads > 0 `, UUID) var created time.Time @@ -65,13 +120,9 @@ func (s Storage) ReadMeta(ctx context.Context, UUID string) (*entries.EntryMeta, var expireNullTime sql.NullTime var remainingReadsNullInt32 sql.NullInt32 var deleteKeyNullString sql.NullString - err = row.Scan(&created, &accessedNullTime, &expireNullTime, &remainingReadsNullInt32, &deleteKeyNullString) + err := row.Scan(&created, &accessedNullTime, &expireNullTime, &remainingReadsNullInt32, &deleteKeyNullString) if err != nil { - tx.Rollback() - if err == sql.ErrNoRows { - return nil, entries.ErrEntryNotFound - } return nil, err } @@ -102,62 +153,36 @@ func (s Storage) ReadMeta(ctx context.Context, UUID string) (*entries.EntryMeta, DeleteKey: deleteKey, } - if meta.IsExpired() { - _, err = tx.ExecContext(ctx, ` - UPDATE entries - SET data=$1, accessed=$2 - WHERE uuid=$3 - `, nil, time.Now(), UUID) - - if err != nil { - tx.Rollback() - return nil, err - } - err := tx.Commit() - if err != nil { - return nil, err - } - - return nil, entries.ErrEntryExpired - } - - err = tx.Commit() - - if err != nil { - return nil, err - } - return meta, nil } -// Get to get entry including the actual secret +func (s Storage) updateReadCount(tx *sql.Tx, UUID string) error { + _, err := tx.Exec("UPDATE entries SET remaining_reads = remaining_reads - 1 WHERE uuid=$1;", UUID) + return err +} + +// read to get entry including the actual secret // returns the data if the secret not expired yet // updates read count -func (s Storage) Get(ctx context.Context, UUID string) (*entries.Entry, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } +func (s Storage) read(tx *sql.Tx, UUID string) (*entries.Entry, error) { - row := tx.QueryRowContext(ctx, "SELECT data, created, accessed, expire FROM entries WHERE uuid=$1", UUID) + row := tx.QueryRow(`SELECT data, created, accessed, expire, remaining_reads FROM entries + WHERE uuid=$1 + AND remaining_reads > 0 + LIMIT 1`, UUID) var data []byte var created time.Time var accessedNullTime sql.NullTime var expireNullTime sql.NullTime - err = row.Scan(&data, &created, &accessedNullTime, &expireNullTime) - - if err != nil { - tx.Rollback() - if err == sql.ErrNoRows { - return nil, entries.ErrEntryNotFound - } - + var remainingReadsNullInt32 sql.NullInt32 + if err := row.Scan(&data, &created, &accessedNullTime, &expireNullTime, &remainingReadsNullInt32); err != nil { return nil, err } var accessed time.Time var expire time.Time + var maxReads int32 if accessedNullTime.Valid { accessed = accessedNullTime.Time @@ -165,33 +190,16 @@ func (s Storage) Get(ctx context.Context, UUID string) (*entries.Entry, error) { if expireNullTime.Valid { expire = expireNullTime.Time } + if remainingReadsNullInt32.Valid { + maxReads = remainingReadsNullInt32.Int32 + } meta := entries.EntryMeta{ UUID: UUID, Created: created, Accessed: accessed, Expire: expire, - } - - if meta.IsExpired() { - _, err = tx.ExecContext(ctx, "UPDATE entries SET data=$1, accessed=$2 WHERE uuid=$3", nil, time.Now(), UUID) - - if err != nil { - tx.Rollback() - return nil, err - } - err := tx.Commit() - if err != nil { - return nil, err - } - - return nil, entries.ErrEntryExpired - } - - err = tx.Commit() - - if err != nil { - return nil, err + MaxReads: maxReads, } return &entries.Entry{ @@ -210,13 +218,7 @@ func (s Storage) Read(ctx context.Context, UUID string) (*entries.Entry, error) return nil, err } - row := tx.QueryRowContext(ctx, "SELECT data, created, accessed, expire FROM entries WHERE uuid=$1", UUID) - - var data []byte - var created time.Time - var accessedNullTime sql.NullTime - var expireNullTime sql.NullTime - err = row.Scan(&data, &created, &accessedNullTime, &expireNullTime) + entry, err := s.read(tx, UUID) if err != nil { tx.Rollback() @@ -226,16 +228,15 @@ func (s Storage) Read(ctx context.Context, UUID string) (*entries.Entry, error) return nil, err } - queries := []string{ - "UPDATE entries SET remaining_reads = remaining_reads - 1 WHERE uuid=$1;", - "DELETE FROM entries WHERE uuid=$1 AND remaining_reads < 1;", + if entry.IsExpired() { + s.setAccessed(tx, UUID) + tx.Commit() + return nil, entries.ErrEntryExpired } - for _, query := range queries { - _, err = tx.ExecContext(ctx, query, UUID) - if err != nil { - tx.Rollback() - return nil, err - } + + if err := s.updateReadCount(tx, UUID); err != nil { + tx.Rollback() + return nil, err } err = tx.Commit() @@ -243,31 +244,7 @@ func (s Storage) Read(ctx context.Context, UUID string) (*entries.Entry, error) return nil, err } - var accessed time.Time - var expire time.Time - - if accessedNullTime.Valid { - accessed = accessedNullTime.Time - } - if expireNullTime.Valid { - expire = expireNullTime.Time - } - - meta := entries.EntryMeta{ - UUID: UUID, - Created: created, - Accessed: accessed, - Expire: expire, - } - - if meta.IsExpired() { - return nil, entries.ErrEntryExpired - } - - return &entries.Entry{ - EntryMeta: meta, - Data: data, - }, nil + return entry, nil } // Delete deletes the entry from the database @@ -346,29 +323,10 @@ func (s Storage) DeleteExpired(ctx context.Context) error { _, err = tx.ExecContext(ctx, "DELETE FROM entries WHERE expire < NOW() OR remaining_reads < 1;") if err != nil { + fmt.Println("DELETE ERRRO", err) tx.Rollback() return err } return tx.Commit() } - -// NewPostgresCleanableStorage Creates a cleanable psql storage instance -func NewPostgresCleanableStorage(s *Storage) *PostgresCleanableStorage { - return &PostgresCleanableStorage{s} -} - -// PostgresCleanableStorage extends the regular PostgresqlStorage with a Clean -// method to remove all entries -type PostgresCleanableStorage struct { - *Storage -} - -// Clean deletes all entries from the database -func (s PostgresCleanableStorage) Clean() { - _, err := s.db.Exec("TRUNCATE entries;") - - if err != nil { - log.Fatal(err) - } -} diff --git a/storage/postgresql/postgresql_storage_test.go b/storage/postgresql/postgresql_storage_test.go index 72d56c0..ad3e7e3 100644 --- a/storage/postgresql/postgresql_storage_test.go +++ b/storage/postgresql/postgresql_storage_test.go @@ -2,7 +2,6 @@ package postgresql import ( "context" - "database/sql" "testing" "time" @@ -10,39 +9,7 @@ import ( "github.com/Ajnasz/sekret.link/uuid" ) -// func TestPostgresqlStorageWriteGet(t *testing.T) { -// psqlConn := testhelper.GetPSQLTestConn() -// storage := NewStorage(psqlConn) -// t.Cleanup(func() { -// defer storage.Close() -// }) -// testCases := []string{ -// "foo", -// } - -// for _, testCase := range testCases { -// t.Run(testCase, func(t *testing.T) { - -// UUID := uuid.NewUUIDString() -// err := storage.Write(UUID, []byte("foo"), time.Second*10, 1) - -// if err != nil { -// t.Fatal(err) -// } -// res, err := storage.Get(UUID) -// if err != nil { -// t.Fatal(err) -// } - -// actual := string(res.Data) -// if actual != testCase { -// t.Errorf("expected: %s, actual: %s", testCase, actual) -// } -// }) -// } -// } - -func TestPostgresqlStorageWrite(t *testing.T) { +func Test_PostgresqlStorageWrite(t *testing.T) { psqlConn := testhelper.GetPSQLTestConn() storage := NewStorage(psqlConn) t.Cleanup(func() { @@ -50,38 +17,34 @@ func TestPostgresqlStorageWrite(t *testing.T) { }) testCases := []struct { - Name string - Secret string - Reads int - Remaining int - ExistanceErr error + Name string + Secret string + Reads int + Remaining int }{ { - Name: "Simple get", - Secret: "foo", - Reads: 1, - Remaining: 0, - ExistanceErr: sql.ErrNoRows, + Name: "Simple get", + Secret: "foo", + Reads: 1, + Remaining: 0, }, { - Name: "Exist get", - Secret: "bar", - Reads: 2, - Remaining: 1, - ExistanceErr: nil, + Name: "Exist get", + Secret: "bar", + Reads: 2, + Remaining: 1, }, { - Name: "Exist get 2", - Secret: "bar", - Reads: 3, - Remaining: 2, - ExistanceErr: nil, + Name: "Exist get 2", + Secret: "bar", + Reads: 3, + Remaining: 2, }, } for _, testCase := range testCases { t.Run(testCase.Name, func(t *testing.T) { - + t.Logf("%+v", testCase) UUID := uuid.NewUUIDString() ctx := context.Background() err := storage.Write(ctx, UUID, []byte(testCase.Secret), time.Second*10, testCase.Reads) @@ -96,16 +59,15 @@ func TestPostgresqlStorageWrite(t *testing.T) { actual := string(res.Data) if actual != testCase.Secret { - t.Errorf("expected: %s, actual: %s", testCase.Secret, actual) + t.Errorf("%s expected: %s, actual: %s", UUID, testCase.Secret, actual) } - var data []byte var remainingReads int - row := storage.db.QueryRow("SELECT data, remaining_reads FROM entries WHERE uuid=$1", UUID) - err = row.Scan(&data, &remainingReads) - if err != testCase.ExistanceErr { - t.Fatal(err) + row := storage.db.QueryRow("SELECT remaining_reads FROM entries WHERE uuid=$1", UUID) + err = row.Scan(&remainingReads) + if err != nil { + t.Fatalf("%s: %v", UUID, err) } if remainingReads != testCase.Remaining { @@ -121,7 +83,6 @@ func TestPostgresqlStorageVerifyDelete(t *testing.T) { t.Cleanup(func() { storage.Close() }) - defer storage.Close() testCases := []struct { UUID string Key string diff --git a/storage/secret/secret_storage.go b/storage/secret/secret_storage.go index 14a9361..1e4017d 100644 --- a/storage/secret/secret_storage.go +++ b/storage/secret/secret_storage.go @@ -94,20 +94,3 @@ func (s SecretStorage) Delete(ctx context.Context, UUID string) error { func (s SecretStorage) DeleteExpired(ctx context.Context) error { return s.internalStorage.DeleteExpired(ctx) } - -// NewCleanableSecretStorage Creates a cleanable secret storage -func NewCleanableSecretStorage(s *SecretStorage, internal storage.Cleanable) CleanableSecretStorage { - return CleanableSecretStorage{s, internal} -} - -// CleanableSecretStorage Storage which implements CleanableStorage interface, -// to allow to clean every entry from the underlying storage -type CleanableSecretStorage struct { - *SecretStorage - internalStorage storage.Cleanable -} - -// Clean Executes the clean call on the storage -func (s CleanableSecretStorage) Clean() { - s.internalStorage.Clean() -} diff --git a/storage/secret/secret_storage_test.go b/storage/secret/secret_storage_test.go index bd5c514..394fda9 100644 --- a/storage/secret/secret_storage_test.go +++ b/storage/secret/secret_storage_test.go @@ -14,20 +14,15 @@ import ( func TestSecretStorage(t *testing.T) { testData := "Lorem ipusm dolor sit amet" - connection := postgresql.NewStorage(testhelper.GetPSQLTestConn()) - t.Cleanup(func() { - connection.Close() - }) - psqlStorage := postgresql.PostgresCleanableStorage{connection} - storage := &CleanableSecretStorage{ - NewSecretStorage( - psqlStorage, - dummy.NewEncrypter(), - ), + psqlStorage := postgresql.NewStorage(testhelper.GetPSQLTestConn()) + storage := NewSecretStorage( psqlStorage, - } - // TODO defer storage.Close() + dummy.NewEncrypter(), + ) + t.Cleanup(func() { + storage.Close() + }) UUID := uuid.NewUUIDString() ctx := context.Background() err := storage.Write(ctx, UUID, []byte(testData), time.Second*10, 1) diff --git a/storage/storage.go b/storage/storage.go index c64fb16..bf633bd 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -37,12 +37,6 @@ type Storage interface { Writer } -// Cleanable Interface which enables to remove every entry from a storae -type Cleanable interface { - Storage - Clean() -} - // Verifyable an interface which extends the EntryStorage with a // VerifyDelete method type Verifyable interface {