diff --git a/.changes/unreleased/ENHANCEMENTS-20230222-001027.yaml b/.changes/unreleased/ENHANCEMENTS-20230222-001027.yaml new file mode 100644 index 0000000000..3dad317c82 --- /dev/null +++ b/.changes/unreleased/ENHANCEMENTS-20230222-001027.yaml @@ -0,0 +1,6 @@ +kind: ENHANCEMENTS +body: Add `SchemaValidateDiagFunc` variants of helper/validation `All` and `Any` as + `AllDiag` and `AnyDiag`. +time: 2023-02-22T00:10:27.474224+02:00 +custom: + Issue: "1155" diff --git a/helper/validation/meta.go b/helper/validation/meta.go index f24f7fa28a..941476429c 100644 --- a/helper/validation/meta.go +++ b/helper/validation/meta.go @@ -45,6 +45,18 @@ func All(validators ...schema.SchemaValidateFunc) schema.SchemaValidateFunc { } } +// AllDiag returns a SchemaValidateDiagFunc which tests if the provided value +// passes all provided SchemaValidateDiagFunc +func AllDiag(validators ...schema.SchemaValidateDiagFunc) schema.SchemaValidateDiagFunc { + return func(i interface{}, k cty.Path) diag.Diagnostics { + var diags diag.Diagnostics + for _, validator := range validators { + diags = append(diags, validator(i, k)...) + } + return diags + } +} + // Any returns a SchemaValidateFunc which tests if the provided value // passes any of the provided SchemaValidateFunc func Any(validators ...schema.SchemaValidateFunc) schema.SchemaValidateFunc { @@ -63,6 +75,22 @@ func Any(validators ...schema.SchemaValidateFunc) schema.SchemaValidateFunc { } } +// AnyDiag returns a SchemaValidateDiagFunc which tests if the provided value +// passes any of the provided SchemaValidateDiagFunc +func AnyDiag(validators ...schema.SchemaValidateDiagFunc) schema.SchemaValidateDiagFunc { + return func(i interface{}, k cty.Path) diag.Diagnostics { + var diags diag.Diagnostics + for _, validator := range validators { + validatorDiags := validator(i, k) + if len(validatorDiags) == 0 { + return diag.Diagnostics{} + } + diags = append(diags, validatorDiags...) + } + return diags + } +} + // ToDiagFunc is a wrapper for legacy schema.SchemaValidateFunc // converting it to schema.SchemaValidateDiagFunc func ToDiagFunc(validator schema.SchemaValidateFunc) schema.SchemaValidateDiagFunc { diff --git a/helper/validation/meta_test.go b/helper/validation/meta_test.go index a8cbb78a66..767f955aec 100644 --- a/helper/validation/meta_test.go +++ b/helper/validation/meta_test.go @@ -71,6 +71,34 @@ func TestValidationAll(t *testing.T) { }) } +func TestValidationAllDiag(t *testing.T) { + runDiagTestCases(t, []diagTestCase{ + { + val: "valid", + f: AllDiag( + ToDiagFunc(StringLenBetween(5, 42)), + ToDiagFunc(StringMatch(regexp.MustCompile(`[a-zA-Z0-9]+`), "value must be alphanumeric")), + ), + }, + { + val: "foo", + f: AllDiag( + ToDiagFunc(StringLenBetween(5, 42)), + ToDiagFunc(StringMatch(regexp.MustCompile(`[a-zA-Z0-9]+`), "value must be alphanumeric")), + ), + expectedDiagSummary: regexp.MustCompile(`expected length of [\w]+ to be in the range \(5 - 42\), got foo`), + }, + { + val: "!!!!!", + f: AllDiag( + ToDiagFunc(StringLenBetween(5, 42)), + ToDiagFunc(StringMatch(regexp.MustCompile(`[a-zA-Z0-9]+`), "value must be alphanumeric")), + ), + expectedDiagSummary: regexp.MustCompile("value must be alphanumeric"), + }, + }) +} + func TestValidationAny(t *testing.T) { runTestCases(t, []testCase{ { @@ -106,6 +134,41 @@ func TestValidationAny(t *testing.T) { }) } +func TestValidationAnyDiag(t *testing.T) { + runDiagTestCases(t, []diagTestCase{ + { + val: 43, + f: AnyDiag( + ToDiagFunc(IntAtLeast(42)), + ToDiagFunc(IntAtMost(5)), + ), + }, + { + val: 4, + f: AnyDiag( + ToDiagFunc(IntAtLeast(42)), + ToDiagFunc(IntAtMost(5)), + ), + }, + { + val: 7, + f: AnyDiag( + ToDiagFunc(IntAtLeast(42)), + ToDiagFunc(IntAtMost(5)), + ), + expectedDiagSummary: regexp.MustCompile(`expected [\w]+ to be at least \(42\), got 7`), + }, + { + val: 7, + f: AnyDiag( + ToDiagFunc(IntAtLeast(42)), + ToDiagFunc(IntAtMost(5)), + ), + expectedDiagSummary: regexp.MustCompile(`expected [\w]+ to be at most \(5\), got 7`), + }, + }) +} + func TestToDiagFunc(t *testing.T) { t.Parallel() diff --git a/helper/validation/testing.go b/helper/validation/testing.go index 8dadd66fcf..a5aa6e0454 100644 --- a/helper/validation/testing.go +++ b/helper/validation/testing.go @@ -4,9 +4,11 @@ package validation import ( + "fmt" "regexp" + "testing" - testing "github.com/mitchellh/go-testing-interface" + "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" @@ -18,23 +20,53 @@ type testCase struct { expectedErr *regexp.Regexp } -func runTestCases(t testing.T, cases []testCase) { +func runTestCases(t *testing.T, cases []testCase) { t.Helper() for i, tc := range cases { - _, errs := tc.f(tc.val, "test_property") + t.Run(fmt.Sprintf("TestCase_%d", i), func(t *testing.T) { + _, errs := tc.f(tc.val, "test_property") - if len(errs) == 0 && tc.expectedErr == nil { - continue - } + if len(errs) == 0 && tc.expectedErr == nil { + return + } - if len(errs) != 0 && tc.expectedErr == nil { - t.Fatalf("expected test case %d to produce no errors, got %v", i, errs) - } + if len(errs) != 0 && tc.expectedErr == nil { + t.Fatalf("expected test case %d to produce no errors, got %v", i, errs) + } - if !matchAnyError(errs, tc.expectedErr) { - t.Fatalf("expected test case %d to produce error matching \"%s\", got %v", i, tc.expectedErr, errs) - } + if !matchAnyError(errs, tc.expectedErr) { + t.Fatalf("expected test case %d to produce error matching \"%s\", got %v", i, tc.expectedErr, errs) + } + }) + } +} + +type diagTestCase struct { + val interface{} + f schema.SchemaValidateDiagFunc + expectedDiagSummary *regexp.Regexp +} + +func runDiagTestCases(t *testing.T, cases []diagTestCase) { + t.Helper() + + for i, tc := range cases { + t.Run(fmt.Sprintf("TestCase_%d", i), func(t *testing.T) { + diags := tc.f(tc.val, cty.GetAttrPath("test_property")) + + if len(diags) == 0 && tc.expectedDiagSummary == nil { + return + } + + if len(diags) != 0 && tc.expectedDiagSummary == nil { + t.Fatalf("expected test case %d to produce no diagnostics, got %v", i, diags) + } + + if !matchAnyDiagSummary(diags, tc.expectedDiagSummary) { + t.Fatalf("expected test case %d to produce diagnostic summary matching \"%s\", got %v", i, tc.expectedDiagSummary, diags) + } + }) } }