From 0f337e9c1455dcdb6d4dd4e423345cf642629630 Mon Sep 17 00:00:00 2001 From: Jian Zeng Date: Fri, 5 Jan 2024 22:54:13 +0800 Subject: [PATCH] feat: make errdefs.IsXXX helper functions work with wrapped errors Signed-off-by: Jian Zeng --- errdefs/helpers_test.go | 66 +++++++++++++++++++++++++++++++++++++++++ errdefs/is.go | 6 ++++ 2 files changed, 72 insertions(+) diff --git a/errdefs/helpers_test.go b/errdefs/helpers_test.go index 41a3b9e18836e..4d902ea819a6e 100644 --- a/errdefs/helpers_test.go +++ b/errdefs/helpers_test.go @@ -2,6 +2,7 @@ package errdefs import ( "errors" + "fmt" "testing" ) @@ -25,6 +26,11 @@ func TestNotFound(t *testing.T) { if !errors.Is(e, errTest) { t.Fatalf("expected not found error to match errTest") } + + wrapped := fmt.Errorf("foo: %w", e) + if !IsNotFound(wrapped) { + t.Fatalf("expected not found error, got: %T", wrapped) + } } func TestConflict(t *testing.T) { @@ -41,6 +47,11 @@ func TestConflict(t *testing.T) { if !errors.Is(e, errTest) { t.Fatalf("expected conflict error to match errTest") } + + wrapped := fmt.Errorf("foo: %w", e) + if !IsConflict(wrapped) { + t.Fatalf("expected conflict error, got: %T", wrapped) + } } func TestForbidden(t *testing.T) { @@ -57,6 +68,11 @@ func TestForbidden(t *testing.T) { if !errors.Is(e, errTest) { t.Fatalf("expected forbidden error to match errTest") } + + wrapped := fmt.Errorf("foo: %w", e) + if !IsForbidden(wrapped) { + t.Fatalf("expected forbidden error, got: %T", wrapped) + } } func TestInvalidParameter(t *testing.T) { @@ -73,6 +89,11 @@ func TestInvalidParameter(t *testing.T) { if !errors.Is(e, errTest) { t.Fatalf("expected invalid argument error to match errTest") } + + wrapped := fmt.Errorf("foo: %w", e) + if !IsInvalidParameter(wrapped) { + t.Fatalf("expected invalid argument error, got: %T", wrapped) + } } func TestNotImplemented(t *testing.T) { @@ -89,6 +110,11 @@ func TestNotImplemented(t *testing.T) { if !errors.Is(e, errTest) { t.Fatalf("expected not implemented error to match errTest") } + + wrapped := fmt.Errorf("foo: %w", e) + if !IsNotImplemented(wrapped) { + t.Fatalf("expected not implemented error, got: %T", wrapped) + } } func TestNotModified(t *testing.T) { @@ -105,6 +131,11 @@ func TestNotModified(t *testing.T) { if !errors.Is(e, errTest) { t.Fatalf("expected not modified error to match errTest") } + + wrapped := fmt.Errorf("foo: %w", e) + if !IsNotModified(wrapped) { + t.Fatalf("expected not modified error, got: %T", wrapped) + } } func TestUnauthorized(t *testing.T) { @@ -121,6 +152,11 @@ func TestUnauthorized(t *testing.T) { if !errors.Is(e, errTest) { t.Fatalf("expected unauthorized error to match errTest") } + + wrapped := fmt.Errorf("foo: %w", e) + if !IsUnauthorized(wrapped) { + t.Fatalf("expected unauthorized error, got: %T", wrapped) + } } func TestUnknown(t *testing.T) { @@ -137,6 +173,11 @@ func TestUnknown(t *testing.T) { if !errors.Is(e, errTest) { t.Fatalf("expected unknown error to match errTest") } + + wrapped := fmt.Errorf("foo: %w", e) + if !IsUnknown(wrapped) { + t.Fatalf("expected unknown error, got: %T", wrapped) + } } func TestCancelled(t *testing.T) { @@ -153,6 +194,11 @@ func TestCancelled(t *testing.T) { if !errors.Is(e, errTest) { t.Fatalf("expected cancelled error to match errTest") } + + wrapped := fmt.Errorf("foo: %w", e) + if !IsCancelled(wrapped) { + t.Fatalf("expected cancelled error, got: %T", wrapped) + } } func TestDeadline(t *testing.T) { @@ -169,6 +215,11 @@ func TestDeadline(t *testing.T) { if !errors.Is(e, errTest) { t.Fatalf("expected deadline error to match errTest") } + + wrapped := fmt.Errorf("foo: %w", e) + if !IsDeadline(wrapped) { + t.Fatalf("expected deadline error, got: %T", wrapped) + } } func TestDataLoss(t *testing.T) { @@ -185,6 +236,11 @@ func TestDataLoss(t *testing.T) { if !errors.Is(e, errTest) { t.Fatalf("expected data loss error to match errTest") } + + wrapped := fmt.Errorf("foo: %w", e) + if !IsDataLoss(wrapped) { + t.Fatalf("expected data loss error, got: %T", wrapped) + } } func TestUnavailable(t *testing.T) { @@ -201,6 +257,11 @@ func TestUnavailable(t *testing.T) { if !errors.Is(e, errTest) { t.Fatalf("expected unavaillable error to match errTest") } + + wrapped := fmt.Errorf("foo: %w", e) + if !IsUnavailable(wrapped) { + t.Fatalf("expected unavaillable error, got: %T", wrapped) + } } func TestSystem(t *testing.T) { @@ -217,4 +278,9 @@ func TestSystem(t *testing.T) { if !errors.Is(e, errTest) { t.Fatalf("expected system error to match errTest") } + + wrapped := fmt.Errorf("foo: %w", e) + if !IsSystem(wrapped) { + t.Fatalf("expected system error, got: %T", wrapped) + } } diff --git a/errdefs/is.go b/errdefs/is.go index b0d745ca7ae35..f94034cbd7dd1 100644 --- a/errdefs/is.go +++ b/errdefs/is.go @@ -9,6 +9,10 @@ type causer interface { Cause() error } +type wrapErr interface { + Unwrap() error +} + func getImplementer(err error) error { switch e := err.(type) { case @@ -28,6 +32,8 @@ func getImplementer(err error) error { return err case causer: return getImplementer(e.Cause()) + case wrapErr: + return getImplementer(e.Unwrap()) default: return err }