Skip to content

Commit

Permalink
Add flag to skip bucket access check
Browse files Browse the repository at this point in the history
1. Add support for PV volume attribute  'skipCSIBucketAccessCheck' to
   skip CSI bucket access checks
2. Restore the bucket access check logic that was removed as part of GoogleCloudPlatform#255
  • Loading branch information
saikat-royc committed May 24, 2024
1 parent a07268c commit 01475a0
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 13 deletions.
45 changes: 44 additions & 1 deletion pkg/csi_driver/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
package driver

import (
"fmt"
"os"
"strings"
"time"
Expand Down Expand Up @@ -110,10 +111,11 @@ func (s *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublish
fuseMountOptions = joinMountOptions(fuseMountOptions, capMount.GetMountFlags())
}

fuseMountOptions, err := parseVolumeAttributes(fuseMountOptions, vc)
fuseMountOptions, skipBucketAccessCheck, err := parseVolumeAttributes(fuseMountOptions, vc)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
klog.V(4).Infof("NodePublishVolume on volume %q has skipBucketAccessCheck %t", bucketName, skipBucketAccessCheck)

if vc[VolumeContextKeyEphemeral] == TrueStr {
bucketName = vc[VolumeContextKeyBucketName]
Expand All @@ -137,6 +139,32 @@ func (s *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublish
}
defer s.volumeLocks.Release(targetPath)

// Check if the given Service Account has the access to the GCS bucket, and the bucket exists.
if bucketName != "_" && !skipBucketAccessCheck {
storageService, err := s.prepareStorageService(ctx, req.GetVolumeContext())
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "failed to prepare storage service: %v", err)
}
defer storageService.Close()

if exist, err := storageService.CheckBucketExists(ctx, &storage.ServiceBucket{Name: bucketName}); !exist {
code := codes.Internal
if storage.IsNotExistErr(err) {
code = codes.NotFound
}

if storage.IsPermissionDeniedErr(err) {
code = codes.PermissionDenied
}

if storage.IsCanceledErr(err) {
code = codes.Aborted
}

return nil, status.Errorf(code, "failed to get GCS bucket %q: %v", bucketName, err)
}
}

// Check if the sidecar container was injected into the Pod
pod, err := s.k8sClients.GetPod(ctx, vc[VolumeContextKeyPodNamespace], vc[VolumeContextKeyPodName])
if err != nil {
Expand Down Expand Up @@ -198,6 +226,10 @@ func (s *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublish
code = codes.PermissionDenied
}

if strings.Contains(errMsgStr, "bucket doesn't exist") {
code = codes.NotFound
}

return nil, status.Errorf(code, "the sidecar container failed with error: %v", errMsgStr)
}

Expand Down Expand Up @@ -322,3 +354,14 @@ func (s *nodeServer) isDirMounted(targetPath string) (bool, error) {

return false, nil
}

// prepareStorageService prepares the GCS Storage Service using the Kubernetes Service Account from VolumeContext.
func (s *nodeServer) prepareStorageService(ctx context.Context, vc map[string]string) (storage.Service, error) {
ts := s.driver.config.TokenManager.GetTokenSourceFromK8sServiceAccount(vc[VolumeContextKeyPodNamespace], vc[VolumeContextKeyServiceAccountName], vc[VolumeContextKeyServiceAccountToken])
storageService, err := s.storageServiceManager.SetupService(ctx, ts)
if err != nil {
return nil, fmt.Errorf("storage service manager failed to setup service: %w", err)
}

return storageService, nil
}
20 changes: 14 additions & 6 deletions pkg/csi_driver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ const (
VolumeContextKeyMetadataTypeCacheCapacity = "metadataTypeCacheCapacity"
VolumeContextKeyMetadataCacheTTLSeconds = "metadataCacheTTLSeconds"
VolumeContextKeyGcsfuseLoggingSeverity = "gcsfuseLoggingSeverity"
VolumeContextKeySkipCSIBucketAccessCheck = "skipCSIBucketAccessCheck"
)

func NewVolumeCapabilityAccessMode(mode csi.VolumeCapability_AccessMode_Mode) *csi.VolumeCapability_AccessMode {
Expand Down Expand Up @@ -161,14 +162,15 @@ var volumeAttributesToMountOptionsMapping = map[string]string{
VolumeContextKeyMetadataTypeCacheCapacity: "metadata-cache:type-cache-max-size-mb:",
VolumeContextKeyMetadataCacheTTLSeconds: "metadata-cache:ttl-secs:",
VolumeContextKeyGcsfuseLoggingSeverity: "logging:severity:",
VolumeContextKeySkipCSIBucketAccessCheck: "skip-csi-bucket-access-check",
}

// parseVolumeAttributes parses volume attributes and convert them to gcsfuse mount options.
func parseVolumeAttributes(fuseMountOptions []string, volumeContext map[string]string) ([]string, error) {
func parseVolumeAttributes(fuseMountOptions []string, volumeContext map[string]string) ([]string, bool, error) {
if mountOptions, ok := volumeContext[VolumeContextKeyMountOptions]; ok {
fuseMountOptions = joinMountOptions(fuseMountOptions, strings.Split(mountOptions, ","))
}

skipCSIBucketAccessCheck := false
for volumeAttribute, mountOption := range volumeAttributesToMountOptionsMapping {
value, ok := volumeContext[volumeAttribute]
if !ok {
Expand All @@ -183,7 +185,7 @@ func parseVolumeAttributes(fuseMountOptions []string, volumeContext map[string]s
case VolumeContextKeyFileCacheCapacity, VolumeContextKeyMetadataStatCacheCapacity, VolumeContextKeyMetadataTypeCacheCapacity:
quantity, err := resource.ParseQuantity(value)
if err != nil {
return nil, fmt.Errorf("volume attribute %v only accepts a valid Quantity value, got %q, error: %w", volumeAttribute, value, err)
return nil, skipCSIBucketAccessCheck, fmt.Errorf("volume attribute %v only accepts a valid Quantity value, got %q, error: %w", volumeAttribute, value, err)
}

megabytes := quantity.Value()
Expand All @@ -203,7 +205,7 @@ func parseVolumeAttributes(fuseMountOptions []string, volumeContext map[string]s
if boolVal, err := strconv.ParseBool(value); err == nil {
mountOptionWithValue = mountOption + strconv.FormatBool(boolVal)
} else {
return nil, fmt.Errorf("volume attribute %v only accepts a valid bool value, got %q", volumeAttribute, value)
return nil, skipCSIBucketAccessCheck, fmt.Errorf("volume attribute %v only accepts a valid bool value, got %q", volumeAttribute, value)
}

// parse int volume attributes
Expand All @@ -215,17 +217,23 @@ func parseVolumeAttributes(fuseMountOptions []string, volumeContext map[string]s

mountOptionWithValue = mountOption + strconv.Itoa(intVal)
} else {
return nil, fmt.Errorf("volume attribute %v only accepts a valid int value, got %q", volumeAttribute, value)
return nil, skipCSIBucketAccessCheck, fmt.Errorf("volume attribute %v only accepts a valid int value, got %q", volumeAttribute, value)
}

case VolumeContextKeySkipCSIBucketAccessCheck:
if boolVal, err := strconv.ParseBool(value); err == nil {
skipCSIBucketAccessCheck = boolVal
} else {
return nil, skipCSIBucketAccessCheck, fmt.Errorf("volume attribute %v only accepts a valid bool value, got %q", volumeAttribute, value)
}
default:
mountOptionWithValue = mountOption + value
}

fuseMountOptions = joinMountOptions(fuseMountOptions, []string{mountOptionWithValue})
}

return fuseMountOptions, nil
return fuseMountOptions, skipCSIBucketAccessCheck, nil
}

func putExitFile(pod *corev1.Pod, emptyDirBasePath string) error {
Expand Down
39 changes: 33 additions & 6 deletions pkg/csi_driver/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ func TestParseVolumeAttributes(t *testing.T) {
t.Run("parsing volume attributes into mount options", func(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
volumeContext map[string]string
expectedMountOptions []string
expectedErr bool
name string
volumeContext map[string]string
expectedMountOptions []string
expectedSkipBucketAccessCheck bool
expectedErr bool
}{
{
name: "should return correct fileCacheCapacity 1",
Expand Down Expand Up @@ -302,19 +303,45 @@ func TestParseVolumeAttributes(t *testing.T) {
volumeAttributesToMountOptionsMapping[VolumeContextKeyMetadataCacheTTLSeconds] + "3600",
},
},
{
name: "should return correct mount options, and skip bucket access check flag",
expectedSkipBucketAccessCheck: true,
volumeContext: map[string]string{
VolumeContextKeyMountOptions: "implicit-dirs,uid=1001",
VolumeContextKeyGcsfuseLoggingSeverity: "trace",
VolumeContextKeyFileCacheCapacity: "500Gi",
VolumeContextKeyFileCacheForRangeRead: "true",
VolumeContextKeyMetadataStatCacheCapacity: "-100",
VolumeContextKeyMetadataTypeCacheCapacity: "0",
VolumeContextKeyMetadataCacheTTLSeconds: "3600",
VolumeContextKeySkipCSIBucketAccessCheck: "true",
},
expectedMountOptions: []string{
"implicit-dirs",
"uid=1001",
volumeAttributesToMountOptionsMapping[VolumeContextKeyGcsfuseLoggingSeverity] + "trace",
volumeAttributesToMountOptionsMapping[VolumeContextKeyFileCacheCapacity] + "512000",
volumeAttributesToMountOptionsMapping[VolumeContextKeyFileCacheForRangeRead] + "true",
volumeAttributesToMountOptionsMapping[VolumeContextKeyMetadataStatCacheCapacity] + "-1",
volumeAttributesToMountOptionsMapping[VolumeContextKeyMetadataTypeCacheCapacity] + "0",
volumeAttributesToMountOptionsMapping[VolumeContextKeyMetadataCacheTTLSeconds] + "3600",
},
},
}

for _, tc := range testCases {
t.Logf("test case: %s", tc.name)
output, err := parseVolumeAttributes([]string{}, tc.volumeContext)

output, skipCSIBucketAccessCheck, err := parseVolumeAttributes([]string{}, tc.volumeContext)
if (err != nil) != tc.expectedErr {
t.Errorf("Got error %v, but expected error %v", err, tc.expectedErr)
}

if tc.expectedErr {
continue
}
if tc.expectedSkipBucketAccessCheck != skipCSIBucketAccessCheck {
t.Errorf("Got skipBucketAccessCheck %v, but expected %v", skipCSIBucketAccessCheck, tc.expectedSkipBucketAccessCheck)
}

less := func(a, b string) bool { return a > b }
if diff := cmp.Diff(output, tc.expectedMountOptions, cmpopts.SortSlices(less)); diff != "" {
Expand Down

0 comments on commit 01475a0

Please sign in to comment.