diff --git a/cmd/webhooks/main.go b/cmd/webhooks/main.go index 74240ff4b41..76bfead8ab0 100644 --- a/cmd/webhooks/main.go +++ b/cmd/webhooks/main.go @@ -78,18 +78,8 @@ func main() { cfg.QPS = webhooksClientRequestQPS cfg.Burst = webhooksClientRequestBurst // Configuring minimum TLS version for the webhook server - var minTLSVersion uint16 - - switch tlsMinVersion { - case "1.0": - minTLSVersion = tls.VersionTLS10 - case "1.1": - minTLSVersion = tls.VersionTLS11 - case "1.2": - minTLSVersion = tls.VersionTLS11 - case "1.3": - minTLSVersion = tls.VersionTLS13 - default: + minTLSVersion, err := kedautil.ParseTLSMinVersionAsString(tlsMinVersion) + if err != nil { setupLog.Error(fmt.Errorf("unsupported minimum TLS version"), fmt.Sprintf("option %s non recognized", tlsMinVersion)) os.Exit(1) } diff --git a/pkg/util/tls_config.go b/pkg/util/tls_config.go index 078b602a579..85503dd9e0a 100644 --- a/pkg/util/tls_config.go +++ b/pkg/util/tls_config.go @@ -23,7 +23,6 @@ import ( "fmt" "os" - "github.com/go-logr/logr" "github.com/youmark/pkcs8" ctrl "sigs.k8s.io/controller-runtime" ) @@ -31,8 +30,13 @@ import ( var minTLSVersion uint16 func init() { - setupLog := ctrl.Log.WithName("tls_setup") - minTLSVersion = initMinTLSVersion(setupLog) + var err error + + version, _ := os.LookupEnv("KEDA_HTTP_MIN_TLS_VERSION") + if minTLSVersion, err = ParseTLSMinVersionAsString(version); err != nil { + ctrl.Log.WithName("tls_setup").Info(err.Error()) + } + } // NewTLSConfigWithPassword returns a *tls.Config using the given ceClient cert, ceClient key, @@ -86,25 +90,23 @@ func GetMinTLSVersion() uint16 { return minTLSVersion } -func initMinTLSVersion(logger logr.Logger) uint16 { - version, found := os.LookupEnv("KEDA_HTTP_MIN_TLS_VERSION") - minVersion := tls.VersionTLS12 - if found { - switch version { - case "TLS13": - minVersion = tls.VersionTLS13 - case "TLS12": - minVersion = tls.VersionTLS12 - case "TLS11": - minVersion = tls.VersionTLS11 - case "TLS10": - minVersion = tls.VersionTLS10 - default: - logger.Info(fmt.Sprintf("%s is not a valid value, using `TLS12`. Allowed values are: `TLS13`,`TLS12`,`TLS11`,`TLS10`", version)) - minVersion = tls.VersionTLS12 - } +func ParseTLSMinVersionAsString(value string) (uint16, error) { + switch value { + case "": + minTLSVersion = tls.VersionTLS12 + case "1.0", "TLS10": + minTLSVersion = tls.VersionTLS10 + case "1.1", "TLS11": + minTLSVersion = tls.VersionTLS11 + case "1.2", "TLS12": + minTLSVersion = tls.VersionTLS12 + case "1.3", "TLS13": + minTLSVersion = tls.VersionTLS13 + default: + return tls.VersionTLS12, fmt.Errorf("%s is not a valid value, using `TLS12`. Allowed values are: `TLS13`,`TLS12`,`TLS11`,`TLS10`", value) } - return uint16(minVersion) + + return minTLSVersion, nil } func decryptClientKey(clientKey, clientKeyPassword string) ([]byte, error) { diff --git a/pkg/util/tls_config_test.go b/pkg/util/tls_config_test.go index 5e19b202d68..cd4259ccb74 100644 --- a/pkg/util/tls_config_test.go +++ b/pkg/util/tls_config_test.go @@ -19,11 +19,8 @@ package util import ( "crypto/tls" "crypto/x509" - "os" "strings" "testing" - - "github.com/go-logr/logr" ) var randomCACert = `-----BEGIN CERTIFICATE----- @@ -256,46 +253,38 @@ func TestNewTLSConfig_WithPassword(t *testing.T) { } type minTLSVersionTestData struct { - envSet bool envValue string expectedVersion uint16 } var minTLSVersionTestDatas = []minTLSVersionTestData{ { - envSet: true, envValue: "TLS10", expectedVersion: tls.VersionTLS10, }, { - envSet: true, envValue: "TLS11", expectedVersion: tls.VersionTLS11, }, { - envSet: true, envValue: "TLS12", expectedVersion: tls.VersionTLS12, }, { - envSet: true, envValue: "TLS13", expectedVersion: tls.VersionTLS13, }, { - envSet: false, expectedVersion: tls.VersionTLS12, }, } func TestResolveMinTLSVersion(t *testing.T) { - defer os.Unsetenv("KEDA_HTTP_MIN_TLS_VERSION") for _, testData := range minTLSVersionTestDatas { - os.Unsetenv("KEDA_HTTP_MIN_TLS_VERSION") - if testData.envSet { - os.Setenv("KEDA_HTTP_MIN_TLS_VERSION", testData.envValue) + minVersion, err := ParseTLSMinVersionAsString(testData.envValue) + if err != nil { + t.Errorf("Expected nil, got an error: %s", err.Error()) } - minVersion := initMinTLSVersion(logr.Discard()) if testData.expectedVersion != minVersion { t.Error("Failed to resolve minTLSVersion correctly", "wants", testData.expectedVersion, "got", minVersion)