diff --git a/README.md b/README.md index 711da17..112f5f4 100644 --- a/README.md +++ b/README.md @@ -137,8 +137,9 @@ Generated documentation for the latest version of the Go implementation can be a If you're looking to integrate Cedar into a production system, please be sure the read the [security best practices](https://docs.cedarpolicy.com/other/security.html) ## Backward Compatibility Considerations - -x/exp - code in this directory is not subject to the semantic version constraints of the rest of the module and breaking changes may be made at any time +- `x/exp` - code in this directory is not subject to the semantic versioning constraints of the rest of the module and breaking changes may be made at any time. +- Variadics may be added to functions that do not have them to expand the arguments of a function or method. +- Concrete types may be replaced with compatible interfaces to expand the variety of arguments a function or method can take. ## Change log diff --git a/authorize.go b/authorize.go index 302d6d7..d8164fa 100644 --- a/authorize.go +++ b/authorize.go @@ -18,7 +18,7 @@ const ( // IsAuthorized uses the combination of the PolicySet and Entities to determine // if the given Request to determine Decision and Diagnostic. -func (p PolicySet) IsAuthorized(entities types.EntityMap, req Request) (Decision, Diagnostic) { +func (p PolicySet) IsAuthorized(entities types.EntityGetter, req Request) (Decision, Diagnostic) { env := eval.Env{ Entities: entities, Principal: req.Principal, diff --git a/internal/eval/evalers.go b/internal/eval/evalers.go index 4ca1bce..d8cded9 100644 --- a/internal/eval/evalers.go +++ b/internal/eval/evalers.go @@ -22,7 +22,7 @@ func zeroValue() types.Value { } type Env struct { - Entities types.EntityMap + Entities types.EntityGetter Principal, Action, Resource types.Value Context types.Value } @@ -754,7 +754,7 @@ func (n *attributeAccessEval) Eval(env Env) (types.Value, error) { if vv == unspecified { return zeroValue(), fmt.Errorf("cannot access attribute `%s` of %w", n.attribute, errUnspecifiedEntity) } - rec, ok := env.Entities[vv] + rec, ok := env.Entities.Get(vv) if !ok { return zeroValue(), fmt.Errorf("entity `%v` %w", vv.String(), errEntityNotExist) } @@ -792,7 +792,7 @@ func (n *hasEval) Eval(env Env) (types.Value, error) { var record types.Record switch vv := v.(type) { case types.EntityUID: - if rec, ok := env.Entities[vv]; ok { + if rec, ok := env.Entities.Get(vv); ok { record = rec.Attributes } case types.Record: @@ -861,12 +861,12 @@ func entityInOne(env Env, entity types.EntityUID, parent types.EntityUID) bool { var todo []types.EntityUID var candidate = entity for { - if fe, ok := env.Entities[candidate]; ok { + if fe, ok := env.Entities.Get(candidate); ok { if fe.Parents.Contains(parent) { return true } fe.Parents.Iterate(func(k types.EntityUID) bool { - p, ok := env.Entities[k] + p, ok := env.Entities.Get(k) if !ok || p.Parents.Len() == 0 || k == entity || known.Contains(k) { return true } @@ -890,12 +890,12 @@ func entityInSet(env Env, entity types.EntityUID, parents mapset.Container[types var todo []types.EntityUID var candidate = entity for { - if fe, ok := env.Entities[candidate]; ok { + if fe, ok := env.Entities.Get(candidate); ok { if fe.Parents.Intersects(parents) { return true } fe.Parents.Iterate(func(k types.EntityUID) bool { - p, ok := env.Entities[k] + p, ok := env.Entities.Get(k) if !ok || p.Parents.Len() == 0 || k == entity || known.Contains(k) { return true } diff --git a/internal/eval/partial.go b/internal/eval/partial.go index 07106cd..33dede4 100644 --- a/internal/eval/partial.go +++ b/internal/eval/partial.go @@ -504,7 +504,7 @@ func (n *partialHasEval) Eval(env Env) (types.Value, error) { var record types.Record switch vv := v.(type) { case types.EntityUID: - if rec, ok := env.Entities[vv]; ok { + if rec, ok := env.Entities.Get(vv); ok { record = rec.Attributes } case types.Record: diff --git a/types/entity_map.go b/types/entity_map.go index 70b9034..7de8039 100644 --- a/types/entity_map.go +++ b/types/entity_map.go @@ -8,6 +8,13 @@ import ( "golang.org/x/exp/maps" ) +// An EntityGetter is an interface for retrieving an Entity by EntityUID. +type EntityGetter interface { + Get(uid EntityUID) (Entity, bool) +} + +var _ EntityGetter = EntityMap{} + // An EntityMap is a collection of all the entities that are needed to evaluate // authorization requests. The key is an EntityUID which uniquely identifies // the Entity (it must be the same as the UID within the Entity itself.) @@ -37,3 +44,8 @@ func (e *EntityMap) UnmarshalJSON(b []byte) error { func (e EntityMap) Clone() EntityMap { return maps.Clone(e) } + +func (e EntityMap) Get(uid EntityUID) (Entity, bool) { + ent, ok := e[uid] + return ent, ok +} diff --git a/types/entity_map_test.go b/types/entity_map_test.go index 507e38d..a8a2afe 100644 --- a/types/entity_map_test.go +++ b/types/entity_map_test.go @@ -22,6 +22,21 @@ func TestEntities(t *testing.T) { testutil.Equals(t, clone, e) }) + t.Run("Get", func(t *testing.T) { + t.Parallel() + ent := types.Entity{ + UID: types.NewEntityUID("Type", "id"), + Attributes: types.NewRecord(types.RecordMap{"key": types.Long(42)}), + } + e := types.EntityMap{ + ent.UID: ent, + } + got, ok := e.Get(ent.UID) + testutil.Equals(t, ok, true) + testutil.Equals(t, got, ent) + _, ok = e.Get(types.NewEntityUID("Type", "id2")) + testutil.Equals(t, ok, false) + }) } func TestEntitiesJSON(t *testing.T) {