Skip to content

Commit

Permalink
types: make unmarshalExtensionArg into a generic constructor function
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick Jakubowski <[email protected]>
  • Loading branch information
patjakdev committed Nov 12, 2024
1 parent a326a32 commit eea990b
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 28 deletions.
6 changes: 1 addition & 5 deletions types/datetime.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,11 @@ func (a Datetime) String() string {
// - { "fn": "datetime", "arg": "1970-01-01" }
// - "1970-01-01"
func (a *Datetime) UnmarshalJSON(b []byte) error {
arg, err := unmarshalExtensionArg(b, "datetime")
aa, err := unmarshalExtensionValue(b, "datetime", ParseDatetime)
if err != nil {
return err
}

aa, err := ParseDatetime(arg)
if err != nil {
return err
}
*a = aa
return nil
}
Expand Down
14 changes: 8 additions & 6 deletions types/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,19 @@ func (d Decimal) String() string {
return res[:right]
}

// UnmarshalJSON implements encoding/json.Unmarshaler for Decimal
//
// It is capable of unmarshaling 3 different representations supported by Cedar
// - { "__extn": { "fn": "decimal", "arg": "1234.5678" }}
// - { "fn": "decimal", "arg": "1234.5678" }
// - "1234.5678"
func (d *Decimal) UnmarshalJSON(b []byte) error {
arg, err := unmarshalExtensionArg(b, "decimal")
dd, err := unmarshalExtensionValue(b, "decimal", ParseDecimal)
if err != nil {
return err
}

vv, err := ParseDecimal(arg)
if err != nil {
return err
}
*d = vv
*d = dd
return nil
}

Expand Down
6 changes: 1 addition & 5 deletions types/duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,11 @@ func (v Duration) String() string {
// - { "fn": "duration", "arg": "1h10m" }
// - "1h10m"
func (v *Duration) UnmarshalJSON(b []byte) error {
arg, err := unmarshalExtensionArg(b, "duration")
vv, err := unmarshalExtensionValue(b, "duration", ParseDuration)
if err != nil {
return err
}

vv, err := ParseDuration(arg)
if err != nil {
return err
}
*v = vv
return nil
}
Expand Down
12 changes: 7 additions & 5 deletions types/ipaddr.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,18 @@ func (c IPAddr) Contains(o IPAddr) bool {
return c.Prefix().Contains(o.Addr()) && c.Prefix().Bits() <= o.Prefix().Bits()
}

// UnmarshalJSON implements encoding/json.Unmarshaler for IPAddr
//
// It is capable of unmarshaling 3 different representations supported by Cedar
// - { "__extn": { "fn": "ip", "arg": "12.34.56.78" }}
// - { "fn": "ip", "arg": "12.34.56.78" }
// - "12.34.56.78"
func (v *IPAddr) UnmarshalJSON(b []byte) error {
arg, err := unmarshalExtensionArg(b, "ip")
vv, err := unmarshalExtensionValue(b, "ip", ParseIPAddr)
if err != nil {
return err
}

vv, err := ParseIPAddr(arg)
if err != nil {
return err
}
*v = vv
return nil
}
Expand Down
20 changes: 13 additions & 7 deletions types/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,35 +127,41 @@ func UnmarshalJSON(b []byte, v *Value) error {
return nil
}

func unmarshalExtensionArg(b []byte, extName string) (string, error) {
func unmarshalExtensionValue[T any](b []byte, extName string, parse func(string) (T, error)) (T, error) {
var zeroT T
var arg string
if len(b) > 0 && b[0] == '"' {
if err := json.Unmarshal(b, &arg); err != nil {
return "", errors.Join(errJSONDecode, err)
return zeroT, errors.Join(errJSONDecode, err)
}
} else {
var res extValueJSON
if err := json.Unmarshal(b, &res); err != nil {
return "", errors.Join(errJSONDecode, err)
return zeroT, errors.Join(errJSONDecode, err)
}
if res.Extn == nil {
// If we didn't find an Extn, maybe it's just an extn.
var res2 extn
_ = json.Unmarshal(b, &res2)
// We've tried Ext.Fn and Fn, so no good.
if res2.Fn == "" {
return "", errJSONExtNotFound
return zeroT, errJSONExtNotFound
}
if res2.Fn != extName {
return "", errJSONExtFnMatch
return zeroT, errJSONExtFnMatch
}
arg = res2.Arg
} else if res.Extn.Fn != extName {
return "", errJSONExtFnMatch
return zeroT, errJSONExtFnMatch
} else {
arg = res.Extn.Arg
}
}

return arg, nil
v, err := parse(arg)
if err != nil {
return zeroT, err
}

return v, nil
}

0 comments on commit eea990b

Please sign in to comment.