From 16a31d2fd93c3fce4d4a4adc7e2526b45b81cf98 Mon Sep 17 00:00:00 2001 From: Kirill Malikov Date: Wed, 21 Feb 2024 16:49:45 +0300 Subject: [PATCH] fix nil ptr dereference --- basic.go | 10 ++++++++++ validator.go | 38 ++++++++++++++++++++++---------------- validator_test.go | 27 +++++++++++++++++++++++++-- 3 files changed, 57 insertions(+), 18 deletions(-) diff --git a/basic.go b/basic.go index b82bb31..8c33e43 100644 --- a/basic.go +++ b/basic.go @@ -60,3 +60,13 @@ func toString(v any) (string, bool) { return "", false } + +func hasRequiredRule(rules []Rule) (Required, bool) { + for _, r := range rules { + if v, ok := r.(Required); ok { + return v, ok + } + } + + return Required{}, false +} diff --git a/validator.go b/validator.go index 28547da..272bd85 100644 --- a/validator.go +++ b/validator.go @@ -13,6 +13,15 @@ func ValidateValue(ctx context.Context, value any, rules ...Rule) error { return nil } + if value == nil { + requiredRule, ok := hasRequiredRule(rules) + if !ok { + return nil + } + + return requiredRule.ValidateValue(ctx, value) + } + dataSet, err := normalizeDataSet(value) if err != nil { return err @@ -25,23 +34,16 @@ func ValidateValue(ctx context.Context, value any, rules ...Rule) error { rules = normalizeRules(rules) result := NewResult() - for _, validatorRule := range rules { - if _, ok := validatorRule.(Required); !ok { - if value == nil { - // if value is not required and is nil - continue - } - } - - if err := validatorRule.ValidateValue(ctx, value); err != nil { + for _, r := range rules { + if err := r.ValidateValue(ctx, value); err != nil { var errRes Result if errors.As(err, &errRes) { - for _, rErr := range errRes.Errors() { - result = result.WithError(rErr) - } - } else { - return err + result = result.WithError(errRes.Errors()...) + + continue } + + return err } } @@ -110,11 +112,11 @@ func Validate(ctx context.Context, dataSet any, rules RuleSet) error { for _, err := range errs { err.Message = DefaultTranslator.Translate(ctx, err.Message, err.Params) summaryResult = summaryResult.WithError(err) - //summaryResult = summaryResult.WithError( + // summaryResult = summaryResult.WithError( // NewValidationError(DefaultTranslator.Translate(ctx, err.Message, err.Params)). // WithParams(err.Params). // WithValuePath(err.ValuePath), - //) + // ) } } @@ -126,6 +128,10 @@ func Validate(ctx context.Context, dataSet any, rules RuleSet) error { } func normalizeDataSet(ds any) (DataSet, error) { + if ds == nil { + return set.NewDataSetAny(ds), nil + } + rt := reflect.TypeOf(ds) if rt.Kind() == reflect.Pointer { rt = rt.Elem() diff --git a/validator_test.go b/validator_test.go index 023bdc2..512aedb 100644 --- a/validator_test.go +++ b/validator_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestValidate_Int_Successfully(t *testing.T) { +func TestValidateValue_Int_Successfully(t *testing.T) { ctx := context.Background() rules := []Rule{ NewRequired(), @@ -18,7 +18,7 @@ func TestValidate_Int_Successfully(t *testing.T) { assert.NoError(t, err) } -func TestValidate_Int_Failure(t *testing.T) { +func TestValidateValue_Int_Failure(t *testing.T) { ctx := context.Background() rules := []Rule{ NewRequired(), @@ -39,6 +39,16 @@ func TestValidate_Int_Failure(t *testing.T) { assert.Equal(t, expectedResult, err) } +func TestValidateValue_IntNilPtrValue_Successfully(t *testing.T) { + ctx := context.Background() + rules := []Rule{ + NewNumber(1, 3), + } + + err := ValidateValue(ctx, nil, rules...) + assert.NoError(t, err) +} + func TestValidate_Map_Successfully(t *testing.T) { ctx := context.Background() rules := RuleSet{ @@ -54,3 +64,16 @@ func TestValidate_Map_Successfully(t *testing.T) { err := Validate(ctx, data, rules) assert.NoError(t, err) } + +func TestValidate_Nil_Failure(t *testing.T) { + ctx := context.Background() + rules := RuleSet{ + "count": { + NewRequired(), + NewNumber(1, 3), + }, + } + + err := Validate(ctx, nil, rules) + assert.Error(t, err) +}