Skip to content

Commit

Permalink
Return error when requested and actual snapshot don't match
Browse files Browse the repository at this point in the history
And add unit tests
  • Loading branch information
jsafrane committed Feb 4, 2020
1 parent 9435266 commit a7bd72b
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 12 deletions.
29 changes: 17 additions & 12 deletions pkg/driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,24 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol
}
}

snapshotID := ""
volumeSource := req.GetVolumeContentSource()
if volumeSource != nil {
if _, ok := volumeSource.GetType().(*csi.VolumeContentSource_Snapshot); !ok {
return nil, status.Error(codes.InvalidArgument, "Unsupported volumeContentSource type")
}
sourceSnapshot := volumeSource.GetSnapshot()
if sourceSnapshot == nil {
return nil, status.Error(codes.InvalidArgument, "Error retrieving snapshot from the volumeContentSource")
}
snapshotID = sourceSnapshot.GetSnapshotId()
}

// volume exists already
if disk != nil {
if disk.SnapshotID != snapshotID {
return nil, status.Errorf(codes.AlreadyExists, "Volume already exists, but was restored from a different snapshot than %s", snapshotID)
}
return newCreateVolumeResponse(disk), nil
}

Expand All @@ -177,18 +193,7 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol
AvailabilityZone: zone,
Encrypted: isEncrypted,
KmsKeyID: kmsKeyID,
}

volumeSource := req.GetVolumeContentSource()
if volumeSource != nil {
if _, ok := volumeSource.GetType().(*csi.VolumeContentSource_Snapshot); !ok {
return nil, status.Error(codes.InvalidArgument, "Unsupported volumeContentSource type")
}
sourceSnapshot := volumeSource.GetSnapshot()
if sourceSnapshot == nil {
return nil, status.Error(codes.InvalidArgument, "Error retrieving snapshot from the volumeContentSource")
}
opts.SnapshotID = sourceSnapshot.GetSnapshotId()
SnapshotID: snapshotID,
}

disk, err = d.cloud.CreateDisk(ctx, volName, opts)
Expand Down
153 changes: 153 additions & 0 deletions pkg/driver/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,159 @@ func TestCreateVolume(t *testing.T) {
}
},
},
{
name: "restore snapshot",
testFunc: func(t *testing.T) {
req := &csi.CreateVolumeRequest{
Name: "random-vol-name",
CapacityRange: stdCapRange,
VolumeCapabilities: stdVolCap,
Parameters: nil,
VolumeContentSource: &csi.VolumeContentSource{
Type: &csi.VolumeContentSource_Snapshot{
Snapshot: &csi.VolumeContentSource_SnapshotSource{
SnapshotId: "snapshot-id",
},
},
},
}

ctx := context.Background()

mockDisk := &cloud.Disk{
VolumeID: req.Name,
AvailabilityZone: expZone,
CapacityGiB: util.BytesToGiB(stdVolSize),
SnapshotID: "snapshot-id",
}

mockCtl := gomock.NewController(t)
defer mockCtl.Finish()

mockCloud := mocks.NewMockCloud(mockCtl)
mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound)
mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil)

awsDriver := controllerService{
cloud: mockCloud,
driverOptions: &DriverOptions{},
}

rsp, err := awsDriver.CreateVolume(ctx, req)
if err != nil {
srvErr, ok := status.FromError(err)
if !ok {
t.Fatalf("Could not get error status code from error: %v", srvErr)
}
t.Fatalf("Unexpected error: %v", srvErr.Code())
}

snapshotID := ""
if rsp.Volume != nil && rsp.Volume.ContentSource != nil && rsp.Volume.ContentSource.GetSnapshot() != nil {
snapshotID = rsp.Volume.ContentSource.GetSnapshot().SnapshotId
}
if rsp.Volume.ContentSource.GetSnapshot().SnapshotId != "snapshot-id" {
t.Errorf("Unexpected snapshot ID: %q", snapshotID)
}
},
},
{
name: "restore snapshot, volume already exists",
testFunc: func(t *testing.T) {
req := &csi.CreateVolumeRequest{
Name: "random-vol-name",
CapacityRange: stdCapRange,
VolumeCapabilities: stdVolCap,
Parameters: nil,
VolumeContentSource: &csi.VolumeContentSource{
Type: &csi.VolumeContentSource_Snapshot{
Snapshot: &csi.VolumeContentSource_SnapshotSource{
SnapshotId: "snapshot-id",
},
},
},
}

ctx := context.Background()

mockDisk := &cloud.Disk{
VolumeID: req.Name,
AvailabilityZone: expZone,
CapacityGiB: util.BytesToGiB(stdVolSize),
SnapshotID: "snapshot-id",
}

mockCtl := gomock.NewController(t)
defer mockCtl.Finish()

mockCloud := mocks.NewMockCloud(mockCtl)
mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(mockDisk, nil)

awsDriver := controllerService{
cloud: mockCloud,
driverOptions: &DriverOptions{},
}

rsp, err := awsDriver.CreateVolume(ctx, req)
if err != nil {
srvErr, ok := status.FromError(err)
if !ok {
t.Fatalf("Could not get error status code from error: %v", srvErr)
}
t.Fatalf("Unexpected error: %v", srvErr.Code())
}

snapshotID := ""
if rsp.Volume != nil && rsp.Volume.ContentSource != nil && rsp.Volume.ContentSource.GetSnapshot() != nil {
snapshotID = rsp.Volume.ContentSource.GetSnapshot().SnapshotId
}
if rsp.Volume.ContentSource.GetSnapshot().SnapshotId != "snapshot-id" {
t.Errorf("Unexpected snapshot ID: %q", snapshotID)
}
},
},
{
name: "restore snapshot, volume already exists with different snapshot ID",
testFunc: func(t *testing.T) {
req := &csi.CreateVolumeRequest{
Name: "random-vol-name",
CapacityRange: stdCapRange,
VolumeCapabilities: stdVolCap,
Parameters: nil,
VolumeContentSource: &csi.VolumeContentSource{
Type: &csi.VolumeContentSource_Snapshot{
Snapshot: &csi.VolumeContentSource_SnapshotSource{
SnapshotId: "snapshot-id",
},
},
},
}

ctx := context.Background()

mockDisk := &cloud.Disk{
VolumeID: req.Name,
AvailabilityZone: expZone,
CapacityGiB: util.BytesToGiB(stdVolSize),
SnapshotID: "another-snapshot-id",
}

mockCtl := gomock.NewController(t)
defer mockCtl.Finish()

mockCloud := mocks.NewMockCloud(mockCtl)
mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(mockDisk, nil)

awsDriver := controllerService{
cloud: mockCloud,
driverOptions: &DriverOptions{},
}

if _, err := awsDriver.CreateVolume(ctx, req); err == nil {
t.Error("CreateVolume with invalid SnapshotID unexpectedly succeeded")
}
},
},
{
name: "fail no name",
testFunc: func(t *testing.T) {
Expand Down

0 comments on commit a7bd72b

Please sign in to comment.