diff --git a/fastwalk.go b/fastwalk.go index 6d044ec..521d620 100644 --- a/fastwalk.go +++ b/fastwalk.go @@ -60,9 +60,10 @@ var ErrSkipFiles = errors.New("fastwalk: skip remaining files in directory") // as an error by any function. var SkipDir = fs.SkipDir -// TODO(charlie): Look into implementing the fs.SkipAll behavior of -// filepath.Walk and filepath.WalkDir. This may not be possible without taking -// a performance hit. +// SkipAll is used as a return value from [WalkDirFunc] to indicate that +// all remaining files and directories are to be skipped. It is not returned +// as an error by any function. +var SkipAll = fs.SkipAll // DefaultNumWorkers returns the default number of worker goroutines to use in // [Walk] and is the value of [runtime.GOMAXPROCS](-1) clamped to a range @@ -577,13 +578,18 @@ func (w *walker) joinPaths(dir, base string) string { func (w *walker) onDirEnt(dirName, baseName string, de DirEntry) error { joined := w.joinPaths(dirName, baseName) + err := w.fn(joined, de, nil) typ := de.Type() if typ == os.ModeDir { - w.enqueue(walkItem{dir: joined, info: de}) + if err != nil { + if err == SkipDir { + return nil + } + return err // May be SkipAll + } + w.enqueue(walkItem{dir: joined, info: de, callbackDone: true}) return nil } - - err := w.fn(joined, de, nil) if typ == os.ModeSymlink { if err == ErrTraverseLink { if !w.follow { @@ -594,8 +600,8 @@ func (w *walker) onDirEnt(dirName, baseName string, de DirEntry) error { } err = nil // Ignore ErrTraverseLink when Follow is true. } - if err == filepath.SkipDir { - // Permit SkipDir on symlinks too. + if err == SkipDir { + // Permit SkipDir and SkipAll on symlinks too. return nil } if err == nil && w.follow && w.shouldTraverse(joined, de) { @@ -609,10 +615,10 @@ func (w *walker) onDirEnt(dirName, baseName string, de DirEntry) error { func (w *walker) walk(root string, info DirEntry, runUserCallback bool) error { if runUserCallback { err := w.fn(root, info, nil) - if err == filepath.SkipDir { - return nil - } if err != nil { + if err == SkipDir || err == SkipAll { + return nil + } return err } } diff --git a/fastwalk_darwin.go b/fastwalk_darwin.go index 107e91f..88ad08d 100644 --- a/fastwalk_darwin.go +++ b/fastwalk_darwin.go @@ -71,6 +71,9 @@ func (w *walker) readDir(dirName string) (err error) { de := newUnixDirent(dirName, nm, typ) if w.sortMode == SortNone { if err := w.onDirEnt(dirName, nm, de); err != nil { + if err == SkipAll { + return nil + } if err != ErrSkipFiles { return err } @@ -92,6 +95,9 @@ func (w *walker) readDir(dirName string) (err error) { continue } if err := w.onDirEnt(dirName, d.Name(), d); err != nil { + if err == SkipAll { + return nil + } if err != ErrSkipFiles { return err } diff --git a/fastwalk_portable.go b/fastwalk_portable.go index 6956c36..610d081 100644 --- a/fastwalk_portable.go +++ b/fastwalk_portable.go @@ -36,6 +36,9 @@ func (w *walker) readDir(dirName string) error { e := newDirEntry(dirName, d) if w.sortMode == SortNone { if err := w.onDirEnt(dirName, d.Name(), e); err != nil { + if err == SkipAll { + return nil + } if err != ErrSkipFiles { return err } @@ -57,6 +60,9 @@ func (w *walker) readDir(dirName string) error { continue } if err := w.onDirEnt(dirName, d.Name(), d); err != nil { + if err == SkipAll { + return nil + } if err != ErrSkipFiles { return err } diff --git a/fastwalk_test.go b/fastwalk_test.go index 6f0ecc6..44281c9 100644 --- a/fastwalk_test.go +++ b/fastwalk_test.go @@ -444,6 +444,60 @@ func TestFastWalk_DirEntryType(t *testing.T) { }) } +func TestFastWalk_DirEntryStat(t *testing.T) { + testFastWalk(t, map[string]string{ + "foo/foo.go": "one", + "bar/bar.go": "LINK:../foo/foo.go", + "symdir": "LINK:foo", + }, + func(path string, d fs.DirEntry, err error) error { + requireNoError(t, err) + de := d.(fastwalk.DirEntry) + if _, ok := de.(fastwalk.DirEntry); !ok { + t.Errorf("%q: not a fastwalk.DirEntry: %T", path, de) + } + ls1, err := os.Lstat(path) + if err != nil { + t.Error(err) + } + ls2, err := de.Info() + if err != nil { + t.Error(err) + } + if !os.SameFile(ls1, ls2) { + t.Errorf("Info(%q) = %v; want: %v", path, ls2, ls1) + } + st1, err := os.Stat(path) + if err != nil { + t.Error(err) + } + st2, err := de.Stat() + if err != nil { + t.Error(err) + } + if !os.SameFile(st1, st2) { + t.Errorf("Stat(%q) = %v; want: %v", path, st2, st1) + } + if de.Name() != filepath.Base(path) { + t.Errorf("Name() = %q; want: %q", de.Name(), filepath.Base(path)) + } + if de.Type() != de.Type().Type() { + t.Errorf("%s: type mismatch got: %q want: %q", + path, de.Type(), de.Type().Type()) + } + return nil + }, + map[string]os.FileMode{ + "": os.ModeDir, + "/src": os.ModeDir, + "/src/bar": os.ModeDir, + "/src/bar/bar.go": os.ModeSymlink, + "/src/foo": os.ModeDir, + "/src/foo/foo.go": 0, + "/src/symdir": os.ModeSymlink, + }) +} + func TestFastWalk_SkipDir(t *testing.T) { test := func(t *testing.T, mode fastwalk.SortMode) { conf := fastwalk.DefaultConfig.Copy() @@ -485,6 +539,28 @@ func TestFastWalk_SkipDir(t *testing.T) { } } +// Test that returning SkipDir for the root directory aborts the walk +func TestFastWalk_SkipDir_Root(t *testing.T) { + want := map[string]os.FileMode{ + "": os.ModeDir, + } + conf := fastwalk.DefaultConfig.Copy() + conf.Sort = fastwalk.SortLexical // Needed for ordering + testFastWalkConf(t, conf, map[string]string{ + "a.go": "a", + "b.go": "b", + }, + func(path string, de fs.DirEntry, err error) error { + requireNoError(t, err) + return fastwalk.SkipDir + }, + want) + if len(want) != 1 { + t.Errorf("invalid number of files visited: wanted 1, got %v (%q)", + len(want), want) + } +} + func TestFastWalk_SkipFiles(t *testing.T) { mapKeys := func(m map[string]os.FileMode) []string { a := make([]string, 0, len(m)) @@ -542,6 +618,117 @@ func TestFastWalk_SkipFiles(t *testing.T) { } } +func TestFastWalk_SkipAll(t *testing.T) { + mapKeys := func(m map[string]os.FileMode) []string { + a := make([]string, 0, len(m)) + for k := range m { + a = append(a, k) + } + return a + } + + t.Run("Root", func(t *testing.T) { + want := map[string]os.FileMode{ + "": os.ModeDir, + } + conf := fastwalk.DefaultConfig.Copy() + conf.Sort = fastwalk.SortLexical // Needed for ordering + testFastWalkConf(t, conf, map[string]string{ + "a.go": "a", + "b.go": "b", + }, + func(path string, de fs.DirEntry, err error) error { + requireNoError(t, err) + return fastwalk.SkipAll + }, + want) + if len(want) != 1 { + t.Errorf("invalid number of files visited: wanted 1, got %v (%q)", + len(want), mapKeys(want)) + } + }) + + t.Run("File", func(t *testing.T) { + want := map[string]os.FileMode{ + "": os.ModeDir, + "/src": os.ModeDir, + "/src/a.go": 0, + } + conf := fastwalk.DefaultConfig.Copy() + conf.Sort = fastwalk.SortLexical // Needed for ordering + testFastWalkConf(t, conf, map[string]string{ + "a.go": "a", + "b.go": "b", + }, + func(path string, de fs.DirEntry, err error) error { + requireNoError(t, err) + if de.Name() == "a.go" { + return fastwalk.SkipAll + } + return nil + }, + want) + if len(want) != 3 { + t.Errorf("invalid number of files visited: wanted 3, got %v (%q)", + len(want), mapKeys(want)) + } + }) + + t.Run("Directory", func(t *testing.T) { + want := map[string]os.FileMode{ + "": os.ModeDir, + "/src": os.ModeDir, + "/src/dir1": os.ModeDir, + } + conf := fastwalk.DefaultConfig.Copy() + conf.Sort = fastwalk.SortDirsFirst // Needed for ordering + testFastWalkConf(t, conf, map[string]string{ + "dir1/a.go": "a", + "dir2/a.go": "a", + }, + func(path string, de fs.DirEntry, err error) error { + requireNoError(t, err) + if de.Name() == "dir1" { + return fastwalk.SkipAll + } + return nil + }, + want) + if len(want) != 3 { + t.Errorf("invalid number of files visited: wanted 3, got %v (%q)", + len(want), mapKeys(want)) + } + }) + + t.Run("Symlink", func(t *testing.T) { + want := map[string]os.FileMode{ + "": os.ModeDir, + "/src": os.ModeDir, + "/src/a.go": 0, + "/src/symdir": os.ModeSymlink, + } + conf := fastwalk.DefaultConfig.Copy() + conf.Sort = fastwalk.SortFilesFirst // Needed for ordering + testFastWalkConf(t, conf, map[string]string{ + "a.go": "a", + "foo/foo.go": "one", + "symdir": "LINK:foo", + }, + func(path string, de fs.DirEntry, err error) error { + requireNoError(t, err) + if de.Type()&fs.ModeSymlink != 0 { + return fastwalk.SkipAll + } + return nil + }, + want) + if len(want) != 4 { + t.Errorf("invalid number of files visited: wanted 4, got %v (%q)", + len(want), mapKeys(want)) + } + }) +} + func TestFastWalk_TraverseSymlink(t *testing.T) { testFastWalk(t, map[string]string{ "foo/foo.go": "one", @@ -1016,15 +1203,6 @@ func TestFastWalkJoinPaths(t *testing.T) { } } -func TestSkipAll(t *testing.T) { - err := fastwalk.Walk(nil, ".", func(path string, info fs.DirEntry, err error) error { - return fs.SkipAll - }) - if err != fs.SkipAll { - t.Error("Expected fs.SkipAll to be returned got:", err) - } -} - func BenchmarkSortModeString(b *testing.B) { var s string for i := 0; i < b.N; i++ { diff --git a/fastwalk_unix.go b/fastwalk_unix.go index 69f2e0b..7351a18 100644 --- a/fastwalk_unix.go +++ b/fastwalk_unix.go @@ -75,6 +75,9 @@ func (w *walker) readDir(dirName string) error { de := newUnixDirent(dirName, name, typ) if w.sortMode == SortNone { if err := w.onDirEnt(dirName, name, de); err != nil { + if err == SkipAll { + return nil + } if err == ErrSkipFiles { skipFiles = true continue @@ -97,6 +100,9 @@ func (w *walker) readDir(dirName string) error { continue } if err := w.onDirEnt(dirName, d.Name(), d); err != nil { + if err == SkipAll { + return nil + } if err != ErrSkipFiles { return err }