From 356672cb56fd5a0eed11e5089ac824c7ab09ffac Mon Sep 17 00:00:00 2001 From: Ye Sijun Date: Fri, 24 Jun 2022 15:45:37 +0800 Subject: [PATCH] refactor: reduce duplicate code Signed-off-by: Ye Sijun (cherry picked from commit 1ab42be15da2578d34f21745563b4e3efc3af455) > Conflicts: > oci/spec_opts.go > oci/spec_opts_linux_test.go Signed-off-by: Akihiro Suda --- oci/spec_opts.go | 53 ++++++---------- oci/spec_opts_linux_test.go | 119 +++++++++++++++++++++++++++++++++++- 2 files changed, 136 insertions(+), 36 deletions(-) diff --git a/oci/spec_opts.go b/oci/spec_opts.go index c23455ce2a15..70b24eabe73f 100644 --- a/oci/spec_opts.go +++ b/oci/spec_opts.go @@ -618,11 +618,8 @@ func WithUIDGID(uid, gid uint32) SpecOpts { func WithUserID(uid uint32) SpecOpts { return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) { setProcess(s) - if c.Snapshotter == "" && c.SnapshotKey == "" { - if !isRootfsAbs(s.Root.Path) { - return errors.Errorf("rootfs absolute path is required") - } - user, err := UserFromPath(s.Root.Path, func(u user.User) bool { + setUser := func(root string) error { + user, err := UserFromPath(root, func(u user.User) bool { return u.Uid == int(uid) }) if err != nil { @@ -634,7 +631,12 @@ func WithUserID(uid uint32) SpecOpts { } s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) return nil - + } + if c.Snapshotter == "" && c.SnapshotKey == "" { + if !isRootfsAbs(s.Root.Path) { + return errors.New("rootfs absolute path is required") + } + return setUser(s.Root.Path) } if c.Snapshotter == "" { return errors.Errorf("no snapshotter set for container") @@ -649,20 +651,7 @@ func WithUserID(uid uint32) SpecOpts { } mounts = tryReadonlyMounts(mounts) - return mount.WithTempMount(ctx, mounts, func(root string) error { - user, err := UserFromPath(root, func(u user.User) bool { - return u.Uid == int(uid) - }) - if err != nil { - if os.IsNotExist(err) || err == ErrNoUsersFound { - s.Process.User.UID, s.Process.User.GID = uid, 0 - return nil - } - return err - } - s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) - return nil - }) + return mount.WithTempMount(ctx, mounts, setUser) } } @@ -674,11 +663,8 @@ func WithUsername(username string) SpecOpts { return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) { setProcess(s) if s.Linux != nil { - if c.Snapshotter == "" && c.SnapshotKey == "" { - if !isRootfsAbs(s.Root.Path) { - return errors.Errorf("rootfs absolute path is required") - } - user, err := UserFromPath(s.Root.Path, func(u user.User) bool { + setUser := func(root string) error { + user, err := UserFromPath(root, func(u user.User) bool { return u.Name == username }) if err != nil { @@ -687,6 +673,12 @@ func WithUsername(username string) SpecOpts { s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) return nil } + if c.Snapshotter == "" && c.SnapshotKey == "" { + if !isRootfsAbs(s.Root.Path) { + return errors.New("rootfs absolute path is required") + } + return setUser(s.Root.Path) + } if c.Snapshotter == "" { return errors.Errorf("no snapshotter set for container") } @@ -700,16 +692,7 @@ func WithUsername(username string) SpecOpts { } mounts = tryReadonlyMounts(mounts) - return mount.WithTempMount(ctx, mounts, func(root string) error { - user, err := UserFromPath(root, func(u user.User) bool { - return u.Name == username - }) - if err != nil { - return err - } - s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid) - return nil - }) + return mount.WithTempMount(ctx, mounts, setUser) } else if s.Windows != nil { s.Process.User.Username = username } else { diff --git a/oci/spec_opts_linux_test.go b/oci/spec_opts_linux_test.go index 32bcc7bfa479..14d639166aa2 100644 --- a/oci/spec_opts_linux_test.go +++ b/oci/spec_opts_linux_test.go @@ -18,6 +18,7 @@ package oci import ( "context" + "fmt" "io/ioutil" "os" "path/filepath" @@ -31,6 +32,123 @@ import ( "golang.org/x/sys/unix" ) +// nolint:gosec +func TestWithUserID(t *testing.T) { + t.Parallel() + + expectedPasswd := `root:x:0:0:root:/root:/bin/ash +guest:x:405:100:guest:/dev/null:/sbin/nologin +` + td := t.TempDir() + apply := fstest.Apply( + fstest.CreateDir("/etc", 0777), + fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777), + ) + if err := apply.Apply(td); err != nil { + t.Fatalf("failed to apply: %v", err) + } + c := containers.Container{ID: t.Name()} + testCases := []struct { + userID uint32 + expectedUID uint32 + expectedGID uint32 + }{ + { + userID: 0, + expectedUID: 0, + expectedGID: 0, + }, + { + userID: 405, + expectedUID: 405, + expectedGID: 100, + }, + { + userID: 1000, + expectedUID: 1000, + expectedGID: 0, + }, + } + for _, testCase := range testCases { + t.Run(fmt.Sprintf("user %d", testCase.userID), func(t *testing.T) { + t.Parallel() + s := Spec{ + Version: specs.Version, + Root: &specs.Root{ + Path: td, + }, + Linux: &specs.Linux{}, + } + err := WithUserID(testCase.userID)(context.Background(), nil, &c, &s) + assert.NoError(t, err) + assert.Equal(t, testCase.expectedUID, s.Process.User.UID) + assert.Equal(t, testCase.expectedGID, s.Process.User.GID) + }) + } +} + +// nolint:gosec +func TestWithUsername(t *testing.T) { + t.Parallel() + + expectedPasswd := `root:x:0:0:root:/root:/bin/ash +guest:x:405:100:guest:/dev/null:/sbin/nologin +` + td := t.TempDir() + apply := fstest.Apply( + fstest.CreateDir("/etc", 0777), + fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777), + ) + if err := apply.Apply(td); err != nil { + t.Fatalf("failed to apply: %v", err) + } + c := containers.Container{ID: t.Name()} + testCases := []struct { + user string + expectedUID uint32 + expectedGID uint32 + err string + }{ + { + user: "root", + expectedUID: 0, + expectedGID: 0, + }, + { + user: "guest", + expectedUID: 405, + expectedGID: 100, + }, + { + user: "1000", + err: "no users found", + }, + { + user: "unknown", + err: "no users found", + }, + } + for _, testCase := range testCases { + t.Run(testCase.user, func(t *testing.T) { + t.Parallel() + s := Spec{ + Version: specs.Version, + Root: &specs.Root{ + Path: td, + }, + Linux: &specs.Linux{}, + } + err := WithUsername(testCase.user)(context.Background(), nil, &c, &s) + if err != nil { + assert.EqualError(t, err, testCase.err) + } + assert.Equal(t, testCase.expectedUID, s.Process.User.UID) + assert.Equal(t, testCase.expectedGID, s.Process.User.GID) + }) + } + +} + // nolint:gosec func TestWithAdditionalGIDs(t *testing.T) { t.Parallel() @@ -55,7 +173,6 @@ sys:x:3:root,bin,adm c := containers.Container{ID: t.Name()} testCases := []struct { - name string user string expected []uint32 }{