diff --git a/internal/actions/dump/dump.go b/internal/actions/dump/dump.go index 2fdf666c..383a8cd2 100644 --- a/internal/actions/dump/dump.go +++ b/internal/actions/dump/dump.go @@ -67,10 +67,8 @@ func (action Dump) Run(ctx context.Context) (err error) { startTime := time.Now() - bar := progressbar.New(-1, "downloading", action.Spinner) + bar, plogger := progressbar.New(os.Stderr, -1, "downloading", action.Spinner) defer bar.Close() - plogger := progressbar.NewBarSafeLogger(os.Stderr, bar) - log.SetOutput(plogger) errGroup, ctx := errgroup.WithContext(ctx) @@ -142,7 +140,6 @@ func (action Dump) Run(ctx context.Context) (err error) { } _ = bar.Finish() - log.SetOutput(os.Stderr) // Close file err = f.Close() diff --git a/internal/actions/restore/restore.go b/internal/actions/restore/restore.go index f318a637..35fec5ba 100644 --- a/internal/actions/restore/restore.go +++ b/internal/actions/restore/restore.go @@ -45,11 +45,8 @@ func (action Restore) Run(ctx context.Context) (err error) { startTime := time.Now() - bar := progressbar.New(-1, "uploading", action.Spinner) + bar, errLog := progressbar.New(os.Stderr, -1, "uploading", action.Spinner) defer bar.Close() - errLog := progressbar.NewBarSafeLogger(os.Stderr, bar) - outLog := progressbar.NewBarSafeLogger(os.Stdout, bar) - log.SetOutput(errLog) errGroup, ctx := errgroup.WithContext(ctx) @@ -59,7 +56,7 @@ func (action Restore) Run(ctx context.Context) (err error) { defer func(pr io.ReadCloser) { _ = pr.Close() }(pr) - return action.runInDatabasePod(ctx, pr, outLog, errLog, action.Format) + return action.runInDatabasePod(ctx, pr, errLog, errLog, action.Format) }) errGroup.Go(func() error { @@ -131,7 +128,7 @@ func (action Restore) Run(ctx context.Context) (err error) { defer func(pr io.ReadCloser) { _ = pr.Close() }(pr) - return action.runInDatabasePod(ctx, pr, outLog, errLog, sqlformat.Gzip) + return action.runInDatabasePod(ctx, pr, errLog, errLog, sqlformat.Gzip) }) } @@ -158,7 +155,6 @@ func (action Restore) Run(ctx context.Context) (err error) { } _ = bar.Finish() - log.SetOutput(os.Stderr) log.WithFields(log.Fields{ "file": action.Filename, diff --git a/internal/progressbar/logger.go b/internal/progressbar/logger.go index 95111598..f1e6febe 100644 --- a/internal/progressbar/logger.go +++ b/internal/progressbar/logger.go @@ -12,8 +12,9 @@ func NewBarSafeLogger(w io.Writer, bar *ProgressBar) *BarSafeLogger { } type BarSafeLogger struct { - out io.Writer - bar *ProgressBar + out io.Writer + bar *ProgressBar + atStart bool } func (l *BarSafeLogger) Write(p []byte) (int, error) { @@ -24,8 +25,10 @@ func (l *BarSafeLogger) Write(p []byte) (int, error) { l.bar.mu.Lock() defer l.bar.mu.Unlock() - if _, err := l.out.Write([]byte("\r\x1B[K")); err != nil { - return 0, err + if !l.atStart { + if _, err := l.out.Write([]byte("\r\x1B[K")); err != nil { + return 0, err + } } n, err := l.out.Write(p) @@ -33,8 +36,13 @@ func (l *BarSafeLogger) Write(p []byte) (int, error) { return n, err } - if _, err := l.out.Write([]byte(l.bar.String())); err != nil { - return n, err + if p[len(p)-1] == '\n' { + if _, err := l.out.Write([]byte(l.bar.String())); err != nil { + return n, err + } + l.atStart = false + } else { + l.atStart = true } return n, nil diff --git a/internal/progressbar/progressbar.go b/internal/progressbar/progressbar.go index 245bc4ec..d414c3b3 100644 --- a/internal/progressbar/progressbar.go +++ b/internal/progressbar/progressbar.go @@ -14,7 +14,7 @@ import ( log "github.com/sirupsen/logrus" ) -func New(max int64, label string, spinnerKey string) *ProgressBar { +func New(w io.Writer, max int64, label string, spinnerKey string) (*ProgressBar, *BarSafeLogger) { s, ok := spinner.Map[spinnerKey] if !ok { log.WithField("spinner", spinnerKey).Warn("invalid spinner") @@ -61,7 +61,7 @@ func New(max int64, label string, spinnerKey string) *ProgressBar { if bar.IsFinished() { return } - if bar.mu.TryLock() { + if !bar.logger.atStart && bar.mu.TryLock() { _ = bar.RenderBlank() _, _ = os.Stderr.Write([]byte(bar.String())) bar.mu.Unlock() @@ -70,7 +70,9 @@ func New(max int64, label string, spinnerKey string) *ProgressBar { } }() - return bar + logger := NewBarSafeLogger(w, bar) + log.SetOutput(logger) + return bar, logger } type ProgressBar struct { @@ -78,10 +80,14 @@ type ProgressBar struct { mu sync.Mutex cancelChan chan struct{} cancelOnce sync.Once + logger BarSafeLogger } func (p *ProgressBar) Finish() error { - p.Close() + defer func() { + p.Close() + log.SetOutput(os.Stderr) + }() return p.ProgressBar.Finish() }