From 6bb9a5632daa6a7260152735f5c7fc568fce1508 Mon Sep 17 00:00:00 2001 From: Matt Siwiec Date: Mon, 18 Dec 2023 09:21:49 -0700 Subject: [PATCH] Check owner permissions in resolvers (#291) * check permissions against ownerID, rather than resource Signed-off-by: Matt Siwiec * establish permissions auth-relationship in ports Signed-off-by: Matt Siwiec --------- Signed-off-by: Matt Siwiec --- internal/ent/schema/port.go | 1 + internal/graphapi/loadbalancer.resolvers.go | 26 ++++++++++---------- internal/graphapi/origin.resolvers.go | 8 +++--- internal/graphapi/pool.resolvers.go | 27 +++++++++++---------- internal/graphapi/port.resolvers.go | 9 ++++--- internal/manualhooks/hooks.go | 10 ++++++++ 6 files changed, 47 insertions(+), 34 deletions(-) diff --git a/internal/ent/schema/port.go b/internal/ent/schema/port.go index 47fe46f3c..2c01d2bbd 100644 --- a/internal/ent/schema/port.go +++ b/internal/ent/schema/port.go @@ -58,6 +58,7 @@ func (Port) Fields() []ent.Field { Annotations( entgql.Type("ID"), entgql.Skip(entgql.SkipWhereInput, entgql.SkipMutationUpdateInput), + pubsubinfo.EventsHookAdditionalSubject("loadbalancer"), ), } } diff --git a/internal/graphapi/loadbalancer.resolvers.go b/internal/graphapi/loadbalancer.resolvers.go index 843317cfd..b151fa407 100644 --- a/internal/graphapi/loadbalancer.resolvers.go +++ b/internal/graphapi/loadbalancer.resolvers.go @@ -27,7 +27,7 @@ func (r *mutationResolver) LoadBalancerCreate(ctx context.Context, input generat } if config.AppConfig.LoadBalancerLimit > 0 { - count, err := r.client.LoadBalancer.Query().Where(predicate.LoadBalancer(loadbalancer.OwnerIDEQ(input.OwnerID))).Count(ctx) + count, err := r.client.LoadBalancer.Query().Where(loadbalancer.OwnerIDEQ(input.OwnerID)).Count(ctx) if err != nil { r.logger.Errorw("failed to query loadbalancer count", "error", err) } @@ -64,10 +64,6 @@ func (r *mutationResolver) LoadBalancerUpdate(ctx context.Context, id gidx.Prefi return nil, err } - if err := permissions.CheckAccess(ctx, id, actionLoadBalancerUpdate); err != nil { - return nil, err - } - lb, err := r.client.LoadBalancer.Get(ctx, id) if err != nil { if generated.IsNotFound(err) { @@ -78,6 +74,10 @@ func (r *mutationResolver) LoadBalancerUpdate(ctx context.Context, id gidx.Prefi return nil, ErrInternalServerError } + if err := permissions.CheckAccess(ctx, lb.OwnerID, actionLoadBalancerUpdate); err != nil { + return nil, err + } + lb, err = lb.Update().SetInput(input).Save(ctx) if err != nil { if generated.IsValidationError(err) { @@ -105,10 +105,6 @@ func (r *mutationResolver) LoadBalancerDelete(ctx context.Context, id gidx.Prefi return nil, err } - if err := permissions.CheckAccess(ctx, id, actionLoadBalancerDelete); err != nil { - return nil, err - } - lb, err := r.client.LoadBalancer.Get(ctx, id) if err != nil { if generated.IsNotFound(err) { @@ -119,6 +115,10 @@ func (r *mutationResolver) LoadBalancerDelete(ctx context.Context, id gidx.Prefi return nil, ErrInternalServerError } + if err := permissions.CheckAccess(ctx, lb.OwnerID, actionLoadBalancerDelete); err != nil { + return nil, err + } + tx, err := r.client.BeginTx(ctx, &sql.TxOptions{}) if err != nil { logger.Errorw("failed to begin transaction", "error", err) @@ -191,10 +191,6 @@ func (r *queryResolver) LoadBalancer(ctx context.Context, id gidx.PrefixedID) (* return nil, err } - if err := permissions.CheckAccess(ctx, id, actionLoadBalancerGet); err != nil { - return nil, err - } - lb, err := r.client.LoadBalancer.Get(ctx, id) if err != nil { if generated.IsNotFound(err) { @@ -205,5 +201,9 @@ func (r *queryResolver) LoadBalancer(ctx context.Context, id gidx.PrefixedID) (* return nil, ErrInternalServerError } + if err := permissions.CheckAccess(ctx, lb.OwnerID, actionLoadBalancerGet); err != nil { + return nil, err + } + return lb, nil } diff --git a/internal/graphapi/origin.resolvers.go b/internal/graphapi/origin.resolvers.go index 48662c192..71e1c4a25 100644 --- a/internal/graphapi/origin.resolvers.go +++ b/internal/graphapi/origin.resolvers.go @@ -74,7 +74,7 @@ func (r *mutationResolver) LoadBalancerOriginUpdate(ctx context.Context, id gidx return nil, err } - ogn, err := r.client.Origin.Get(ctx, id) + ogn, err := r.client.Origin.Query().WithPool().Where(origin.IDEQ(id)).Only(ctx) if err != nil { if generated.IsNotFound(err) { return nil, err @@ -84,7 +84,7 @@ func (r *mutationResolver) LoadBalancerOriginUpdate(ctx context.Context, id gidx return nil, ErrInternalServerError } - if err := permissions.CheckAccess(ctx, ogn.PoolID, actionLoadBalancerPoolUpdate); err != nil { + if err := permissions.CheckAccess(ctx, ogn.Edges.Pool.OwnerID, actionLoadBalancerPoolUpdate); err != nil { return nil, err } @@ -121,7 +121,7 @@ func (r *mutationResolver) LoadBalancerOriginDelete(ctx context.Context, id gidx return nil, err } - ogn, err := r.client.Origin.Get(ctx, id) + ogn, err := r.client.Origin.Query().WithPool().Where(origin.IDEQ(id)).Only(ctx) if err != nil { if generated.IsNotFound(err) { return nil, err @@ -131,7 +131,7 @@ func (r *mutationResolver) LoadBalancerOriginDelete(ctx context.Context, id gidx return nil, ErrInternalServerError } - if err := permissions.CheckAccess(ctx, ogn.PoolID, actionLoadBalancerPoolUpdate); err != nil { + if err := permissions.CheckAccess(ctx, ogn.Edges.Pool.OwnerID, actionLoadBalancerPoolUpdate); err != nil { return nil, err } diff --git a/internal/graphapi/pool.resolvers.go b/internal/graphapi/pool.resolvers.go index 61fffca84..ebadbd842 100644 --- a/internal/graphapi/pool.resolvers.go +++ b/internal/graphapi/pool.resolvers.go @@ -86,10 +86,6 @@ func (r *mutationResolver) LoadBalancerPoolUpdate(ctx context.Context, id gidx.P return nil, err } - if err := permissions.CheckAccess(ctx, id, actionLoadBalancerPoolUpdate); err != nil { - return nil, err - } - pool, err := r.client.Pool.Get(ctx, id) if err != nil { if generated.IsNotFound(err) { @@ -100,6 +96,10 @@ func (r *mutationResolver) LoadBalancerPoolUpdate(ctx context.Context, id gidx.P return nil, ErrInternalServerError } + if err := permissions.CheckAccess(ctx, pool.OwnerID, actionLoadBalancerPoolUpdate); err != nil { + return nil, err + } + ports, err := r.client.Port.Query().Where(port.HasLoadBalancerWith(loadbalancer.OwnerIDEQ(pool.OwnerID))).Where(port.IDIn(input.AddPortIDs...)).All(ctx) if err != nil { logger.Errorw("failed to query input ports", "error", err) @@ -152,11 +152,8 @@ func (r *mutationResolver) LoadBalancerPoolDelete(ctx context.Context, id gidx.P return nil, err } - if err := permissions.CheckAccess(ctx, id, actionLoadBalancerPoolDelete); err != nil { - return nil, err - } - - if _, err := r.client.Pool.Get(ctx, id); err != nil { + p, err := r.client.Pool.Get(ctx, id) + if err != nil { if generated.IsNotFound(err) { return nil, err } @@ -165,6 +162,10 @@ func (r *mutationResolver) LoadBalancerPoolDelete(ctx context.Context, id gidx.P return nil, ErrInternalServerError } + if err := permissions.CheckAccess(ctx, p.OwnerID, actionLoadBalancerPoolDelete); err != nil { + return nil, err + } + tx, err := r.client.BeginTx(ctx, &sql.TxOptions{}) if err != nil { logger.Errorw("failed to begin transaction", "error", err) @@ -230,10 +231,6 @@ func (r *queryResolver) LoadBalancerPool(ctx context.Context, id gidx.PrefixedID return nil, err } - if err := permissions.CheckAccess(ctx, id, actionLoadBalancerPoolGet); err != nil { - return nil, err - } - pool, err := r.client.Pool.Get(ctx, id) if err != nil { if generated.IsNotFound(err) { @@ -244,5 +241,9 @@ func (r *queryResolver) LoadBalancerPool(ctx context.Context, id gidx.PrefixedID return nil, ErrInternalServerError } + if err := permissions.CheckAccess(ctx, pool.OwnerID, actionLoadBalancerPoolGet); err != nil { + return nil, err + } + return pool, nil } diff --git a/internal/graphapi/port.resolvers.go b/internal/graphapi/port.resolvers.go index 7a8102c67..1724c8d80 100644 --- a/internal/graphapi/port.resolvers.go +++ b/internal/graphapi/port.resolvers.go @@ -13,6 +13,7 @@ import ( "go.infratographer.com/load-balancer-api/internal/ent/generated" "go.infratographer.com/load-balancer-api/internal/ent/generated/pool" + "go.infratographer.com/load-balancer-api/internal/ent/generated/port" "go.infratographer.com/load-balancer-api/pkg/metadata" ) @@ -93,7 +94,7 @@ func (r *mutationResolver) LoadBalancerPortUpdate(ctx context.Context, id gidx.P return nil, err } - p, err := r.client.Port.Get(ctx, id) + p, err := r.client.Port.Query().WithLoadBalancer().Where(port.IDEQ(id)).Only(ctx) if err != nil { if generated.IsNotFound(err) { return nil, err @@ -103,7 +104,7 @@ func (r *mutationResolver) LoadBalancerPortUpdate(ctx context.Context, id gidx.P return nil, ErrInternalServerError } - if err := permissions.CheckAccess(ctx, p.LoadBalancerID, actionLoadBalancerUpdate); err != nil { + if err := permissions.CheckAccess(ctx, p.Edges.LoadBalancer.OwnerID, actionLoadBalancerUpdate); err != nil { return nil, err } @@ -164,7 +165,7 @@ func (r *mutationResolver) LoadBalancerPortDelete(ctx context.Context, id gidx.P return nil, err } - p, err := r.client.Port.Get(ctx, id) + p, err := r.client.Port.Query().WithLoadBalancer().Where(port.IDEQ(id)).Only(ctx) if err != nil { if generated.IsNotFound(err) { return nil, err @@ -174,7 +175,7 @@ func (r *mutationResolver) LoadBalancerPortDelete(ctx context.Context, id gidx.P return nil, ErrInternalServerError } - if err := permissions.CheckAccess(ctx, p.LoadBalancerID, actionLoadBalancerUpdate); err != nil { + if err := permissions.CheckAccess(ctx, p.Edges.LoadBalancer.OwnerID, actionLoadBalancerUpdate); err != nil { return nil, err } diff --git a/internal/manualhooks/hooks.go b/internal/manualhooks/hooks.go index e770b094f..0c3608e2f 100644 --- a/internal/manualhooks/hooks.go +++ b/internal/manualhooks/hooks.go @@ -1002,6 +1002,11 @@ func PortHooks() []ent.Hook { }) } + relationships = append(relationships, events.AuthRelationshipRelation{ + Relation: "loadbalancer", + SubjectID: load_balancer_id, + }) + msg := events.ChangeMessage{ EventType: eventType(m.Op()), SubjectID: objID, @@ -1089,6 +1094,11 @@ func PortHooks() []ent.Hook { additionalSubjects = append(additionalSubjects, dbObj.Edges.LoadBalancer.OwnerID) additionalSubjects = append(additionalSubjects, dbObj.Edges.LoadBalancer.ProviderID) + relationships = append(relationships, events.AuthRelationshipRelation{ + Relation: "loadbalancer", + SubjectID: dbObj.LoadBalancerID, + }) + // we have all the info we need, now complete the mutation before we process the event retValue, err := next.Mutate(ctx, m) if err != nil {