Skip to content

Commit

Permalink
Add server and client implementations for MPA. (#364)
Browse files Browse the repository at this point in the history
These changes are sufficient for MPA when using a direct connection to the server. Here's a few sample commands you can run in parallel to try it out.

```
go run ./cmd/sansshell-server
go run ./cmd/sanssh -client-cert ./auth/mtls/testdata/client.pem -client-key ./auth/mtls/testdata/client.key -mpa -targets localhost healthcheck validate
go run ./cmd/sanssh -client-cert ./services/mpa/testdata/approver.pem -client-key ./services/mpa/testdata/approver.key -targets localhost mpa approve a59c2fef-748944da-336c9d35
```

I've added some new testdata certs because I'm forbidding cases where approver == requester. I've updated the sansshell server code to allow any request if it's requested by our "normal" client cert and approved by our "approver" client cert.

The output of `-mpa` prints a nonconfigurable help message to stderr while waiting on approval. If the command is already approved, the message won't show up.

```
$ sanssh -mpa -targets localhost healthcheck validate
Multi party auth requested, ask an approver to run:
  sanssh --targets localhost:50042 mpa approve a59c2fef-748944da-336c9d35
Target localhost:50042 (0) healthy`
```

This implements the client and server portion, but not the proxy portion. The proxy part mostly builds on top of what I have here and will take advantage of some other features I'm implementing.

- #361 for implementing the proxy equivalent of `ServerMPAAuthzHook()`
- #358 for implementing the proxy equivalents of `mpahooks.UnaryClientIntercepter()` and `mpahooks.StreamClientIntercepter()`
- #359 so that MPA can use the identity of the caller to the proxy instead of the identity of the proxy.

Part of #346
  • Loading branch information
stvnrhodes authored Nov 10, 2023
1 parent aabd3f5 commit 5f1ff80
Show file tree
Hide file tree
Showing 16 changed files with 1,469 additions and 1 deletion.
22 changes: 21 additions & 1 deletion auth/opa/rpcauth/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
)

// RPCAuthInput is used as policy input to validate Sansshell RPCs
// NOTE: RPCAuthInputForLogging must be updated when this changes.
type RPCAuthInput struct {
// The GRPC method name, as '/Package.Service/Method'
Method string `json:"method"`
Expand All @@ -52,6 +51,9 @@ type RPCAuthInput struct {
// Information about the host serving the RPC.
Host *HostAuthInput `json:"host"`

// Information about approvers when using multi-party authentication.
Approvers []*PrincipalAuthInput `json:"approvers"`

// Information about the environment in which the policy evaluation is
// happening.
Environment *EnvironmentInput `json:"environment"`
Expand Down Expand Up @@ -153,9 +155,27 @@ func NewRPCAuthInput(ctx context.Context, method string, req proto.Message) (*RP
return out, nil
}

type peerInfoKey struct{}

// AddPeerToContext adds a PeerAuthInput to the context. This is typically
// added by the rpcauth grpc interceptors.
func AddPeerToContext(ctx context.Context, p *PeerAuthInput) context.Context {
if p == nil {
return ctx
}
return context.WithValue(ctx, peerInfoKey{}, p)
}

// PeerInputFromContext populates peer information from the supplied
// context, if available.
func PeerInputFromContext(ctx context.Context) *PeerAuthInput {
// If this runs after rpcauth hooks, we can return richer data that includes
// information added by the hooks.
cached, ok := ctx.Value(peerInfoKey{}).(*PeerAuthInput)
if ok {
return cached
}

out := &PeerAuthInput{}
p, ok := peer.FromContext(ctx)
if !ok {
Expand Down
29 changes: 29 additions & 0 deletions auth/opa/rpcauth/rpcauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"context"
"fmt"
"strings"
"sync"

"github.com/go-logr/logr"
"go.opentelemetry.io/otel/attribute"
Expand Down Expand Up @@ -171,6 +172,7 @@ func (g *Authorizer) Authorize(ctx context.Context, req interface{}, info *grpc.
if err := g.Eval(ctx, authInput); err != nil {
return nil, err
}
ctx = AddPeerToContext(ctx, authInput.Peer)
return handler(ctx, req)
}

Expand All @@ -187,6 +189,7 @@ func (g *Authorizer) AuthorizeClient(ctx context.Context, method string, req, re
if err := g.Eval(ctx, authInput); err != nil {
return err
}
ctx = AddPeerToContext(ctx, authInput.Peer)
return invoker(ctx, method, req, reply, cc, opts...)
}

Expand All @@ -209,6 +212,16 @@ type wrappedClientStream struct {
grpc.ClientStream
method string
authz *Authorizer

peerMu sync.Mutex
lastPeerAuthInput *PeerAuthInput
}

func (e *wrappedClientStream) Context() context.Context {
e.peerMu.Lock()
ctx := AddPeerToContext(e.ClientStream.Context(), e.lastPeerAuthInput)
e.peerMu.Unlock()
return ctx
}

// see: grpc.ClientStream.SendMsg
Expand All @@ -225,6 +238,9 @@ func (e *wrappedClientStream) SendMsg(req interface{}) error {
if err := e.authz.Eval(ctx, authInput); err != nil {
return err
}
e.peerMu.Lock()
e.lastPeerAuthInput = authInput.Peer
e.peerMu.Unlock()
return e.ClientStream.SendMsg(req)
}

Expand All @@ -243,6 +259,16 @@ type wrappedStream struct {
grpc.ServerStream
info *grpc.StreamServerInfo
authz *Authorizer

peerMu sync.Mutex
lastPeerAuthInput *PeerAuthInput
}

func (e *wrappedStream) Context() context.Context {
e.peerMu.Lock()
ctx := AddPeerToContext(e.ServerStream.Context(), e.lastPeerAuthInput)
e.peerMu.Unlock()
return ctx
}

// see: grpc.ServerStream.RecvMsg
Expand All @@ -266,5 +292,8 @@ func (e *wrappedStream) RecvMsg(req interface{}) error {
if err := e.authz.Eval(ctx, authInput); err != nil {
return err
}
e.peerMu.Lock()
e.lastPeerAuthInput = authInput.Peer
e.peerMu.Unlock()
return nil
}
7 changes: 7 additions & 0 deletions cmd/sanssh/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/Snowflake-Labs/sansshell/proxy/proxy"

cmdUtil "github.com/Snowflake-Labs/sansshell/cmd/util"
"github.com/Snowflake-Labs/sansshell/services/mpa/mpahooks"
"github.com/Snowflake-Labs/sansshell/services/util"
)

Expand Down Expand Up @@ -74,6 +75,8 @@ type RunState struct {
// BatchSize if non-zero will do the requested operation to the targets but in
// N calls to the proxy where N is the target list size divided by BatchSize.
BatchSize int
// If true, add an interceptor that performs the multi-party auth flow
EnableMPA bool
}

const (
Expand Down Expand Up @@ -317,6 +320,10 @@ func Run(ctx context.Context, rs RunState) {
streamInterceptors = append(streamInterceptors, clientAuthz.AuthorizeClientStream)
unaryInterceptors = append(unaryInterceptors, clientAuthz.AuthorizeClient)
}
if rs.EnableMPA {
unaryInterceptors = append(unaryInterceptors, mpahooks.UnaryClientIntercepter())
streamInterceptors = append(streamInterceptors, mpahooks.StreamClientIntercepter())
}
// timeout interceptor should be the last item in ops so that it's executed first.
streamInterceptors = append(streamInterceptors, StreamClientTimeoutInterceptor(rs.IdleTimeout))
unaryInterceptors = append(unaryInterceptors, UnaryClientTimeoutInterceptor(rs.IdleTimeout))
Expand Down
4 changes: 4 additions & 0 deletions cmd/sanssh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
_ "github.com/Snowflake-Labs/sansshell/services/healthcheck/client"
_ "github.com/Snowflake-Labs/sansshell/services/httpoverrpc/client"
_ "github.com/Snowflake-Labs/sansshell/services/localfile/client"
_ "github.com/Snowflake-Labs/sansshell/services/mpa/client"
_ "github.com/Snowflake-Labs/sansshell/services/packages/client"
_ "github.com/Snowflake-Labs/sansshell/services/power/client"
_ "github.com/Snowflake-Labs/sansshell/services/process/client"
Expand Down Expand Up @@ -84,6 +85,7 @@ If port is blank the default of %d will be used`, proxyEnv, defaultProxyPort))
verbosity = flag.Int("v", -1, "Verbosity level. > 0 indicates more extensive logging")
prefixHeader = flag.Bool("h", false, "If true prefix each line of output with '<index>-<target>: '")
batchSize = flag.Int("batch-size", 0, "If non-zero will perform the proxy->target work in batches of this size (with any remainder done at the end).")
mpa = flag.Bool("mpa", false, "Request multi-party approval for commands. This will create an MPA request, wait for approval, and then execute the command.")

// targets will be bound to --targets for sending a single request to N nodes.
targetsFlag util.StringSliceCommaOrWhitespaceFlag
Expand Down Expand Up @@ -118,6 +120,7 @@ func init() {
subcommands.ImportantFlag("justification")
subcommands.ImportantFlag("client-policy")
subcommands.ImportantFlag("client-policy-file")
subcommands.ImportantFlag("mpa")
subcommands.ImportantFlag("v")
}

Expand Down Expand Up @@ -192,6 +195,7 @@ func main() {
ClientPolicy: clientPolicy,
PrefixOutput: *prefixHeader,
BatchSize: *batchSize,
EnableMPA: *mpa,
}
ctx := logr.NewContext(context.Background(), logger)

Expand Down
10 changes: 10 additions & 0 deletions cmd/sansshell-server/default-policy.rego
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,13 @@ allow {
allow {
input.method = "/SysInfo.SysInfo/Dmesg"
}

# Allow anything with MPA
allow {
input.peer.principal.id = "sanssh"
input.approvers[_].id = "approver"
}

allow {
startswith(input.method, "/Mpa.Mpa/")
}
2 changes: 2 additions & 0 deletions cmd/sansshell-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ import (
fdbserver "github.com/Snowflake-Labs/sansshell/services/fdb/server"
_ "github.com/Snowflake-Labs/sansshell/services/healthcheck/server"
_ "github.com/Snowflake-Labs/sansshell/services/localfile/server"
mpa "github.com/Snowflake-Labs/sansshell/services/mpa/server"
_ "github.com/Snowflake-Labs/sansshell/services/power/server"

// Packages needs a real import to bind flags.
Expand Down Expand Up @@ -171,6 +172,7 @@ func main() {
server.WithParsedPolicy(parsed),
server.WithJustification(*justification),
server.WithAuthzHook(rpcauth.PeerPrincipalFromCertHook()),
server.WithAuthzHook(mpa.ServerMPAAuthzHook()),
server.WithRawServerOption(func(s *grpc.Server) { reflection.Register(s) }),
server.WithRawServerOption(func(s *grpc.Server) { channelz.RegisterChannelzServiceToServer(s) }),
server.WithDebugPort(*debugport),
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/go-logr/stdr v1.2.2
github.com/google/go-cmp v0.6.0
github.com/google/subcommands v1.2.0
github.com/gowebpki/jcs v1.0.1
github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.0
github.com/open-policy-agent/opa v0.58.0
github.com/pkg/errors v0.9.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 5f1ff80

Please sign in to comment.