Skip to content

Commit

Permalink
feat(state): Improve errors for validation failures
Browse files Browse the repository at this point in the history
  • Loading branch information
aholstenson committed Jan 14, 2024
1 parent bc53bd0 commit 99bde72
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 60 deletions.
34 changes: 31 additions & 3 deletions internal/api/state/v1alpha1/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"go.opentelemetry.io/otel/propagation"
"go.uber.org/fx"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
)

Expand Down Expand Up @@ -48,7 +50,14 @@ func (s *StateServiceServer) EnsureStore(ctx context.Context, req *statev1alpha1
err := s.state.EnsureStore(ctx, &state.StoreConfig{
Name: req.Store,
})
if err != nil {

if errors.Is(err, context.Canceled) {
return nil, status.Error(codes.Canceled, "context canceled")
} else if errors.Is(err, context.DeadlineExceeded) {
return nil, status.Error(codes.DeadlineExceeded, "timed out")
} else if state.IsValidationError(err) {
return nil, status.Error(codes.InvalidArgument, err.Error())
} else if err != nil {
return nil, err
}

Expand All @@ -62,6 +71,12 @@ func (s *StateServiceServer) Get(ctx context.Context, req *statev1alpha1.GetRequ
return &statev1alpha1.GetResponse{
Revision: 0,
}, nil
} else if errors.Is(err, context.Canceled) {
return nil, status.Error(codes.Canceled, "context canceled")
} else if errors.Is(err, context.DeadlineExceeded) {
return nil, status.Error(codes.DeadlineExceeded, "timed out")
} else if state.IsValidationError(err) {
return nil, status.Error(codes.InvalidArgument, err.Error())
} else if err != nil {
return nil, err
}
Expand All @@ -84,7 +99,13 @@ func (s *StateServiceServer) Set(ctx context.Context, req *statev1alpha1.SetRequ
revision, err = s.state.Set(ctx, req.Store, req.Key, req.Value)
}

if err != nil {
if errors.Is(err, context.Canceled) {
return nil, status.Error(codes.Canceled, "context canceled")
} else if errors.Is(err, context.DeadlineExceeded) {
return nil, status.Error(codes.DeadlineExceeded, "timed out")
} else if state.IsValidationError(err) {
return nil, status.Error(codes.InvalidArgument, err.Error())
} else if err != nil {
return nil, err
}

Expand All @@ -100,7 +121,14 @@ func (s *StateServiceServer) Delete(ctx context.Context, req *statev1alpha1.Dele
} else {
err = s.state.Delete(ctx, req.Store, req.Key)
}
if err != nil {

if errors.Is(err, context.Canceled) {
return nil, status.Error(codes.Canceled, "context canceled")
} else if errors.Is(err, context.DeadlineExceeded) {
return nil, status.Error(codes.DeadlineExceeded, "timed out")
} else if state.IsValidationError(err) {
return nil, status.Error(codes.InvalidArgument, err.Error())
} else if err != nil {
return nil, err
}

Expand Down
25 changes: 18 additions & 7 deletions internal/state/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,8 @@ package state

import "errors"

// ErrStoreRequired is returned when a store is not provided.
var ErrStoreRequired = errors.New("store required")

// ErrStoreNotFound is returned when a store is not found.
var ErrStoreNotFound = errors.New("store not found")

// ErrKeyRequired is returned when a key is not provided.
var ErrKeyRequired = errors.New("missing key")
var ErrStoreNotFound = &validationError{err: "store not found"}

// ErrKeyNotFound is returned when a key is not found in a store.
var ErrKeyNotFound = errors.New("key not found")
Expand All @@ -21,3 +15,20 @@ var ErrKeyAlreadyExists = errors.New("key already exists")
// ErrRevisionMismatch is returned when a revision does not match the current
// revision of a key.
var ErrRevisionMismatch = errors.New("revision mismatch")

type validationError struct {
err string
}

func (e *validationError) Error() string {
return e.err
}

func newValidationError(err string) error {
return &validationError{err: err}
}

func IsValidationError(err error) bool {
_, ok := err.(*validationError)
return ok
}
75 changes: 47 additions & 28 deletions internal/state/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package state

import (
"context"
"strings"
"time"

"github.com/cockroachdb/errors"
Expand Down Expand Up @@ -82,13 +81,12 @@ func (m *Manager) EnsureStore(ctx context.Context, config *StoreConfig) error {
)
defer span.End()

storeName := strings.TrimSpace(config.Name)
if storeName == "" {
span.SetStatus(codes.Error, "store required")
return errors.WithStack(ErrStoreRequired)
if !IsValidStoreName(config.Name) {
span.SetStatus(codes.Error, "invalid store name")
return newValidationError("invalid store name: " + config.Name)
}

_, err := m.stores.Get(ctx, storeName)
_, err := m.stores.Get(ctx, config.Name)
if errors.Is(err, ErrStoreNotFound) {
_, err = m.js.CreateKeyValue(ctx, jetstream.KeyValueConfig{
Bucket: config.Name,
Expand Down Expand Up @@ -124,17 +122,10 @@ func (m *Manager) Get(ctx context.Context, store string, key string) (*Entry, er
)
defer span.End()

store = strings.TrimSpace(store)
key = strings.TrimSpace(key)

if store == "" {
span.SetStatus(codes.Error, "store required")
return nil, errors.WithStack(ErrStoreRequired)
}

if key == "" {
span.SetStatus(codes.Error, "key required")
return nil, errors.WithStack(ErrKeyRequired)
err := validatePreconditions(store, key)
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, err
}

bucket, err := m.stores.Get(ctx, store)
Expand Down Expand Up @@ -187,17 +178,10 @@ func (m *Manager) Create(ctx context.Context, store string, key string, value *a
)
defer span.End()

store = strings.TrimSpace(store)
key = strings.TrimSpace(key)

if store == "" {
span.SetStatus(codes.Error, "store required")
return 0, errors.WithStack(ErrStoreRequired)
}

if key == "" {
span.SetStatus(codes.Error, "key required")
return 0, errors.WithStack(ErrKeyRequired)
err := validatePreconditions(store, key)
if err != nil {
span.SetStatus(codes.Error, err.Error())
return 0, err
}

bucket, err := m.stores.Get(ctx, store)
Expand Down Expand Up @@ -244,6 +228,12 @@ func (m *Manager) Set(ctx context.Context, store string, key string, value *anyp
)
defer span.End()

err := validatePreconditions(store, key)
if err != nil {
span.SetStatus(codes.Error, err.Error())
return 0, err
}

bucket, err := m.stores.Get(ctx, store)
if err != nil {
span.RecordError(err)
Expand Down Expand Up @@ -286,6 +276,12 @@ func (m *Manager) Update(ctx context.Context, store string, key string, value *a
)
defer span.End()

err := validatePreconditions(store, key)
if err != nil {
span.SetStatus(codes.Error, err.Error())
return 0, err
}

bucket, err := m.stores.Get(ctx, store)
if err != nil {
span.RecordError(err)
Expand Down Expand Up @@ -339,6 +335,12 @@ func (m *Manager) Delete(ctx context.Context, store string, key string) error {
)
defer span.End()

err := validatePreconditions(store, key)
if err != nil {
span.SetStatus(codes.Error, err.Error())
return err
}

bucket, err := m.stores.Get(ctx, store)
if err != nil {
span.RecordError(err)
Expand Down Expand Up @@ -373,6 +375,12 @@ func (m *Manager) DeleteWithRevision(ctx context.Context, store string, key stri
)
defer span.End()

err := validatePreconditions(store, key)
if err != nil {
span.SetStatus(codes.Error, err.Error())
return err
}

bucket, err := m.stores.Get(ctx, store)
if err != nil {
span.RecordError(err)
Expand All @@ -390,3 +398,14 @@ func (m *Manager) DeleteWithRevision(ctx context.Context, store string, key stri
span.SetStatus(codes.Ok, "")
return nil
}

func validatePreconditions(store string, key string) error {
if !IsValidStoreName(store) {
return newValidationError("invalid store name: " + store)
}

if !IsValidKey(key) {
return newValidationError("invalid key: " + key)
}
return nil
}
14 changes: 14 additions & 0 deletions internal/state/names.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package state

import "windshift/service/internal/events"

// IsValidStoreName checks if the store name is valid.
func IsValidStoreName(name string) bool {
return events.IsValidStreamName(name)
}

// IsValidKey checks if the key name is valid. The NATS documentation says
// that keys follow the same rules as subjects so we delegate to events.IsValidSubject.
func IsValidKey(name string) bool {
return events.IsValidSubject(name, false)
}
Loading

0 comments on commit 99bde72

Please sign in to comment.