From b8fa2c77082445f7a9a2c2f18775b5cba8caaa72 Mon Sep 17 00:00:00 2001 From: Patrick Zheng Date: Tue, 24 Sep 2024 14:46:38 +0800 Subject: [PATCH] update Signed-off-by: Patrick Zheng --- internal/file/file.go | 37 +++++++++++++++++++++++++++++++++++ internal/file/file_test.go | 40 +++++++++++++++++++++++++------------- verifier/crl/crl.go | 31 +++++------------------------ 3 files changed, 69 insertions(+), 39 deletions(-) diff --git a/internal/file/file.go b/internal/file/file.go index 920fa640..823c884b 100644 --- a/internal/file/file.go +++ b/internal/file/file.go @@ -15,6 +15,7 @@ package file import ( "errors" + "fmt" "io" "io/fs" "os" @@ -23,6 +24,11 @@ import ( "strings" ) +const ( + // tempFileName is the prefix of the temporary file + tempFileName = "notation-*" +) + // ErrNotRegularFile is returned when the file is not an regular file. var ErrNotRegularFile = errors.New("not regular file") @@ -110,3 +116,34 @@ func CopyDirToDir(src, dst string) error { func TrimFileExtension(fileName string) string { return strings.TrimSuffix(fileName, filepath.Ext(fileName)) } + +// WriteFile writes content to path with all parent directories created. +// If path already exists and is a file, WriteFile overwrites it. +func WriteFile(path string, content []byte) (writeErr error) { + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { + return err + } + tempFile, err := os.CreateTemp("", tempFileName) + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + defer func() { + // remove the temp file in case of error + if writeErr != nil { + tempFile.Close() + os.Remove(tempFile.Name()) + } + }() + + if _, err := tempFile.Write(content); err != nil { + return fmt.Errorf("failed to write content: %w", err) + } + + // close before moving + if err := tempFile.Close(); err != nil { + return fmt.Errorf("failed to close temp file: %w", err) + } + + // rename is atomic on UNIX-like platforms + return os.Rename(tempFile.Name(), path) +} diff --git a/internal/file/file_test.go b/internal/file/file_test.go index a108d0da..67152f49 100644 --- a/internal/file/file_test.go +++ b/internal/file/file_test.go @@ -18,6 +18,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "testing" ) @@ -26,7 +27,7 @@ func TestCopyToDir(t *testing.T) { tempDir := t.TempDir() data := []byte("data") filename := filepath.Join(tempDir, "a", "file.txt") - if err := writeFile(filename, data); err != nil { + if err := WriteFile(filename, data); err != nil { t.Fatal(err) } @@ -45,7 +46,7 @@ func TestCopyToDir(t *testing.T) { destDir := t.TempDir() data := []byte("data") filename := filepath.Join(tempDir, "a", "file.txt") - if err := writeFile(filename, data); err != nil { + if err := WriteFile(filename, data); err != nil { t.Fatal(err) } @@ -77,7 +78,7 @@ func TestCopyToDir(t *testing.T) { data := []byte("data") // prepare file filename := filepath.Join(tempDir, "a", "file.txt") - if err := writeFile(filename, data); err != nil { + if err := WriteFile(filename, data); err != nil { t.Fatal(err) } // forbid reading @@ -100,7 +101,7 @@ func TestCopyToDir(t *testing.T) { data := []byte("data") // prepare file filename := filepath.Join(tempDir, "a", "file.txt") - if err := writeFile(filename, data); err != nil { + if err := WriteFile(filename, data); err != nil { t.Fatal(err) } // forbid dest directory operation @@ -123,7 +124,7 @@ func TestCopyToDir(t *testing.T) { data := []byte("data") // prepare file filename := filepath.Join(tempDir, "a", "file.txt") - if err := writeFile(filename, data); err != nil { + if err := WriteFile(filename, data); err != nil { t.Fatal(err) } // forbid writing to destTempDir @@ -140,7 +141,7 @@ func TestCopyToDir(t *testing.T) { tempDir := t.TempDir() data := []byte("data") filename := filepath.Join(tempDir, "a", "file.txt") - if err := writeFile(filename, data); err != nil { + if err := WriteFile(filename, data); err != nil { t.Fatal(err) } @@ -161,6 +162,26 @@ func TestFileNameWithoutExtension(t *testing.T) { } } +func TestWriteFile(t *testing.T) { + tempDir := t.TempDir() + content := []byte("test WriteFile") + + t.Run("permission denied", func(t *testing.T) { + err := os.Chmod(tempDir, 0) + if err != nil { + t.Fatal(err) + } + err = WriteFile(filepath.Join(tempDir, "testFile"), content) + if err == nil || !strings.Contains(err.Error(), "permission denied") { + t.Fatalf("expected permission denied error, but got %s", err) + } + err = os.Chmod(tempDir, 0700) + if err != nil { + t.Fatal(err) + } + }) +} + func validFileContent(t *testing.T, filename string, content []byte) { b, err := os.ReadFile(filename) if err != nil { @@ -170,10 +191,3 @@ func validFileContent(t *testing.T, filename string, content []byte) { t.Fatal("file content is not correct") } } - -func writeFile(path string, data []byte) error { - if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { - return err - } - return os.WriteFile(path, data, 0600) -} diff --git a/verifier/crl/crl.go b/verifier/crl/crl.go index 556787c8..b8c7fd23 100644 --- a/verifier/crl/crl.go +++ b/verifier/crl/crl.go @@ -28,14 +28,10 @@ import ( "time" corecrl "github.com/notaryproject/notation-core-go/revocation/crl" + "github.com/notaryproject/notation-go/internal/file" "github.com/notaryproject/notation-go/log" ) -const ( - // tempFileName is the prefix of the temporary file - tempFileName = "notation-*" -) - // FileCache implements corecrl.Cache. // // Key: url of the CRL. @@ -126,7 +122,7 @@ func (c *FileCache) Get(ctx context.Context, url string) (*corecrl.Bundle, error } // Set stores the CRL bundle in c with url as key. -func (c *FileCache) Set(ctx context.Context, url string, bundle *corecrl.Bundle) (setErr error) { +func (c *FileCache) Set(ctx context.Context, url string, bundle *corecrl.Bundle) error { logger := log.GetLogger(ctx) logger.Debugf("Storing crl bundle to file cache with key %q ...", url) @@ -145,28 +141,11 @@ func (c *FileCache) Set(ctx context.Context, url string, bundle *corecrl.Bundle) if bundle.DeltaCRL != nil { content.DeltaCRL = bundle.DeltaCRL.Raw } - - // save content to temp file - tempFile, err := os.CreateTemp("", tempFileName) + contentBytes, err := json.Marshal(content) if err != nil { - return fmt.Errorf("failed to store crl bundle in file cache: failed to create temp file: %w", err) - } - defer func() { - // remove the temp file in case of error - if setErr != nil { - defer os.Remove(tempFile.Name()) - } - }() - - if err := json.NewEncoder(tempFile).Encode(content); err != nil { - return fmt.Errorf("failed to store crl bundle in file cache: failed to encode content: %w", err) - } - if err := tempFile.Close(); err != nil { - return fmt.Errorf("failed to store crl bundle in file cache: failed to close temp file: %w", err) + return fmt.Errorf("failed to store crl bundle in file cache: %w", err) } - - // rename is atomic on UNIX-like platforms - if err := os.Rename(tempFile.Name(), filepath.Join(c.root, c.fileName(url))); err != nil { + if err := file.WriteFile(filepath.Join(c.root, c.fileName(url)), contentBytes); err != nil { return fmt.Errorf("failed to store crl bundle in file cache: %w", err) } return nil