Skip to content

Commit

Permalink
checkpoint again
Browse files Browse the repository at this point in the history
  • Loading branch information
stvnrhodes committed Sep 30, 2023
1 parent dad1f54 commit c800db1
Show file tree
Hide file tree
Showing 11 changed files with 378 additions and 161 deletions.
16 changes: 16 additions & 0 deletions auth/opa/rpcauth/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,19 @@ func JustificationHook(justificationFunc func(string) error) RPCAuthzHook {
return nil
})
}

// PeerPrincipalFromCertHook returns an RPCAuthzHook that sets principal
// information based on the peer's certificate, using the common name as
// the id and the organizational units as the groups.
func PeerPrincipalFromCertHook() RPCAuthzHook {
return RPCAuthzHookFunc(func(_ context.Context, input *RPCAuthInput) error {
if input.Peer == nil || input.Peer.Cert == nil {
return nil
}
input.Host.Principal = &PrincipalAuthInput{
ID: input.Peer.Cert.Subject.CommonName,
Groups: input.Peer.Cert.Subject.OrganizationalUnit,
}
return nil
})
}
2 changes: 0 additions & 2 deletions auth/opa/rpcauth/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,6 @@ func PeerInputFromContext(ctx context.Context) *PeerAuthInput {
}
out.Net = NetInputFromAddr(p.Addr)
out.Cert = CertInputFrom(p.AuthInfo)
// DO NOT SUBMIT
out.Principal = &PrincipalAuthInput{ID: out.Cert.Subject.CommonName}
return out
}

Expand Down
1 change: 1 addition & 0 deletions cmd/proxy-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ func main() {
server.WithCredSource(*credSource),
server.WithHostPort(*hostport),
server.WithJustification(*justification),
server.WithAuthzHook(rpcauth.PeerPrincipalFromCertHook()),
server.WithRawServerOption(func(s *grpc.Server) { reflection.Register(s) }),
server.WithRawServerOption(func(s *grpc.Server) { channelz.RegisterChannelzServiceToServer(s) }),
server.WithRawServerOption(srv.Register),
Expand Down
4 changes: 4 additions & 0 deletions cmd/sansshell-server/default-policy.rego
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ allow {
input.method = "/Mpa.Mpa/Approve"
}

allow {
input.method = "/Mpa.Mpa/List"
}

allow {
input.type = "LocalFile.ReadActionRequest"
input.message.file.filename = "/etc/hosts"
Expand Down
1 change: 1 addition & 0 deletions cmd/sansshell-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func main() {
server.WithHostPort(*hostport),
server.WithParsedPolicy(parsed),
server.WithJustification(*justification),
server.WithAuthzHook(rpcauth.PeerPrincipalFromCertHook()),
server.WithRawServerOption(func(s *grpc.Server) { reflection.Register(s) }),
server.WithRawServerOption(func(s *grpc.Server) { channelz.RegisterChannelzServiceToServer(s) }),
server.WithDebugPort(*debugport),
Expand Down
37 changes: 32 additions & 5 deletions proxy/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ import (
proxypb "github.com/Snowflake-Labs/sansshell/proxy"
)

// Interceptor intercepts the execution of an RPC through the proxy. Interceptors
// can be added to a Conn by modifying its Interceptors field. When an interceptor
// is set on a ClientConn, all proxy RPC invocations are delegated to the interceptor,
// and it is the responsibility of the interceptor to call invoker to complete the
// processing of the RPC.
type Interceptor func(ctx context.Context, conn *Conn, method string, args any, invoker Invoker, opts ...grpc.CallOption) (<-chan *Ret, error)

// Invoker is called by Interceptor to complete RPCs.
type Invoker func(ctx context.Context, method string, args any, opts ...grpc.CallOption) (<-chan *Ret, error)

// Conn is a grpc.ClientConnInterface which is connected to the proxy
// converting calls into RPC the proxy understands.
type Conn struct {
Expand All @@ -55,6 +65,11 @@ type Conn struct {

// If this is true we're not proxy but instead direct connect.
direct bool

// Interceptors allow intercepting Invoke and InvokeOneMany calls
// that go through a proxy.
// It is unsafe to modify Intercepters while calls are in progress.
Interceptors []Interceptor
}

// Ret defines the internal API for getting responses from the proxy.
Expand Down Expand Up @@ -91,7 +106,7 @@ type proxyStream struct {
}

// Invoke - see grpc.ClientConnInterface
func (p *Conn) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
func (p *Conn) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error {
if p.Direct() {
// TODO(jchacon): Add V1 style logging indicating pass through in use.
return p.cc.Invoke(ctx, method, args, reply, opts...)
Expand Down Expand Up @@ -247,7 +262,7 @@ func (p *proxyStream) closeClients() error {
}

// see grpc.ClientStream
func (p *proxyStream) SendMsg(args interface{}) error {
func (p *proxyStream) SendMsg(args any) error {
if p.sendClosed {
return status.Error(codes.FailedPrecondition, "sending on a closed connection")
}
Expand All @@ -262,13 +277,13 @@ func (p *proxyStream) SendMsg(args interface{}) error {
}

// see grpc.ClientStream
func (p *proxyStream) RecvMsg(m interface{}) error {
func (p *proxyStream) RecvMsg(m any) error {
// Up front check for nothing left since we closed all streams.
if len(p.ids) == 0 {
return io.EOF
}

// Since the API is an interface{} we can change what this normally
// Since the API is an any we can change what this normally
// expects from a proto.Message to a *[]*Ret instead.
//
// Anything else is an error if we have > 1 target. In the one target
Expand Down Expand Up @@ -527,7 +542,19 @@ func (p *Conn) createStreams(ctx context.Context, method string) (proxypb.Proxy_
// NOTE: The returned channel must be read until it closes in order to avoid leaking goroutines.
//
// TODO(jchacon): Should add the ability to specify remote dial timeout in the connection to the proxy.
func (p *Conn) InvokeOneMany(ctx context.Context, method string, args interface{}, opts ...grpc.CallOption) (<-chan *Ret, error) {
func (p *Conn) InvokeOneMany(ctx context.Context, method string, args any, opts ...grpc.CallOption) (<-chan *Ret, error) {
invoke := p.invokeOneMany
for _, intercept := range p.Interceptors {
intercept := intercept
inner := invoke
invoke = func(ctx context.Context, method string, args any, opts ...grpc.CallOption) (<-chan *Ret, error) {
return intercept(ctx, p, method, args, inner, opts...)
}
}
return invoke(ctx, method, args, opts...)
}

func (p *Conn) invokeOneMany(ctx context.Context, method string, args any, opts ...grpc.CallOption) (<-chan *Ret, error) {
requestMsg, ok := args.(proto.Message)
if !ok {
return nil, status.Error(codes.InvalidArgument, "args must be a proto.Message")
Expand Down
42 changes: 34 additions & 8 deletions services/mpa/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"flag"
"fmt"
"os"
"strings"

"github.com/Snowflake-Labs/sansshell/client"
pb "github.com/Snowflake-Labs/sansshell/services/mpa"
Expand Down Expand Up @@ -136,7 +137,9 @@ func (p *approveCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter
return subcommands.ExitSuccess
}

type listCmd struct{}
type listCmd struct {
verbose bool
}

func (*listCmd) Name() string { return "list" }
func (*listCmd) Synopsis() string { return "Lists out pending MPA requests on machines" }
Expand All @@ -146,12 +149,14 @@ func (*listCmd) Usage() string {
`
}

func (p *listCmd) SetFlags(f *flag.FlagSet) {}
func (p *listCmd) SetFlags(f *flag.FlagSet) {
f.BoolVar(&p.verbose, "v", false, "Verbose: list full details of MPA request")
}

func (p *listCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
state := args[0].(*util.ExecuteState)
if f.NArg() != 1 {
fmt.Fprintln(os.Stderr, "Please specify an ID to approve.")
if f.NArg() != 0 {
fmt.Fprintln(os.Stderr, "List takes no args.")
return subcommands.ExitUsageError
}

Expand All @@ -167,11 +172,33 @@ func (p *listCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfac
}
for r := range resp {
if r.Error != nil {
fmt.Fprintf(state.Err[r.Index], "Failed to wait for approval: %v\n", r.Error)
fmt.Fprintln(state.Err[r.Index], r.Error)
continue
}
for _, id := range r.Resp.Id {
fmt.Fprintln(state.Out[r.Index], id)
for _, item := range r.Resp.Item {
msg := []string{item.Id}
if p.verbose {
if len(item.Approver) > 0 {
var approvers []string
for _, a := range item.Approver {
approvers = append(approvers, a.Id)
}
msg = append(msg, fmt.Sprintf("(approved by %v)", strings.Join(approvers, ",")))
}
msg = append(msg, protojson.MarshalOptions{UseProtoNames: true}.Format(item.Action))
} else {
msg = append(msg, item.Action.GetMethod())
if item.Action.GetUser() != "" {
msg = append(msg, "from", item.Action.GetUser())
}
if item.Action.GetJustification() != "" {
msg = append(msg, "for", item.Action.GetJustification())
}
if len(item.Approver) > 0 {
msg = append(msg, "(approved)")
}
}
fmt.Fprintln(state.Out[r.Index], strings.Join(msg, " "))
}
}
return subcommands.ExitSuccess
Expand Down Expand Up @@ -257,7 +284,6 @@ func (p *getCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface
fmt.Fprintf(state.Err[r.Index], "Error: action was nil when looking up MPA request")
continue
}
// DO NOT SUBMIT: Pretty print message
fmt.Fprintln(state.Out[r.Index], protojson.MarshalOptions{UseProtoNames: true, Multiline: true}.Format(r.Resp))
}
return subcommands.ExitSuccess
Expand Down
Loading

0 comments on commit c800db1

Please sign in to comment.