From 606f636379b28d3054b80d55f0b3f8dfb7d95736 Mon Sep 17 00:00:00 2001 From: Lucas Rodriguez Date: Tue, 20 Aug 2024 11:11:25 -0500 Subject: [PATCH] feat: support context cancellation in file store (#803) Closes #619 --------- Signed-off-by: Lucas Rodriguez --- content/file/file.go | 8 ++-- content/file/utils.go | 9 ++++- content/file/utils_test.go | 78 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 5 deletions(-) diff --git a/content/file/file.go b/content/file/file.go index 3f1e8c086..4e6b1477d 100644 --- a/content/file/file.go +++ b/content/file/file.go @@ -395,7 +395,7 @@ func (s *Store) Predecessors(ctx context.Context, node ocispec.Descriptor) ([]oc } // Add adds a file into the file store. -func (s *Store) Add(_ context.Context, name, mediaType, path string) (ocispec.Descriptor, error) { +func (s *Store) Add(ctx context.Context, name, mediaType, path string) (ocispec.Descriptor, error) { if s.isClosedSet() { return ocispec.Descriptor{}, ErrStoreClosed } @@ -426,7 +426,7 @@ func (s *Store) Add(_ context.Context, name, mediaType, path string) (ocispec.De // generate descriptor var desc ocispec.Descriptor if fi.IsDir() { - desc, err = s.descriptorFromDir(name, mediaType, path) + desc, err = s.descriptorFromDir(ctx, name, mediaType, path) } else { desc, err = s.descriptorFromFile(fi, mediaType, path) } @@ -505,7 +505,7 @@ func (s *Store) pushDir(name, target string, expected ocispec.Descriptor, conten } // descriptorFromDir generates descriptor from the given directory. -func (s *Store) descriptorFromDir(name, mediaType, dir string) (desc ocispec.Descriptor, err error) { +func (s *Store) descriptorFromDir(ctx context.Context, name, mediaType, dir string) (desc ocispec.Descriptor, err error) { // make a temp file to store the gzip gz, err := s.tempFile() if err != nil { @@ -532,7 +532,7 @@ func (s *Store) descriptorFromDir(name, mediaType, dir string) (desc ocispec.Des tw := io.MultiWriter(gzw, tarDigester.Hash()) buf := bufPool.Get().(*[]byte) defer bufPool.Put(buf) - if err := tarDirectory(dir, name, tw, s.TarReproducible, *buf); err != nil { + if err := tarDirectory(ctx, dir, name, tw, s.TarReproducible, *buf); err != nil { return ocispec.Descriptor{}, fmt.Errorf("failed to tar %s: %w", dir, err) } diff --git a/content/file/utils.go b/content/file/utils.go index c42013d88..3c0077617 100644 --- a/content/file/utils.go +++ b/content/file/utils.go @@ -18,6 +18,7 @@ package file import ( "archive/tar" "compress/gzip" + "context" "errors" "fmt" "io" @@ -31,7 +32,7 @@ import ( // tarDirectory walks the directory specified by path, and tar those files with a new // path prefix. -func tarDirectory(root, prefix string, w io.Writer, removeTimes bool, buf []byte) (err error) { +func tarDirectory(ctx context.Context, root, prefix string, w io.Writer, removeTimes bool, buf []byte) (err error) { tw := tar.NewWriter(w) defer func() { closeErr := tw.Close() @@ -45,6 +46,12 @@ func tarDirectory(root, prefix string, w io.Writer, removeTimes bool, buf []byte return err } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + // Rename path name, err := filepath.Rel(root, path) if err != nil { diff --git a/content/file/utils_test.go b/content/file/utils_test.go index ed2acab7b..7d93f73ac 100644 --- a/content/file/utils_test.go +++ b/content/file/utils_test.go @@ -16,11 +16,89 @@ limitations under the License. package file import ( + "compress/gzip" + "context" + "errors" "os" "path/filepath" "testing" ) +func Test_tarDirectory(t *testing.T) { + setup := func(t *testing.T) (tmpdir string, gz *os.File, gw *gzip.Writer) { + tmpdir = t.TempDir() + + paths := []string{ + filepath.Join(tmpdir, "file1.txt"), + filepath.Join(tmpdir, "file2.txt"), + } + + for _, p := range paths { + err := os.WriteFile(p, []byte("test content"), 0644) + if err != nil { + t.Fatal(err) + } + } + + gz, err := os.CreateTemp(tmpdir, "tarDirectory-*") + if err != nil { + t.Fatal(err) + } + + return tmpdir, gz, gzip.NewWriter(gz) + } + + t.Run("success", func(t *testing.T) { + tmpdir, gz, gw := setup(t) + defer func() { + if err := gw.Close(); err != nil { + t.Fatal(err) + } + if err := gz.Close(); err != nil { + t.Fatal(err) + } + }() + + err := tarDirectory(context.Background(), tmpdir, "prefix", gw, false, nil) + if err != nil { + t.Fatal(err) + } + + _, err = gz.Stat() + if err != nil { + t.Fatal(err) + } + }) + + t.Run("context cancellation", func(t *testing.T) { + tmpdir, gz, gw := setup(t) + defer func() { + if err := gw.Close(); err != nil { + t.Fatal(err) + } + if err := gz.Close(); err != nil { + t.Fatal(err) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := tarDirectory(ctx, tmpdir, "prefix", gw, false, nil) + if err == nil { + t.Fatal("expected context cancellation error, got nil") + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled error, got %v", err) + } + + _, err = gz.Stat() + if err != nil { + t.Fatal(err) + } + }) +} + func Test_ensureBasePath(t *testing.T) { root := t.TempDir() if err := os.MkdirAll(filepath.Join(root, "hello world", "foo", "bar"), 0700); err != nil {