Skip to content

Commit

Permalink
Merge pull request #446 from jsafrane/snapshot-restore-source
Browse files Browse the repository at this point in the history
Fix CreateVolumeResponse after snapshot restore
  • Loading branch information
k8s-ci-robot authored Feb 4, 2020
2 parents d11901e + a7bd72b commit db95482
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 13 deletions.
4 changes: 3 additions & 1 deletion pkg/cloud/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ type Disk struct {
VolumeID string
CapacityGiB int64
AvailabilityZone string
SnapshotID string
}

// DiskOptions represents parameters to create an EBS volume
Expand Down Expand Up @@ -318,7 +319,7 @@ func (c *cloud) CreateDisk(ctx context.Context, volumeName string, diskOptions *
return nil, fmt.Errorf("failed to get an available volume in EC2: %v", err)
}

return &Disk{CapacityGiB: size, VolumeID: volumeID, AvailabilityZone: zone}, nil
return &Disk{CapacityGiB: size, VolumeID: volumeID, AvailabilityZone: zone, SnapshotID: snapshotID}, nil
}

func (c *cloud) DeleteDisk(ctx context.Context, volumeID string) (bool, error) {
Expand Down Expand Up @@ -485,6 +486,7 @@ func (c *cloud) GetDiskByName(ctx context.Context, name string, capacityBytes in
VolumeID: aws.StringValue(volume.VolumeId),
CapacityGiB: volSizeBytes,
AvailabilityZone: aws.StringValue(volume.AvailabilityZone),
SnapshotID: aws.StringValue(volume.SnapshotId),
}, nil
}

Expand Down
40 changes: 28 additions & 12 deletions pkg/driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,24 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol
}
}

snapshotID := ""
volumeSource := req.GetVolumeContentSource()
if volumeSource != nil {
if _, ok := volumeSource.GetType().(*csi.VolumeContentSource_Snapshot); !ok {
return nil, status.Error(codes.InvalidArgument, "Unsupported volumeContentSource type")
}
sourceSnapshot := volumeSource.GetSnapshot()
if sourceSnapshot == nil {
return nil, status.Error(codes.InvalidArgument, "Error retrieving snapshot from the volumeContentSource")
}
snapshotID = sourceSnapshot.GetSnapshotId()
}

// volume exists already
if disk != nil {
if disk.SnapshotID != snapshotID {
return nil, status.Errorf(codes.AlreadyExists, "Volume already exists, but was restored from a different snapshot than %s", snapshotID)
}
return newCreateVolumeResponse(disk), nil
}

Expand All @@ -177,18 +193,7 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol
AvailabilityZone: zone,
Encrypted: isEncrypted,
KmsKeyID: kmsKeyID,
}

volumeSource := req.GetVolumeContentSource()
if volumeSource != nil {
if _, ok := volumeSource.GetType().(*csi.VolumeContentSource_Snapshot); !ok {
return nil, status.Error(codes.InvalidArgument, "Unsupported volumeContentSource type")
}
sourceSnapshot := volumeSource.GetSnapshot()
if sourceSnapshot == nil {
return nil, status.Error(codes.InvalidArgument, "Error retrieving snapshot from the volumeContentSource")
}
opts.SnapshotID = sourceSnapshot.GetSnapshotId()
SnapshotID: snapshotID,
}

disk, err = d.cloud.CreateDisk(ctx, volName, opts)
Expand Down Expand Up @@ -513,6 +518,16 @@ func pickAvailabilityZone(requirement *csi.TopologyRequirement) string {
}

func newCreateVolumeResponse(disk *cloud.Disk) *csi.CreateVolumeResponse {
var src *csi.VolumeContentSource
if disk.SnapshotID != "" {
src = &csi.VolumeContentSource{
Type: &csi.VolumeContentSource_Snapshot{
Snapshot: &csi.VolumeContentSource_SnapshotSource{
SnapshotId: disk.SnapshotID,
},
},
}
}
return &csi.CreateVolumeResponse{
Volume: &csi.Volume{
VolumeId: disk.VolumeID,
Expand All @@ -523,6 +538,7 @@ func newCreateVolumeResponse(disk *cloud.Disk) *csi.CreateVolumeResponse {
Segments: map[string]string{TopologyKey: disk.AvailabilityZone},
},
},
ContentSource: src,
},
}
}
Expand Down
153 changes: 153 additions & 0 deletions pkg/driver/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,159 @@ func TestCreateVolume(t *testing.T) {
}
},
},
{
name: "restore snapshot",
testFunc: func(t *testing.T) {
req := &csi.CreateVolumeRequest{
Name: "random-vol-name",
CapacityRange: stdCapRange,
VolumeCapabilities: stdVolCap,
Parameters: nil,
VolumeContentSource: &csi.VolumeContentSource{
Type: &csi.VolumeContentSource_Snapshot{
Snapshot: &csi.VolumeContentSource_SnapshotSource{
SnapshotId: "snapshot-id",
},
},
},
}

ctx := context.Background()

mockDisk := &cloud.Disk{
VolumeID: req.Name,
AvailabilityZone: expZone,
CapacityGiB: util.BytesToGiB(stdVolSize),
SnapshotID: "snapshot-id",
}

mockCtl := gomock.NewController(t)
defer mockCtl.Finish()

mockCloud := mocks.NewMockCloud(mockCtl)
mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound)
mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil)

awsDriver := controllerService{
cloud: mockCloud,
driverOptions: &DriverOptions{},
}

rsp, err := awsDriver.CreateVolume(ctx, req)
if err != nil {
srvErr, ok := status.FromError(err)
if !ok {
t.Fatalf("Could not get error status code from error: %v", srvErr)
}
t.Fatalf("Unexpected error: %v", srvErr.Code())
}

snapshotID := ""
if rsp.Volume != nil && rsp.Volume.ContentSource != nil && rsp.Volume.ContentSource.GetSnapshot() != nil {
snapshotID = rsp.Volume.ContentSource.GetSnapshot().SnapshotId
}
if rsp.Volume.ContentSource.GetSnapshot().SnapshotId != "snapshot-id" {
t.Errorf("Unexpected snapshot ID: %q", snapshotID)
}
},
},
{
name: "restore snapshot, volume already exists",
testFunc: func(t *testing.T) {
req := &csi.CreateVolumeRequest{
Name: "random-vol-name",
CapacityRange: stdCapRange,
VolumeCapabilities: stdVolCap,
Parameters: nil,
VolumeContentSource: &csi.VolumeContentSource{
Type: &csi.VolumeContentSource_Snapshot{
Snapshot: &csi.VolumeContentSource_SnapshotSource{
SnapshotId: "snapshot-id",
},
},
},
}

ctx := context.Background()

mockDisk := &cloud.Disk{
VolumeID: req.Name,
AvailabilityZone: expZone,
CapacityGiB: util.BytesToGiB(stdVolSize),
SnapshotID: "snapshot-id",
}

mockCtl := gomock.NewController(t)
defer mockCtl.Finish()

mockCloud := mocks.NewMockCloud(mockCtl)
mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(mockDisk, nil)

awsDriver := controllerService{
cloud: mockCloud,
driverOptions: &DriverOptions{},
}

rsp, err := awsDriver.CreateVolume(ctx, req)
if err != nil {
srvErr, ok := status.FromError(err)
if !ok {
t.Fatalf("Could not get error status code from error: %v", srvErr)
}
t.Fatalf("Unexpected error: %v", srvErr.Code())
}

snapshotID := ""
if rsp.Volume != nil && rsp.Volume.ContentSource != nil && rsp.Volume.ContentSource.GetSnapshot() != nil {
snapshotID = rsp.Volume.ContentSource.GetSnapshot().SnapshotId
}
if rsp.Volume.ContentSource.GetSnapshot().SnapshotId != "snapshot-id" {
t.Errorf("Unexpected snapshot ID: %q", snapshotID)
}
},
},
{
name: "restore snapshot, volume already exists with different snapshot ID",
testFunc: func(t *testing.T) {
req := &csi.CreateVolumeRequest{
Name: "random-vol-name",
CapacityRange: stdCapRange,
VolumeCapabilities: stdVolCap,
Parameters: nil,
VolumeContentSource: &csi.VolumeContentSource{
Type: &csi.VolumeContentSource_Snapshot{
Snapshot: &csi.VolumeContentSource_SnapshotSource{
SnapshotId: "snapshot-id",
},
},
},
}

ctx := context.Background()

mockDisk := &cloud.Disk{
VolumeID: req.Name,
AvailabilityZone: expZone,
CapacityGiB: util.BytesToGiB(stdVolSize),
SnapshotID: "another-snapshot-id",
}

mockCtl := gomock.NewController(t)
defer mockCtl.Finish()

mockCloud := mocks.NewMockCloud(mockCtl)
mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(mockDisk, nil)

awsDriver := controllerService{
cloud: mockCloud,
driverOptions: &DriverOptions{},
}

if _, err := awsDriver.CreateVolume(ctx, req); err == nil {
t.Error("CreateVolume with invalid SnapshotID unexpectedly succeeded")
}
},
},
{
name: "fail no name",
testFunc: func(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions tests/sanity/fake_cloud_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func (c *fakeCloudProvider) CreateDisk(ctx context.Context, volumeName string, d
VolumeID: fmt.Sprintf("vol-%d", r1.Uint64()),
CapacityGiB: util.BytesToGiB(diskOptions.CapacityBytes),
AvailabilityZone: diskOptions.AvailabilityZone,
SnapshotID: diskOptions.SnapshotID,
},
tags: diskOptions.Tags,
}
Expand Down

0 comments on commit db95482

Please sign in to comment.