Skip to content

Commit

Permalink
chore: Improve handling of auth data in context (#262)
Browse files Browse the repository at this point in the history
Signed-off-by: jannfis <[email protected]>
  • Loading branch information
jannfis authored Dec 20, 2024
1 parent 354cd77 commit 4adae7a
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 18 deletions.
27 changes: 24 additions & 3 deletions internal/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,39 @@ import (
"fmt"

"github.com/argoproj-labs/argocd-agent/pkg/types"
"k8s.io/apimachinery/pkg/api/validation"
)

/*
Package session contains various functions to access and manipulate session
data.
*/

// ClientIdFromContext returns the client ID stored in context ctx. If there
// is no client ID in the context, or the client ID is invalid, returns an
// error.
func ClientIdFromContext(ctx context.Context) (string, error) {
val := ctx.Value(types.ContextAgentIdentifier)
if clientId, ok := val.(string); !ok {
clientId, ok := val.(string)
if !ok {
return "", fmt.Errorf("no client identifier found in context")
} else {
return clientId, nil
}
if !IsValidClientId(clientId) {
return "", fmt.Errorf("invalid client identifier: %s", clientId)
}
return clientId, nil
}

// ClientIdToContext returns a copy of context ctx with the clientId stored
func ClientIdToContext(ctx context.Context, clientId string) context.Context {
return context.WithValue(ctx, types.ContextAgentIdentifier, clientId)
}

// IsValidClientId returns true if the string s is considered a valid client
// identifier.
func IsValidClientId(s string) bool {
if errs := validation.NameIsDNSSubdomain(s, false); len(errs) > 0 {
return false
}
return true
}
28 changes: 28 additions & 0 deletions internal/session/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package session

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func Test_ClientIdFromContext(t *testing.T) {
t.Run("Successfully extract client ID", func(t *testing.T) {
ctx := ClientIdToContext(context.Background(), "agent")
a, err := ClientIdFromContext(ctx)
assert.NoError(t, err)
assert.Equal(t, "agent", a)
})
t.Run("No client ID in context", func(t *testing.T) {
a, err := ClientIdFromContext(context.Background())
assert.ErrorContains(t, err, "no client identifier")
assert.Empty(t, a)
})
t.Run("Invalid client ID in context", func(t *testing.T) {
ctx := ClientIdToContext(context.Background(), "ag_ent")
a, err := ClientIdFromContext(ctx)
assert.ErrorContains(t, err, "invalid client identifier")
assert.Empty(t, a)
})
}
17 changes: 3 additions & 14 deletions principal/apis/eventstream/eventstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import (

"github.com/argoproj-labs/argocd-agent/internal/event"
"github.com/argoproj-labs/argocd-agent/internal/queue"
"github.com/argoproj-labs/argocd-agent/internal/session"
"github.com/argoproj-labs/argocd-agent/pkg/api/grpc/eventstreamapi"
"github.com/argoproj-labs/argocd-agent/pkg/types"
"github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1"
"github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -136,7 +136,7 @@ func (s *Server) newClientConnection(ctx context.Context, timeout time.Duration)
c := &client{}
c.wg = &sync.WaitGroup{}

agentName, err := agentName(ctx)
agentName, err := session.ClientIdFromContext(ctx)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
Expand All @@ -160,17 +160,6 @@ func (s *Server) newClientConnection(ctx context.Context, timeout time.Duration)
return c, nil
}

// agentName gets the agent name from the context ctx. If no agent identifier
// could be found in the context, returns an error.
func agentName(ctx context.Context) (string, error) {
agentName, ok := ctx.Value(types.ContextAgentIdentifier).(string)
if !ok {
return "", fmt.Errorf("invalid context: no agent name")
}
// TODO: check agentName for validity
return agentName, nil
}

// onDisconnect must be called whenever client c disconnects from the stream
func (s *Server) onDisconnect(c *client) {
c.lock.Lock()
Expand Down Expand Up @@ -379,7 +368,7 @@ func (s *Server) Push(pushs eventstreamapi.EventStream_PushServer) error {
}
defer cancel()

agentName, err := agentName(ctx)
agentName, err := session.ClientIdFromContext(ctx)
if err != nil {
return status.Error(codes.InvalidArgument, err.Error())
}
Expand Down
3 changes: 2 additions & 1 deletion principal/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/argoproj-labs/argocd-agent/internal/auth"
"github.com/argoproj-labs/argocd-agent/internal/grpcutil"
"github.com/argoproj-labs/argocd-agent/internal/session"
"github.com/argoproj-labs/argocd-agent/pkg/types"
middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"google.golang.org/grpc"
Expand Down Expand Up @@ -117,7 +118,7 @@ func (s *Server) authenticate(ctx context.Context) (context.Context, error) {

// claims at this point is validated and we can propagate values to the
// context.
authCtx := context.WithValue(ctx, types.ContextAgentIdentifier, agentInfo.ClientID)
authCtx := session.ClientIdToContext(ctx, agentInfo.ClientID)
if !s.queues.HasQueuePair(agentInfo.ClientID) {
logCtx.Tracef("Creating a new queue pair for client %s", agentInfo.ClientID)
if err := s.queues.Create(agentInfo.ClientID); err != nil {
Expand Down

0 comments on commit 4adae7a

Please sign in to comment.