From ed01e8b43361f43830ba99e6dedeb59e5f42fb41 Mon Sep 17 00:00:00 2001 From: Steven Rhodes Date: Thu, 14 Dec 2023 17:04:02 -0800 Subject: [PATCH] Refactor fdbmovedata with more concurrency (#388) This code was originally written to only work with a single caller. It took a lock for the entire call and no other callers could make progress until the lock was released. This was done for simplicity of implementation, but turned out to be too simple. If the caller of WaitData died, the context would be cancelled and two possible scenarios could happen. - If the command was writing to only stdout, the stdout goroutine would exit due to the failure to send the return stream. The stderr goroutine would continue to run waiting for some stderr to get written. For long-running commands, the pipe buffer (~64kb) would eventually fill up and the child subroutine would halt and be unable to progress. - If the command was writing to both stdout and stderr, the client code would exit both goroutines when it tries to send on the stream and it would get to the `cmd.Wait` part. Long-running commands hit the same failure mode. The server was taking a lock for the entire wait call, so everything got stuck from this. I've refactored to support a few different features: - Locks are now taken in a much more fine-grained way. There shouldn't be any locks that last longer than it takes to spawn a process. - Multiple FDBMoveDataWait are supported. Each call will stream back stdout/stderr. We'll still forget about a command as soon as it finishes. - Stdout and stderr are collected regardless of whether FDBMoveDataWait is running. We retain a maximum of 1MB of each to prevent this from consuming unlimited memory. We still don't support running multiple FDBMoveDataCopy commands at once. If we want to do so, all we need to do is remove the check on `len(s.operations)` and update the tests. --- services/fdb/server/fdbmovedata.go | 240 +++++++++++++++++------- services/fdb/server/fdbmovedata_test.go | 64 ++++++- services/fdb/server/server_test.go | 2 +- 3 files changed, 231 insertions(+), 75 deletions(-) diff --git a/services/fdb/server/fdbmovedata.go b/services/fdb/server/fdbmovedata.go index 19e7c03d..7c84905f 100644 --- a/services/fdb/server/fdbmovedata.go +++ b/services/fdb/server/fdbmovedata.go @@ -27,9 +27,11 @@ import ( "github.com/Snowflake-Labs/sansshell/services" pb "github.com/Snowflake-Labs/sansshell/services/fdb" "github.com/go-logr/logr" + "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" ) var ( @@ -40,12 +42,88 @@ var ( generateFDBMoveDataArgs = generateFDBMoveDataArgsImpl ) +const maxHistoryBytes = 1024 * 1024 + +// reader allows reading incoming data from a multiReader +type reader struct { + next chan []byte + parent *multiReader +} + +func (r *reader) Next(ctx context.Context) ([]byte, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case b := <-r.next: + return b, nil + case <-r.parent.finished: + return nil, r.parent.finalErr + } +} + +// multiReader continuously reads from the provided io.Reader and provides ways +// for multiple callers to read from the data it provides. It internally retains +// maxHistoryBytes of past data. +type multiReader struct { + src io.Reader + history []byte + readers []*reader + finalErr error + finished chan struct{} + mu sync.Mutex +} + +func newMultiReader(src io.Reader) *multiReader { + m := &multiReader{src: src, finished: make(chan struct{})} + go m.backgroundRead() + return m +} + +func (m *multiReader) backgroundRead() { + for { + b := make([]byte, 512) + n, err := m.src.Read(b) + if err != nil { + m.finalErr = err + close(m.finished) + return + } + b = b[:n] + m.mu.Lock() + for _, r := range m.readers { + select { + case r.next <- b: + default: + } + } + m.history = append(m.history, b...) + if len(m.history) > maxHistoryBytes { + m.history = m.history[maxHistoryBytes/2:] + } + m.mu.Unlock() + } +} + +// Reader returns the bytes read so far and a reader to use for subsequent reads. +func (m *multiReader) Reader() ([]byte, *reader) { + m.mu.Lock() + defer m.mu.Unlock() + r := &reader{next: make(chan []byte, 10), parent: m} + m.readers = append(m.readers, r) + return m.history, r +} + +type moveOperation struct { + req *pb.FDBMoveDataCopyRequest + stdout *multiReader + stderr *multiReader + done chan struct{} + exitErr *exec.ExitError +} + type fdbmovedata struct { - mu sync.Mutex - id int64 - cmd *exec.Cmd - stdout io.ReadCloser - stderr io.ReadCloser + mu sync.Mutex + operations map[int64]*moveOperation } func generateFDBMoveDataArgsImpl(req *pb.FDBMoveDataCopyRequest) ([]string, error) { @@ -63,24 +141,22 @@ func generateFDBMoveDataArgsImpl(req *pb.FDBMoveDataCopyRequest) ([]string, erro } func (s *fdbmovedata) FDBMoveDataCopy(ctx context.Context, req *pb.FDBMoveDataCopyRequest) (*pb.FDBMoveDataCopyResponse, error) { - lockSuccess := s.mu.TryLock() - if !(lockSuccess) { - return nil, status.Errorf(codes.Internal, "Copy or Wait command already running") - } + s.mu.Lock() defer s.mu.Unlock() - logger := logr.FromContextOrDiscard(ctx) // The sansshell server should only run one copy command at a time - if !(s.cmd == nil) { - logger.Info("existing command already running. returning early") - logger.Info("command details", "cmd", s.cmd.String()) - logger.Info("command running with id", "id", s.id) - earlyresp := &pb.FDBMoveDataCopyResponse{ - Id: s.id, - Existing: true, + for id, o := range s.operations { + if proto.Equal(o.req, req) { + return &pb.FDBMoveDataCopyResponse{ + Id: id, + Existing: true, + }, nil } - return earlyresp, nil } - s.id = rand.Int63() + if len(s.operations) > 0 { + return nil, status.Errorf(codes.Internal, "Copy command already running") + } + logger := logr.FromContextOrDiscard(ctx) + id := rand.Int63() command, err := generateFDBMoveDataArgs(req) if err != nil { return nil, err @@ -103,81 +179,109 @@ func (s *fdbmovedata) FDBMoveDataCopy(ctx context.Context, req *pb.FDBMoveDataCo } logger.Info("executing local command", "cmd", cmd.String()) - logger.Info("command running with id", "id", s.id) - s.cmd = cmd - s.stdout = stdout - s.stderr = stderr - err = s.cmd.Start() - if err != nil { - s.cmd = nil - s.id = 0 + logger.Info("command running with id", "id", id) + if err = cmd.Start(); err != nil { return nil, status.Errorf(codes.Internal, "error running fdbmovedata cmd (%+v): %v", command, err) } + op := &moveOperation{ + req: req, + done: make(chan struct{}), + stdout: newMultiReader(stdout), + stderr: newMultiReader(stderr), + } + s.operations[id] = op + go func() { + // Wait for output to be done, then check command status + <-op.stdout.finished + <-op.stderr.finished + err := cmd.Wait() + if exitErr, ok := err.(*exec.ExitError); ok { + op.exitErr = exitErr + } + logger.Info("fdbmovedata command finished", "id", id, "err", err) + close(op.done) + }() + resp := &pb.FDBMoveDataCopyResponse{ - Id: s.id, + Id: id, Existing: false, } return resp, nil } func (s *fdbmovedata) FDBMoveDataWait(req *pb.FDBMoveDataWaitRequest, stream pb.FDBMove_FDBMoveDataWaitServer) error { - s.mu.Lock() - defer s.mu.Unlock() ctx := stream.Context() logger := logr.FromContextOrDiscard(ctx) - if !(req.Id == s.id) { - logger.Info("Provided ID and stored ID do not match", "providedID", req.Id, "storedID", s.id) - return status.Errorf(codes.Internal, "Provided ID %d does not match stored ID %d", req.Id, s.id) + + s.mu.Lock() + op, found := s.operations[req.Id] + var ids []int64 + for id := range s.operations { + ids = append(ids, id) } - if s.cmd == nil { - logger.Info("No command running on the server") - return status.Errorf(codes.Internal, "No command running on the server") + s.mu.Unlock() + if !found { + logger.Info("Provided ID and stored IDs do not match", "providedID", req.Id, "storedID", ids) + return status.Errorf(codes.Internal, "Provided ID %d does not match stored IDs %v", req.Id, ids) } - wg := &sync.WaitGroup{} - // Send stderr asynchronously - stderr := s.stderr - wg.Add(1) - go func() { - defer wg.Done() + stdoutHistory, stdout := op.stdout.Reader() + stderrHistory, stderr := op.stderr.Reader() + if stdoutHistory != nil || stderrHistory != nil { + if err := stream.Send(&pb.FDBMoveDataWaitResponse{Stdout: stdoutHistory, Stderr: stderrHistory}); err != nil { + return err + } + } + + // Send output asynchronously + g, gCtx := errgroup.WithContext(ctx) + g.Go(func() error { for { - buf := make([]byte, 1024) - n, err := stderr.Read(buf) + buf, err := stderr.Next(gCtx) if err != nil { - return + if err == io.EOF { + return nil + } + return fmt.Errorf("could not read stdout: %v", err) } - if err := stream.Send(&pb.FDBMoveDataWaitResponse{Stderr: buf[:n]}); err != nil { - return + if err := stream.Send(&pb.FDBMoveDataWaitResponse{Stderr: buf}); err != nil { + return err } } - }() - - // Send stdout asynchronously - stdout := s.stdout - wg.Add(1) - go func() { - defer wg.Done() + }) + g.Go(func() error { for { - buf := make([]byte, 1024) - n, err := stdout.Read(buf) + buf, err := stdout.Next(ctx) if err != nil { - return + if err == io.EOF { + return nil + } + return fmt.Errorf("could not read stderr: %v", err) } - if err := stream.Send(&pb.FDBMoveDataWaitResponse{Stdout: buf[:n]}); err != nil { - return + if err := stream.Send(&pb.FDBMoveDataWaitResponse{Stdout: buf}); err != nil { + return err } } - }() - wg.Wait() - err := s.cmd.Wait() - if exitErr, ok := err.(*exec.ExitError); ok { - return stream.Send(&pb.FDBMoveDataWaitResponse{RetCode: int32(exitErr.ExitCode())}) + }) + + if err := g.Wait(); err != nil { + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-op.done: } // clear the cmd to allow another call - s.cmd = nil - s.id = 0 - return err + s.mu.Lock() + delete(s.operations, req.Id) + s.mu.Unlock() + + if op.exitErr != nil { + return stream.Send(&pb.FDBMoveDataWaitResponse{RetCode: int32(op.exitErr.ExitCode())}) + } + return nil } // Register is called to expose this handler to the gRPC server @@ -186,5 +290,5 @@ func (s *fdbmovedata) Register(gs *grpc.Server) { } func init() { - services.RegisterSansShellService(&fdbmovedata{}) + services.RegisterSansShellService(&fdbmovedata{operations: make(map[int64]*moveOperation)}) } diff --git a/services/fdb/server/fdbmovedata_test.go b/services/fdb/server/fdbmovedata_test.go index 8d6ca990..e58df073 100644 --- a/services/fdb/server/fdbmovedata_test.go +++ b/services/fdb/server/fdbmovedata_test.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "time" "strings" "testing" @@ -93,7 +94,7 @@ func TestFDBMoveData(t *testing.T) { } _, err = waitResp.Recv() if err != io.EOF { - t.Error("unexpected") + t.Errorf("unexpected err %v", err) } }) } @@ -118,7 +119,7 @@ func TestFDBMoveDataDouble(t *testing.T) { generateFDBMoveDataArgs = func(req *pb.FDBMoveDataCopyRequest) ([]string, error) { _, err = savedGenerateFDBMoveDataArgs(req) - return []string{sh, "-c", "/bin/sleep 1; echo done"}, err + return []string{sh, "-c", "/bin/sleep 0.1; echo done"}, err } for _, tc := range []struct { name string @@ -161,14 +162,15 @@ func TestFDBMoveDataDouble(t *testing.T) { waitResp1, err1 := client1.FDBMoveDataWait(ctx, waitReq) testutil.FatalOnErr("fdbmovedata wait1 failed", err1, t) for _, want1 := range tc.outputWait { - rs, err1 := waitResp1.Recv() - if err1 != io.EOF { - testutil.FatalOnErr("fdbmovedata wait1 failed", err1, t) - } + rs, err := waitResp1.Recv() + testutil.FatalOnErr("fdbmovedata wait1 failed", err, t) if !(proto.Equal(want1, rs)) { t.Errorf("want: %v, got: %v", want1, rs) } } + if _, err := waitResp1.Recv(); err != io.EOF { + testutil.FatalOnErr("fdbmovedata wait1 EOF failed", err, t) + } waitResp2, err2 := client2.FDBMoveDataWait(ctx, waitReq) testutil.FatalOnErr("fdbmovedata wait2 failed", err2, t) @@ -179,3 +181,53 @@ func TestFDBMoveDataDouble(t *testing.T) { }) } } + +func TestFDBMoveDataResumed(t *testing.T) { + ctx := context.Background() + conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + testutil.FatalOnErr("grpc.DialContext(bufnet)", err, t) + t.Cleanup(func() { conn.Close() }) + + client := pb.NewFDBMoveClient(conn) + + savedGenerateFDBMoveDataArgs := generateFDBMoveDataArgs + t.Cleanup(func() { + generateFDBMoveDataArgs = savedGenerateFDBMoveDataArgs + }) + + sh := testutil.ResolvePath(t, "/bin/sh") + + generateFDBMoveDataArgs = func(req *pb.FDBMoveDataCopyRequest) ([]string, error) { + _, err = savedGenerateFDBMoveDataArgs(req) + return []string{sh, "-c", "/bin/sleep 0.2; for i in {1..16000}; do echo filling-up-pipe-buffer; done"}, err + } + + resp, err := client.FDBMoveDataCopy(ctx, &pb.FDBMoveDataCopyRequest{ + ClusterFile: "1", + TenantGroup: "2", + SourceCluster: "3", + DestinationCluster: "4", + NumProcs: 5, + }) + testutil.FatalOnErr("fdbmovedata copy failed", err, t) + waitReq := &pb.FDBMoveDataWaitRequest{Id: resp.Id} + shortCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond) + defer cancel() + shortWaitResp, err := client.FDBMoveDataWait(shortCtx, waitReq) + testutil.FatalOnErr("fdbmovedata wait failed", err, t) + _, err = shortWaitResp.Recv() + testutil.WantErr("fdbmovedata wait", err, true, t) + + time.Sleep(20 * time.Millisecond) + + waitResp, err := client.FDBMoveDataWait(ctx, waitReq) + testutil.FatalOnErr("second fdbmovedata wait failed", err, t) + for { + if _, err := waitResp.Recv(); err != nil { + if err != io.EOF { + testutil.FatalOnErr("fdbmovedata wait exited with error", err, t) + } + break + } + } +} diff --git a/services/fdb/server/server_test.go b/services/fdb/server/server_test.go index 9d9c462c..38127e28 100644 --- a/services/fdb/server/server_test.go +++ b/services/fdb/server/server_test.go @@ -52,7 +52,7 @@ func TestMain(m *testing.M) { fds := &fdbserver{} fds.Register(s) - fdbm := &fdbmovedata{} + fdbm := &fdbmovedata{operations: make(map[int64]*moveOperation)} fdbm.Register(s) go func() {