From 15659fda3835e362b2a395ef4aa0707a7f06313c Mon Sep 17 00:00:00 2001 From: Gabe Cook Date: Fri, 6 Oct 2023 01:46:03 -0500 Subject: [PATCH] feat(inplace): Write to a temporary file to ensure data integrity --- cmd/cmd.go | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index eab1efb..8d72931 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "os" + "path/filepath" "strings" "github.com/clevyr/yampl/internal/config" @@ -115,22 +116,9 @@ func openAndTemplate(cmd *cobra.Command, conf config.Config, p string) (err erro }(conf.Log) conf.Log = log.WithField("file", p) - var f *os.File - if conf.Inplace { - stat, err := os.Stat(p) - if err != nil { - return fmt.Errorf("%s: %w", p, err) - } - - f, err = os.OpenFile(p, os.O_RDWR, stat.Mode()) - if err != nil { - return fmt.Errorf("%s: %w", p, err) - } - } else { - f, err = os.Open(p) - if err != nil { - return fmt.Errorf("%s: %w", p, err) - } + f, err := os.Open(p) + if err != nil { + return fmt.Errorf("%s: %w", p, err) } defer func(f *os.File) { _ = f.Close() @@ -141,16 +129,26 @@ func openAndTemplate(cmd *cobra.Command, conf config.Config, p string) (err erro return fmt.Errorf("%s: %w", p, err) } + _ = f.Close() + if conf.Inplace { - if err := f.Truncate(int64(len(s))); err != nil { + temp, err := os.CreateTemp("", "yampl_*_"+filepath.Base(p)) + if err != nil { return fmt.Errorf("%s: %w", p, err) } + defer func() { + _ = os.RemoveAll(temp.Name()) + }() - if _, err := f.Seek(0, io.SeekStart); err != nil { + if _, err := temp.WriteString(s); err != nil { return fmt.Errorf("%s: %w", p, err) } - if _, err := f.WriteString(s); err != nil { + if err := temp.Close(); err != nil { + return fmt.Errorf("%s: %w", p, err) + } + + if err := os.Rename(temp.Name(), p); err != nil { return fmt.Errorf("%s: %w", p, err) } } else { @@ -159,7 +157,7 @@ func openAndTemplate(cmd *cobra.Command, conf config.Config, p string) (err erro } } - return f.Close() + return nil } func templateReader(conf config.Config, r io.Reader) (string, error) {