diff --git a/pkg/downloader/downloader.go b/pkg/downloader/downloader.go index ade98bd3b90..8a3160854c7 100644 --- a/pkg/downloader/downloader.go +++ b/pkg/downloader/downloader.go @@ -229,6 +229,31 @@ func Download(ctx context.Context, local, remote string, opts ...Opt) (*Result, return res, nil } + shad := cacheDirectoryPath(o.cacheDir, remote) + if err := os.MkdirAll(shad, 0o700); err != nil { + return nil, err + } + + var res *Result + err := lockutil.WithDirLock(shad, func() error { + var err error + res, err = getCached(ctx, localPath, remote, o) + if err != nil { + return err + } + if res != nil { + return nil + } + res, err = fetch(ctx, localPath, remote, o) + return err + }) + return res, err +} + +// getCached tries to copy the file from the cache to local path. Return result, +// nil if the file was copied, nil, nil if the file is not in the cache or the +// cache needs update, or nil, error on fatal error. +func getCached(ctx context.Context, localPath, remote string, o options) (*Result, error) { shad := cacheDirectoryPath(o.cacheDir, remote) shadData := filepath.Join(shad, "data") shadTime := filepath.Join(shad, "time") @@ -237,53 +262,62 @@ func Download(ctx context.Context, local, remote string, opts ...Opt) (*Result, if err != nil { return nil, err } - if _, err := os.Stat(shadData); err == nil { - logrus.Debugf("file %q is cached as %q", localPath, shadData) - useCache := true - if _, err := os.Stat(shadDigest); err == nil { - logrus.Debugf("Comparing digest %q with the cached digest file %q, not computing the actual digest of %q", - o.expectedDigest, shadDigest, shadData) - if err := validateCachedDigest(shadDigest, o.expectedDigest); err != nil { - return nil, err - } - if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, "", ""); err != nil { + if _, err := os.Stat(shadData); err != nil { + return nil, nil + } + ext := path.Ext(remote) + logrus.Debugf("file %q is cached as %q", localPath, shadData) + if _, err := os.Stat(shadDigest); err == nil { + logrus.Debugf("Comparing digest %q with the cached digest file %q, not computing the actual digest of %q", + o.expectedDigest, shadDigest, shadData) + if err := validateCachedDigest(shadDigest, o.expectedDigest); err != nil { + return nil, err + } + if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, "", ""); err != nil { + return nil, err + } + } else { + if match, lmCached, lmRemote, err := matchLastModified(ctx, shadTime, remote); err != nil { + logrus.WithError(err).Info("Failed to retrieve last-modified for cached digest-less image; using cached image.") + } else if match { + if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, o.description, o.expectedDigest); err != nil { return nil, err } } else { - if match, lmCached, lmRemote, err := matchLastModified(ctx, shadTime, remote); err != nil { - logrus.WithError(err).Info("Failed to retrieve last-modified for cached digest-less image; using cached image.") - } else if match { - if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, o.description, o.expectedDigest); err != nil { - return nil, err - } - } else { - logrus.Infof("Re-downloading digest-less image: last-modified mismatch (cached: %q, remote: %q)", lmCached, lmRemote) - useCache = false - } - } - if useCache { - res := &Result{ - Status: StatusUsedCache, - CachePath: shadData, - LastModified: readTime(shadTime), - ContentType: readFile(shadType), - ValidatedDigest: o.expectedDigest != "", - } - return res, nil + logrus.Infof("Re-downloading digest-less image: last-modified mismatch (cached: %q, remote: %q)", lmCached, lmRemote) + return nil, nil } } - if err := os.MkdirAll(shad, 0o700); err != nil { + res := &Result{ + Status: StatusUsedCache, + CachePath: shadData, + LastModified: readTime(shadTime), + ContentType: readFile(shadType), + ValidatedDigest: o.expectedDigest != "", + } + return res, nil +} + +// fetch downloads remote to the cache and copy the cached file to local path. +func fetch(ctx context.Context, localPath, remote string, o options) (*Result, error) { + shad := cacheDirectoryPath(o.cacheDir, remote) + shadData := filepath.Join(shad, "data") + shadTime := filepath.Join(shad, "time") + shadType := filepath.Join(shad, "type") + shadDigest, err := cacheDigestPath(shad, o.expectedDigest) + if err != nil { return nil, err } + ext := path.Ext(remote) shadURL := filepath.Join(shad, "url") - if err := writeFirst(shadURL, []byte(remote), 0o644); err != nil { + if err := os.WriteFile(shadURL, []byte(remote), 0o644); err != nil { return nil, err } if err := downloadHTTP(ctx, shadData, shadTime, shadType, remote, o.description, o.expectedDigest); err != nil { return nil, err } if shadDigest != "" && o.expectedDigest != "" { - if err := writeFirst(shadDigest, []byte(o.expectedDigest.String()), 0o644); err != nil { + if err := os.WriteFile(shadDigest, []byte(o.expectedDigest.String()), 0o644); err != nil { return nil, err } } @@ -327,18 +361,33 @@ func Cached(remote string, opts ...Opt) (*Result, error) { if err != nil { return nil, err } + + // Checking if data file exists is safe without locking. if _, err := os.Stat(shadData); err != nil { return nil, err } - if _, err := os.Stat(shadDigest); err != nil { - if err := validateCachedDigest(shadDigest, o.expectedDigest); err != nil { - return nil, err - } - } else { - if err := validateLocalFileDigest(shadData, o.expectedDigest); err != nil { - return nil, err + + // But validating the digest or the data file must take the lock to avoid races + // with parallel downloads. + if err := os.MkdirAll(shad, 0o700); err != nil { + return nil, err + } + err = lockutil.WithDirLock(shad, func() error { + if _, err := os.Stat(shadDigest); err != nil { + if err := validateCachedDigest(shadDigest, o.expectedDigest); err != nil { + return err + } + } else { + if err := validateLocalFileDigest(shadData, o.expectedDigest); err != nil { + return err + } } + return nil + }) + if err != nil { + return nil, err } + res := &Result{ Status: StatusUsedCache, CachePath: shadData, @@ -612,13 +661,13 @@ func downloadHTTP(ctx context.Context, localPath, lastModified, contentType, url } if lastModified != "" { lm := resp.Header.Get("Last-Modified") - if err := writeFirst(lastModified, []byte(lm), 0o644); err != nil { + if err := os.WriteFile(lastModified, []byte(lm), 0o644); err != nil { return err } } if contentType != "" { ct := resp.Header.Get("Content-Type") - if err := writeFirst(contentType, []byte(ct), 0o644); err != nil { + if err := os.WriteFile(contentType, []byte(ct), 0o644); err != nil { return err } } @@ -679,19 +728,7 @@ func downloadHTTP(ctx context.Context, localPath, lastModified, contentType, url return err } - // If localPath was created by a parallel download keep it. Replacing it - // while another process is copying it to the destination may fail the - // clonefile syscall. We use a lock to ensure that only one process updates - // data, and when we return data file exists. - - return lockutil.WithDirLock(filepath.Dir(localPath), func() error { - if _, err := os.Stat(localPath); err == nil { - return nil - } else if !errors.Is(err, os.ErrNotExist) { - return err - } - return os.Rename(localPathTmp, localPath) - }) + return os.Rename(localPathTmp, localPath) } var tempfileCount atomic.Uint64 @@ -706,18 +743,6 @@ func perProcessTempfile(path string) string { return fmt.Sprintf("%s.tmp.%d.%d", path, os.Getpid(), tempfileCount.Add(1)) } -// writeFirst writes data to path unless path already exists. -func writeFirst(path string, data []byte, perm os.FileMode) error { - return lockutil.WithDirLock(filepath.Dir(path), func() error { - if _, err := os.Stat(path); err == nil { - return nil - } else if !errors.Is(err, os.ErrNotExist) { - return err - } - return os.WriteFile(path, data, perm) - }) -} - // CacheEntries returns a map of cache entries. // The key is the SHA256 of the URL. // The value is the path to the cache entry. diff --git a/pkg/downloader/downloader_test.go b/pkg/downloader/downloader_test.go index 2e4483194cb..045805bc00d 100644 --- a/pkg/downloader/downloader_test.go +++ b/pkg/downloader/downloader_test.go @@ -8,7 +8,6 @@ import ( "os/exec" "path/filepath" "runtime" - "slices" "strings" "testing" "time" @@ -31,11 +30,6 @@ type downloadResult struct { // races quicker. 20 parallel downloads take about 120 milliseconds on M1 Pro. const parallelDownloads = 20 -// When downloading in parallel usually all downloads completed with -// StatusDownload, but some may be delayed and find the data file when they -// start. Can be reproduced locally using 100 parallel downloads. -var parallelStatus = []Status{StatusDownloaded, StatusUsedCache} - func TestDownloadRemote(t *testing.T) { ts := httptest.NewServer(http.FileServer(http.Dir("testdata"))) t.Cleanup(ts.Close) @@ -103,15 +97,10 @@ func TestDownloadRemote(t *testing.T) { results <- downloadResult{r, err} }() } - // We must process all results before cleanup. - for i := 0; i < parallelDownloads; i++ { - result := <-results - if result.err != nil { - t.Errorf("Download failed: %s", result.err) - } else if !slices.Contains(parallelStatus, result.r.Status) { - t.Errorf("Expected download status %s, got %s", parallelStatus, result.r.Status) - } - } + // Only one thread should download, the rest should use the cache. + downloaded, cached := countResults(t, results) + assert.Equal(t, downloaded, 1) + assert.Equal(t, cached, parallelDownloads-1) }) }) t.Run("caching-only mode", func(t *testing.T) { @@ -146,15 +135,10 @@ func TestDownloadRemote(t *testing.T) { results <- downloadResult{r, err} }() } - // We must process all results before cleanup. - for i := 0; i < parallelDownloads; i++ { - result := <-results - if result.err != nil { - t.Errorf("Download failed: %s", result.err) - } else if !slices.Contains(parallelStatus, result.r.Status) { - t.Errorf("Expected download status %s, got %s", parallelStatus, result.r.Status) - } - } + // Only one thread should download, the rest should use the cache. + downloaded, cached := countResults(t, results) + assert.Equal(t, downloaded, 1) + assert.Equal(t, cached, parallelDownloads-1) }) }) t.Run("cached", func(t *testing.T) { @@ -188,6 +172,26 @@ func TestDownloadRemote(t *testing.T) { }) } +func countResults(t *testing.T, results chan downloadResult) (downloaded, cached int) { + t.Helper() + for i := 0; i < parallelDownloads; i++ { + result := <-results + if result.err != nil { + t.Errorf("Download failed: %s", result.err) + } else { + switch result.r.Status { + case StatusDownloaded: + downloaded++ + case StatusUsedCache: + cached++ + default: + t.Errorf("Unexpected download status %q", result.r.Status) + } + } + } + return downloaded, cached +} + func TestRedownloadRemote(t *testing.T) { remoteDir := t.TempDir() ts := httptest.NewServer(http.FileServer(http.Dir(remoteDir))) @@ -203,18 +207,26 @@ func TestRedownloadRemote(t *testing.T) { assert.NilError(t, os.Chtimes(remoteFile, time.Now(), time.Now().Add(-time.Hour))) opt := []Opt{cacheOpt} - r, err := Download(context.Background(), filepath.Join(downloadDir, "digest-less1.txt"), ts.URL+"/digest-less.txt", opt...) + // Download on the first call + r, err := Download(context.Background(), filepath.Join(downloadDir, "1"), ts.URL+"/digest-less.txt", opt...) assert.NilError(t, err) assert.Equal(t, StatusDownloaded, r.Status) - r, err = Download(context.Background(), filepath.Join(downloadDir, "digest-less2.txt"), ts.URL+"/digest-less.txt", opt...) + + // Next download will use the cached download + r, err = Download(context.Background(), filepath.Join(downloadDir, "2"), ts.URL+"/digest-less.txt", opt...) assert.NilError(t, err) assert.Equal(t, StatusUsedCache, r.Status) - // modifying remote file will cause redownload + // Modifying remote file will cause redownload assert.NilError(t, os.Chtimes(remoteFile, time.Now(), time.Now())) - r, err = Download(context.Background(), filepath.Join(downloadDir, "digest-less3.txt"), ts.URL+"/digest-less.txt", opt...) + r, err = Download(context.Background(), filepath.Join(downloadDir, "3"), ts.URL+"/digest-less.txt", opt...) assert.NilError(t, err) assert.Equal(t, StatusDownloaded, r.Status) + + // Next download will use the cached download + r, err = Download(context.Background(), filepath.Join(downloadDir, "4"), ts.URL+"/digest-less.txt", opt...) + assert.NilError(t, err) + assert.Equal(t, StatusUsedCache, r.Status) }) t.Run("has-digest", func(t *testing.T) {