diff --git a/velero-plugin-for-microsoft-azure/volume_snapshotter.go b/velero-plugin-for-microsoft-azure/volume_snapshotter.go index 8a9c548..5ed2055 100644 --- a/velero-plugin-for-microsoft-azure/volume_snapshotter.go +++ b/velero-plugin-for-microsoft-azure/volume_snapshotter.go @@ -358,8 +358,11 @@ func getComputeResourceName(subscription, resourceGroup, resource, name string) return fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/%s/%s", subscription, resourceGroup, resource, name) } -var snapshotURIRegexp = regexp.MustCompile( - `^\/subscriptions\/(?P.*)\/resourceGroups\/(?P.*)\/providers\/Microsoft.Compute\/snapshots\/(?P.*)$`) +var ( + snapshotURIRegexp = regexp.MustCompile( + `^\/subscriptions\/(?P.*)\/resourceGroups\/(?P.*)\/providers\/Microsoft.Compute\/snapshots\/(?P.*)$`) + diskURIRegexp = regexp.MustCompile(`\/Microsoft.Compute\/disks\/.*$`) +) // parseFullSnapshotName takes a fully-qualified snapshot name and returns // a snapshot identifier or an error if the snapshot name does not match the @@ -396,6 +399,10 @@ func (b *VolumeSnapshotter) GetVolumeID(unstructuredPV runtime.Unstructured) (st return "", errors.WithStack(err) } + if pv.Spec.CSI != nil && pv.Spec.CSI.Driver == "disk.csi.azure.com" { + return strings.TrimPrefix(diskURIRegexp.FindString(pv.Spec.CSI.VolumeHandle), "/Microsoft.Compute/disks/"), nil + } + if pv.Spec.AzureDisk == nil { return "", nil } @@ -413,12 +420,19 @@ func (b *VolumeSnapshotter) SetVolumeID(unstructuredPV runtime.Unstructured, vol return nil, errors.WithStack(err) } - if pv.Spec.AzureDisk == nil { - return nil, errors.New("spec.azureDisk not found") - } + if pv.Spec.CSI != nil && pv.Spec.CSI.Driver == "disk.csi.azure.com" { + pv.Spec.CSI.VolumeHandle = getComputeResourceName(b.disksSubscription, b.disksResourceGroup, disksResource, volumeID) + if _, exist := pv.Spec.CSI.VolumeAttributes["csi.storage.k8s.io/pv/name"]; exist { + pv.Spec.CSI.VolumeAttributes["csi.storage.k8s.io/pv/name"] = volumeID + } + } else { + if pv.Spec.AzureDisk == nil { + return nil, errors.New("spec.azureDisk not found") + } - pv.Spec.AzureDisk.DiskName = volumeID - pv.Spec.AzureDisk.DataDiskURI = getComputeResourceName(b.disksSubscription, b.disksResourceGroup, disksResource, volumeID) + pv.Spec.AzureDisk.DiskName = volumeID + pv.Spec.AzureDisk.DataDiskURI = getComputeResourceName(b.disksSubscription, b.disksResourceGroup, disksResource, volumeID) + } res, err := runtime.DefaultUnstructuredConverter.ToUnstructured(pv) if err != nil { diff --git a/velero-plugin-for-microsoft-azure/volume_snapshotter_test.go b/velero-plugin-for-microsoft-azure/volume_snapshotter_test.go index 1f909de..857ee69 100644 --- a/velero-plugin-for-microsoft-azure/volume_snapshotter_test.go +++ b/velero-plugin-for-microsoft-azure/volume_snapshotter_test.go @@ -52,6 +52,16 @@ func TestGetVolumeID(t *testing.T) { volumeID, err = b.GetVolumeID(pv) assert.NoError(t, err) assert.Equal(t, "foo", volumeID) + + // CSI driver + csi := map[string]interface{}{ + "driver": "disk.csi.azure.com", + "volumeHandle": " /subscriptions/subscription-id/resourceGroups/resource-group-name/providers/Microsoft.Compute/disks/bar", + } + pv.Object["spec"].(map[string]interface{})["csi"] = csi + volumeID, err = b.GetVolumeID(pv) + assert.NoError(t, err) + assert.Equal(t, "bar", volumeID) } func TestSetVolumeID(t *testing.T) { @@ -92,6 +102,24 @@ func TestSetVolumeID(t *testing.T) { require.NotNil(t, res.Spec.AzureDisk) assert.Equal(t, "revised", res.Spec.AzureDisk.DiskName) assert.Equal(t, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/disks/revised", res.Spec.AzureDisk.DataDiskURI) + + // CSI driver + csi := map[string]interface{}{ + "driver": "disk.csi.azure.com", + "volumeHandle": " /subscriptions/subscription-id/resourceGroups/resource-group-name/providers/Microsoft.Compute/disks/foo", + "volumeAttributes": map[string]string{ + "csi.storage.k8s.io/pv/name": "foo", + }, + } + pv.Object["spec"].(map[string]interface{})["csi"] = csi + updatedPV, err = b.SetVolumeID(pv, "updated") + require.NoError(t, err) + + res = new(v1.PersistentVolume) + require.NoError(t, runtime.DefaultUnstructuredConverter.FromUnstructured(updatedPV.UnstructuredContent(), res)) + require.NotNil(t, res.Spec.CSI) + assert.Equal(t, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/disks/updated", res.Spec.CSI.VolumeHandle) + assert.Equal(t, "updated", res.Spec.CSI.VolumeAttributes["csi.storage.k8s.io/pv/name"]) } func TestParseFullSnapshotName(t *testing.T) {