From eea990b444506f2fb08d57bb69942ddb7ceaedc0 Mon Sep 17 00:00:00 2001 From: Patrick Jakubowski Date: Tue, 12 Nov 2024 12:47:22 -0800 Subject: [PATCH] types: make unmarshalExtensionArg into a generic constructor function Signed-off-by: Patrick Jakubowski --- types/datetime.go | 6 +----- types/decimal.go | 14 ++++++++------ types/duration.go | 6 +----- types/ipaddr.go | 12 +++++++----- types/json.go | 20 +++++++++++++------- 5 files changed, 30 insertions(+), 28 deletions(-) diff --git a/types/datetime.go b/types/datetime.go index 631c87c..6d8079f 100644 --- a/types/datetime.go +++ b/types/datetime.go @@ -250,15 +250,11 @@ func (a Datetime) String() string { // - { "fn": "datetime", "arg": "1970-01-01" } // - "1970-01-01" func (a *Datetime) UnmarshalJSON(b []byte) error { - arg, err := unmarshalExtensionArg(b, "datetime") + aa, err := unmarshalExtensionValue(b, "datetime", ParseDatetime) if err != nil { return err } - aa, err := ParseDatetime(arg) - if err != nil { - return err - } *a = aa return nil } diff --git a/types/decimal.go b/types/decimal.go index 19873a3..539bade 100644 --- a/types/decimal.go +++ b/types/decimal.go @@ -160,17 +160,19 @@ func (d Decimal) String() string { return res[:right] } +// UnmarshalJSON implements encoding/json.Unmarshaler for Decimal +// +// It is capable of unmarshaling 3 different representations supported by Cedar +// - { "__extn": { "fn": "decimal", "arg": "1234.5678" }} +// - { "fn": "decimal", "arg": "1234.5678" } +// - "1234.5678" func (d *Decimal) UnmarshalJSON(b []byte) error { - arg, err := unmarshalExtensionArg(b, "decimal") + dd, err := unmarshalExtensionValue(b, "decimal", ParseDecimal) if err != nil { return err } - vv, err := ParseDecimal(arg) - if err != nil { - return err - } - *d = vv + *d = dd return nil } diff --git a/types/duration.go b/types/duration.go index 9acc05f..9443b5c 100644 --- a/types/duration.go +++ b/types/duration.go @@ -218,15 +218,11 @@ func (v Duration) String() string { // - { "fn": "duration", "arg": "1h10m" } // - "1h10m" func (v *Duration) UnmarshalJSON(b []byte) error { - arg, err := unmarshalExtensionArg(b, "duration") + vv, err := unmarshalExtensionValue(b, "duration", ParseDuration) if err != nil { return err } - vv, err := ParseDuration(arg) - if err != nil { - return err - } *v = vv return nil } diff --git a/types/ipaddr.go b/types/ipaddr.go index e3a3ae9..607b24c 100644 --- a/types/ipaddr.go +++ b/types/ipaddr.go @@ -103,16 +103,18 @@ func (c IPAddr) Contains(o IPAddr) bool { return c.Prefix().Contains(o.Addr()) && c.Prefix().Bits() <= o.Prefix().Bits() } +// UnmarshalJSON implements encoding/json.Unmarshaler for IPAddr +// +// It is capable of unmarshaling 3 different representations supported by Cedar +// - { "__extn": { "fn": "ip", "arg": "12.34.56.78" }} +// - { "fn": "ip", "arg": "12.34.56.78" } +// - "12.34.56.78" func (v *IPAddr) UnmarshalJSON(b []byte) error { - arg, err := unmarshalExtensionArg(b, "ip") + vv, err := unmarshalExtensionValue(b, "ip", ParseIPAddr) if err != nil { return err } - vv, err := ParseIPAddr(arg) - if err != nil { - return err - } *v = vv return nil } diff --git a/types/json.go b/types/json.go index 37d691f..e974a91 100644 --- a/types/json.go +++ b/types/json.go @@ -127,16 +127,17 @@ func UnmarshalJSON(b []byte, v *Value) error { return nil } -func unmarshalExtensionArg(b []byte, extName string) (string, error) { +func unmarshalExtensionValue[T any](b []byte, extName string, parse func(string) (T, error)) (T, error) { + var zeroT T var arg string if len(b) > 0 && b[0] == '"' { if err := json.Unmarshal(b, &arg); err != nil { - return "", errors.Join(errJSONDecode, err) + return zeroT, errors.Join(errJSONDecode, err) } } else { var res extValueJSON if err := json.Unmarshal(b, &res); err != nil { - return "", errors.Join(errJSONDecode, err) + return zeroT, errors.Join(errJSONDecode, err) } if res.Extn == nil { // If we didn't find an Extn, maybe it's just an extn. @@ -144,18 +145,23 @@ func unmarshalExtensionArg(b []byte, extName string) (string, error) { _ = json.Unmarshal(b, &res2) // We've tried Ext.Fn and Fn, so no good. if res2.Fn == "" { - return "", errJSONExtNotFound + return zeroT, errJSONExtNotFound } if res2.Fn != extName { - return "", errJSONExtFnMatch + return zeroT, errJSONExtFnMatch } arg = res2.Arg } else if res.Extn.Fn != extName { - return "", errJSONExtFnMatch + return zeroT, errJSONExtFnMatch } else { arg = res.Extn.Arg } } - return arg, nil + v, err := parse(arg) + if err != nil { + return zeroT, err + } + + return v, nil }