diff --git a/docs/api/middleware/cors.md b/docs/api/middleware/cors.md index 9a28342fd3..ca250833d6 100644 --- a/docs/api/middleware/cors.md +++ b/docs/api/middleware/cors.md @@ -10,6 +10,8 @@ The middleware conforms to the `access-control-allow-origin` specification by pa For more control, `AllowOriginsFunc` can be used to programatically determine if an origin is allowed. If no match was found in `AllowOrigins` and if `AllowOriginsFunc` returns true then the 'access-control-allow-origin' response header is set to the 'origin' request header. +When defining your Origins make sure they are properly formatted. The middleware validates and normalizes the provided origins, ensuring they're in the correct format by checking for valid schemes (http or https), and removing any trailing slashes. + ## Signatures ```go @@ -56,18 +58,27 @@ app.Use(cors.New(cors.Config{ })) ``` +**Note: The following configuration is considered insecure and will result in a panic.** + +```go +app.Use(cors.New(cors.Config{ + AllowOrigins: "*", + AllowCredentials: true, +})) +``` + ## Config -| Property | Type | Description | Default | -|:-----------------|:---------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------| -| Next | `func(*fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` | -| AllowOriginsFunc | `func(origin string) bool` | AllowOriginsFunc defines a function that will set the 'access-control-allow-origin' response header to the 'origin' request header when returned true. | `nil` | -| AllowOrigins | `string` | AllowOrigin defines a comma separated list of origins that may access the resource. | `"*"` | -| AllowMethods | `string` | AllowMethods defines a list of methods allowed when accessing the resource. This is used in response to a preflight request. | `"GET,POST,HEAD,PUT,DELETE,PATCH"` | -| AllowHeaders | `string` | AllowHeaders defines a list of request headers that can be used when making the actual request. This is in response to a preflight request. | `""` | -| AllowCredentials | `bool` | AllowCredentials indicates whether or not the response to the request can be exposed when the credentials flag is true. | `false` | -| ExposeHeaders | `string` | ExposeHeaders defines a whitelist headers that clients are allowed to access. | `""` | -| MaxAge | `int` | MaxAge indicates how long (in seconds) the results of a preflight request can be cached. If you pass MaxAge 0, Access-Control-Max-Age header will not be added and browser will use 5 seconds by default. To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header 0. | `0` | +| Property | Type | Description | Default | +|:-----------------|:---------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------| +| Next | `func(*fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` | +| AllowOriginsFunc | `func(origin string) bool` | AllowOriginsFunc defines a function that will set the 'access-control-allow-origin' response header to the 'origin' request header when returned true. This allows for dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins will be not have the 'access-control-allow-credentials' header set to 'true'. | `nil` | +| AllowOrigins | `string` | AllowOrigin defines a comma separated list of origins that may access the resource. | `"*"` | +| AllowMethods | `string` | AllowMethods defines a list of methods allowed when accessing the resource. This is used in response to a preflight request. | `"GET,POST,HEAD,PUT,DELETE,PATCH"` | +| AllowHeaders | `string` | AllowHeaders defines a list of request headers that can be used when making the actual request. This is in response to a preflight request. | `""` | +| AllowCredentials | `bool` | AllowCredentials indicates whether or not the response to the request can be exposed when the credentials flag is true. When used as part of a response to a preflight request, this indicates whether or not the actual request can be made using credentials. Note: If true, AllowOrigins cannot be set to a wildcard ("*") to prevent security vulnerabilities. | `false` | +| ExposeHeaders | `string` | ExposeHeaders defines a whitelist headers that clients are allowed to access. | `""` | +| MaxAge | `int` | MaxAge indicates how long (in seconds) the results of a preflight request can be cached. If you pass MaxAge 0, Access-Control-Max-Age header will not be added and browser will use 5 seconds by default. To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header 0. | `0` | ## Default Config diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index ebc1c6b1cb..2ca3767d1f 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -16,12 +16,14 @@ type Config struct { Next func(c *fiber.Ctx) bool // AllowOriginsFunc defines a function that will set the 'access-control-allow-origin' - // response header to the 'origin' request header when returned true. + // response header to the 'origin' request header when returned true. This allows for + // dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins + // will be not have the 'access-control-allow-credentials' header set to 'true'. // // Optional. Default: nil AllowOriginsFunc func(origin string) bool - // AllowOrigin defines a list of origins that may access the resource. + // AllowOrigin defines a comma separated list of origins that may access the resource. // // Optional. Default value "*" AllowOrigins string @@ -41,7 +43,8 @@ type Config struct { // AllowCredentials indicates whether or not the response to the request // can be exposed when the credentials flag is true. When used as part of // a response to a preflight request, this indicates whether or not the - // actual request can be made using credentials. + // actual request can be made using credentials. Note: If true, AllowOrigins + // cannot be set to a wildcard ("*") to prevent security vulnerabilities. // // Optional. Default value false. AllowCredentials bool @@ -105,6 +108,26 @@ func New(config ...Config) fiber.Handler { log.Warn("[CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.") } + // Validate CORS credentials configuration + if cfg.AllowCredentials && cfg.AllowOrigins == "*" { + panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.") + } + + // Validate and normalize static AllowOrigins if not using AllowOriginsFunc + if cfg.AllowOriginsFunc == nil && cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" { + validatedOrigins := []string{} + for _, origin := range strings.Split(cfg.AllowOrigins, ",") { + isValid, normalizedOrigin := normalizeOrigin(origin) + if isValid { + validatedOrigins = append(validatedOrigins, normalizedOrigin) + } else { + log.Warnf("[CORS] Invalid origin format in configuration: %s", origin) + panic("[CORS] Invalid origin provided in configuration") + } + } + cfg.AllowOrigins = strings.Join(validatedOrigins, ",") + } + // Convert string to slice allowOrigins := strings.Split(strings.ReplaceAll(cfg.AllowOrigins, " ", ""), ",") @@ -123,22 +146,18 @@ func New(config ...Config) fiber.Handler { return c.Next() } - // Get origin header - origin := c.Get(fiber.HeaderOrigin) + // Get originHeader header + originHeader := c.Get(fiber.HeaderOrigin) allowOrigin := "" // Check allowed origins - for _, o := range allowOrigins { - if o == "*" { + for _, origin := range allowOrigins { + if origin == "*" { allowOrigin = "*" break } - if o == origin { - allowOrigin = o - break - } - if matchSubdomain(origin, o) { - allowOrigin = origin + if validateDomain(originHeader, origin) { + allowOrigin = originHeader break } } @@ -147,8 +166,8 @@ func New(config ...Config) fiber.Handler { // handling the value in 'AllowOrigins' does // not result in allowOrigin being set. if allowOrigin == "" && cfg.AllowOriginsFunc != nil { - if cfg.AllowOriginsFunc(origin) { - allowOrigin = origin + if cfg.AllowOriginsFunc(originHeader) { + allowOrigin = originHeader } } @@ -173,9 +192,17 @@ func New(config ...Config) fiber.Handler { c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods) - // Set Allow-Credentials if set to true if cfg.AllowCredentials { - c.Set(fiber.HeaderAccessControlAllowCredentials, "true") + // When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*' + if allowOrigin != "*" && allowOrigin != "" { + c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) + c.Set(fiber.HeaderAccessControlAllowCredentials, "true") + } else if allowOrigin == "*" { + log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.") + } + } else { + // For non-credential requests, it's safe to set to '*' or specific origins + c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) } // Set Allow-Headers if not empty diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 22fef8442b..9fc2852556 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -35,7 +35,7 @@ func Test_CORS_Negative_MaxAge(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) - ctx.Request.Header.Set(fiber.HeaderOrigin, "localhost") + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") app.Handler()(ctx) utils.AssertEqual(t, "0", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) @@ -72,7 +72,46 @@ func Test_CORS_Wildcard(t *testing.T) { app := fiber.New() // OPTIONS (preflight) response headers when AllowOrigins is * app.Use(New(Config{ - AllowOrigins: "*", + AllowOrigins: "*", + MaxAge: 3600, + ExposeHeaders: "X-Request-ID", + AllowHeaders: "Authentication", + })) + // Get handler pointer + handler := app.Handler() + + // Make request + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/") + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") + ctx.Request.Header.SetMethod(fiber.MethodOptions) + + // Perform request + handler(ctx) + + // Check result + utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) // Validates request is not reflecting origin in the response + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) + utils.AssertEqual(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) + utils.AssertEqual(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders))) + + // Test non OPTIONS (preflight) response headers + ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + handler(ctx) + + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) + utils.AssertEqual(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders))) +} + +// go test -run -v Test_CORS_Origin_AllowCredentials +func Test_CORS_Origin_AllowCredentials(t *testing.T) { + t.Parallel() + // New fiber instance + app := fiber.New() + // OPTIONS (preflight) response headers when AllowOrigins is * + app.Use(New(Config{ + AllowOrigins: "http://localhost", AllowCredentials: true, MaxAge: 3600, ExposeHeaders: "X-Request-ID", @@ -84,14 +123,14 @@ func Test_CORS_Wildcard(t *testing.T) { // Make request ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") - ctx.Request.Header.Set(fiber.HeaderOrigin, "localhost") + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") ctx.Request.Header.SetMethod(fiber.MethodOptions) // Perform request handler(ctx) // Check result - utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) + utils.AssertEqual(t, "http://localhost", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) utils.AssertEqual(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) utils.AssertEqual(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) utils.AssertEqual(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders))) @@ -105,6 +144,57 @@ func Test_CORS_Wildcard(t *testing.T) { utils.AssertEqual(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders))) } +// go test -run -v Test_CORS_Wildcard_AllowCredentials_Panic +// Test for fiber-ghsa-fmg4-x8pw-hjhg +func Test_CORS_Wildcard_AllowCredentials_Panic(t *testing.T) { + t.Parallel() + // New fiber instance + app := fiber.New() + + didPanic := false + func() { + defer func() { + if r := recover(); r != nil { + didPanic = true + } + }() + + app.Use(New(Config{ + AllowOrigins: "*", + AllowCredentials: true, + })) + }() + + if !didPanic { + t.Errorf("Expected a panic when AllowOrigins is '*' and AllowCredentials is true") + } +} + +// go test -run -v Test_CORS_Invalid_Origin_Panic +func Test_CORS_Invalid_Origin_Panic(t *testing.T) { + t.Parallel() + // New fiber instance + app := fiber.New() + + didPanic := false + func() { + defer func() { + if r := recover(); r != nil { + didPanic = true + } + }() + + app.Use(New(Config{ + AllowOrigins: "localhost", + AllowCredentials: true, + })) + }() + + if !didPanic { + t.Errorf("Expected a panic when Origin is missing scheme") + } +} + // go test -run -v Test_CORS_Subdomain func Test_CORS_Subdomain(t *testing.T) { t.Parallel() @@ -193,12 +283,9 @@ func Test_CORS_AllowOriginScheme(t *testing.T) { shouldAllowOrigin: false, }, { - pattern: "http://*.example.com", - reqOrigin: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ - .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ - .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ - .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`, - shouldAllowOrigin: false, + pattern: "http://*.example.com", + reqOrigin: "http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com", + shouldAllowOrigin: true, }, { pattern: "http://example.com", @@ -471,12 +558,13 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { } // The fix for issue #2422 -func Test_CORS_AllowCredetials(t *testing.T) { +func Test_CORS_AllowCredentials(t *testing.T) { testCases := []struct { - Name string - Config Config - RequestOrigin string - ResponseOrigin string + Name string + Config Config + RequestOrigin string + ResponseOrigin string + ResponseCredentials string }{ { Name: "AllowOriginsFuncDefined", @@ -488,19 +576,35 @@ func Test_CORS_AllowCredetials(t *testing.T) { }, RequestOrigin: "http://aaa.com", // The AllowOriginsFunc config was defined, should use the real origin of the function - ResponseOrigin: "http://aaa.com", + ResponseOrigin: "http://aaa.com", + ResponseCredentials: "true", }, { - Name: "AllowOriginsFuncNotDefined", + Name: "fiber-ghsa-fmg4-x8pw-hjhg-wildcard-credentials", Config: Config{ AllowCredentials: true, + AllowOriginsFunc: func(origin string) bool { + return true + }, + }, + RequestOrigin: "*", + ResponseOrigin: "*", + // Middleware will validate that wildcard wont set credentials to true + ResponseCredentials: "", + }, + { + Name: "AllowOriginsFuncNotDefined", + Config: Config{ + // Setting this to true will cause the middleware to panic since default AllowOrigins is "*" + AllowCredentials: false, }, RequestOrigin: "http://aaa.com", // None of the AllowOrigins or AllowOriginsFunc config was defined, should use the default origin of "*" // which will cause the CORS error in the client: // The value of the 'Access-Control-Allow-Origin' header in the response must not be the wildcard '*' // when the request's credentials mode is 'include'. - ResponseOrigin: "*", + ResponseOrigin: "*", + ResponseCredentials: "", }, { Name: "AllowOriginsDefined", @@ -508,8 +612,9 @@ func Test_CORS_AllowCredetials(t *testing.T) { AllowCredentials: true, AllowOrigins: "http://aaa.com", }, - RequestOrigin: "http://aaa.com", - ResponseOrigin: "http://aaa.com", + RequestOrigin: "http://aaa.com", + ResponseOrigin: "http://aaa.com", + ResponseCredentials: "true", }, { Name: "AllowOriginsDefined/UnallowedOrigin", @@ -517,8 +622,9 @@ func Test_CORS_AllowCredetials(t *testing.T) { AllowCredentials: true, AllowOrigins: "http://aaa.com", }, - RequestOrigin: "http://bbb.com", - ResponseOrigin: "", + RequestOrigin: "http://bbb.com", + ResponseOrigin: "", + ResponseCredentials: "", }, } @@ -536,9 +642,7 @@ func Test_CORS_AllowCredetials(t *testing.T) { handler(ctx) - if tc.Config.AllowCredentials { - utils.AssertEqual(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) - } + utils.AssertEqual(t, tc.ResponseCredentials, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) utils.AssertEqual(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) }) } diff --git a/middleware/cors/utils.go b/middleware/cors/utils.go index 8b6114bdab..313a430c77 100644 --- a/middleware/cors/utils.go +++ b/middleware/cors/utils.go @@ -1,56 +1,83 @@ package cors import ( + "net/url" "strings" ) +// matchScheme compares the scheme of the domain and pattern func matchScheme(domain, pattern string) bool { didx := strings.Index(domain, ":") pidx := strings.Index(pattern, ":") return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx] } -// matchSubdomain compares authority with wildcard -func matchSubdomain(domain, pattern string) bool { - if !matchScheme(domain, pattern) { - return false +// validateDomain checks if the domain matches the pattern +func validateDomain(domain, pattern string) bool { + // Directly compare the domain and pattern for an exact match. + if domain == pattern { + return true } - didx := strings.Index(domain, "://") - pidx := strings.Index(pattern, "://") - if didx == -1 || pidx == -1 { - return false + + // Normalize domain and pattern to exclude schemes and ports for matching purposes + normalizedDomain := normalizeDomain(domain) + normalizedPattern := normalizeDomain(pattern) + + // Handling the case where pattern is a wildcard subdomain pattern. + if strings.HasPrefix(normalizedPattern, "*.") { + // Trim leading "*." from pattern for comparison. + trimmedPattern := normalizedPattern[2:] + + // Check if the domain ends with the trimmed pattern. + if strings.HasSuffix(normalizedDomain, trimmedPattern) { + // Ensure that the domain is not exactly the base domain. + if normalizedDomain != trimmedPattern { + // Special handling to prevent "example.com" matching "*.example.com". + if strings.TrimSuffix(normalizedDomain, trimmedPattern) != "" { + return true + } + } + } } - domAuth := domain[didx+3:] - // to avoid long loop by invalid long domain - const maxDomainLen = 253 - if len(domAuth) > maxDomainLen { - return false + + return false +} + +// normalizeDomain removes the scheme and port from the input domain +func normalizeDomain(input string) string { + // Remove scheme + input = strings.TrimPrefix(strings.TrimPrefix(input, "http://"), "https://") + + // Find and remove port, if present + if portIndex := strings.Index(input, ":"); portIndex != -1 { + input = input[:portIndex] } - patAuth := pattern[pidx+3:] - - domComp := strings.Split(domAuth, ".") - patComp := strings.Split(patAuth, ".") - const divHalf = 2 - for i := len(domComp)/divHalf - 1; i >= 0; i-- { - opp := len(domComp) - 1 - i - domComp[i], domComp[opp] = domComp[opp], domComp[i] + + return input +} + +// normalizeOrigin checks if the provided origin is in a correct format +// and normalizes it by removing any path or trailing slash. +// It returns a boolean indicating whether the origin is valid +// and the normalized origin. +func normalizeOrigin(origin string) (bool, string) { + parsedOrigin, err := url.Parse(origin) + if err != nil { + return false, "" } - for i := len(patComp)/divHalf - 1; i >= 0; i-- { - opp := len(patComp) - 1 - i - patComp[i], patComp[opp] = patComp[opp], patComp[i] + + // Validate the scheme is either http or https + if parsedOrigin.Scheme != "http" && parsedOrigin.Scheme != "https" { + return false, "" } - for i, v := range domComp { - if len(patComp) <= i { - return false - } - p := patComp[i] - if p == "*" { - return true - } - if p != v { - return false - } + // Validate there is a host present. The presence of a path, query, or fragment components + // is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized + if parsedOrigin.Host == "" || (parsedOrigin.Path != "" && parsedOrigin.Path != "/") || parsedOrigin.RawQuery != "" || parsedOrigin.Fragment != "" { + return false, "" } - return false + + // Normalize the origin by constructing it from the scheme and host. + // The path or trailing slash is not included in the normalized origin. + return true, strings.ToLower(parsedOrigin.Scheme) + "://" + strings.ToLower(parsedOrigin.Host) } diff --git a/middleware/cors/utils_test.go b/middleware/cors/utils_test.go new file mode 100644 index 0000000000..3acd692521 --- /dev/null +++ b/middleware/cors/utils_test.go @@ -0,0 +1,145 @@ +package cors + +import ( + "testing" +) + +// go test -run -v Test_normalizeOrigin +func Test_normalizeOrigin(t *testing.T) { + testCases := []struct { + origin string + expectedValid bool + expectedOrigin string + }{ + {"http://example.com", true, "http://example.com"}, // Simple case should work. + {"http://example.com/", true, "http://example.com"}, // Trailing slash should be removed. + {"http://example.com:3000", true, "http://example.com:3000"}, // Port should be preserved. + {"http://example.com:3000/", true, "http://example.com:3000"}, // Trailing slash should be removed. + {"http://", false, ""}, // Invalid origin should not be accepted. + {"http://example.com/path", false, ""}, // Path should not be accepted. + {"http://example.com?query=123", false, ""}, // Query should not be accepted. + {"http://example.com#fragment", false, ""}, // Fragment should not be accepted. + {"http://localhost", true, "http://localhost"}, // Localhost should be accepted. + {"http://127.0.0.1", true, "http://127.0.0.1"}, // IPv4 address should be accepted. + {"http://[::1]", true, "http://[::1]"}, // IPv6 address should be accepted. + {"http://[::1]:8080", true, "http://[::1]:8080"}, // IPv6 address with port should be accepted. + {"http://[::1]:8080/", true, "http://[::1]:8080"}, // IPv6 address with port and trailing slash should be accepted. + {"http://[::1]:8080/path", false, ""}, // IPv6 address with port and path should not be accepted. + {"http://[::1]:8080?query=123", false, ""}, // IPv6 address with port and query should not be accepted. + {"http://[::1]:8080#fragment", false, ""}, // IPv6 address with port and fragment should not be accepted. + {"http://[::1]:8080/path?query=123#fragment", false, ""}, // IPv6 address with port, path, query, and fragment should not be accepted. + {"http://[::1]:8080/path?query=123#fragment/", false, ""}, // IPv6 address with port, path, query, fragment, and trailing slash should not be accepted. + {"http://[::1]:8080/path?query=123#fragment/invalid", false, ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment should not be accepted. + {"http://[::1]:8080/path?query=123#fragment/invalid/", false, ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with trailing slash should not be accepted. + {"http://[::1]:8080/path?query=123#fragment/invalid/segment", false, ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with additional segment should not be accepted. + } + + for _, tc := range testCases { + valid, normalizedOrigin := normalizeOrigin(tc.origin) + + if valid != tc.expectedValid { + t.Errorf("Expected origin '%s' to be valid: %v, but got: %v", tc.origin, tc.expectedValid, valid) + } + + if normalizedOrigin != tc.expectedOrigin { + t.Errorf("Expected normalized origin '%s' for origin '%s', but got: '%s'", tc.expectedOrigin, tc.origin, normalizedOrigin) + } + } +} + +// go test -run -v Test_matchScheme +func Test_matchScheme(t *testing.T) { + testCases := []struct { + domain string + pattern string + expected bool + }{ + {"http://example.com", "http://example.com", true}, // Exact match should work. + {"https://example.com", "http://example.com", false}, // Scheme mismatch should matter. + {"http://example.com", "https://example.com", false}, // Scheme mismatch should matter. + {"http://example.com", "http://example.org", true}, // Different domains should not matter. + {"http://example.com", "http://example.com:8080", true}, // Port should not matter. + {"http://example.com:8080", "http://example.com", true}, // Port should not matter. + {"http://example.com:8080", "http://example.com:8081", true}, // Different ports should not matter. + {"http://localhost", "http://localhost", true}, // Localhost should match. + {"http://127.0.0.1", "http://127.0.0.1", true}, // IPv4 address should match. + {"http://[::1]", "http://[::1]", true}, // IPv6 address should match. + } + + for _, tc := range testCases { + result := matchScheme(tc.domain, tc.pattern) + + if result != tc.expected { + t.Errorf("Expected matchScheme('%s', '%s') to be %v, but got %v", tc.domain, tc.pattern, tc.expected, result) + } + } +} + +// go test -run -v Test_validateOrigin +func Test_validateOrigin(t *testing.T) { + testCases := []struct { + domain string + pattern string + expected bool + }{ + {"http://example.com", "http://example.com", true}, // Exact match should work. + {"https://example.com", "http://example.com", false}, // Scheme mismatch should matter in CORS context. + {"http://example.com", "https://example.com", false}, // Scheme mismatch should matter in CORS context. + {"http://example.com", "http://example.org", false}, // Different domains should not match. + {"http://example.com", "http://example.com:8080", false}, // Port mismatch should matter. + {"http://example.com:8080", "http://example.com", false}, // Port mismatch should matter. + {"http://example.com:8080", "http://example.com:8081", false}, // Different ports should not match. + {"example.com", "example.com", true}, // Simplified form, assuming scheme and port are not considered here, but in practice, they are part of the origin. + {"sub.example.com", "example.com", false}, // Subdomain should not match the base domain directly. + {"sub.example.com", "*.example.com", true}, // Correct assumption for wildcard subdomain matching. + {"example.com", "*.example.com", false}, // Base domain should not match its wildcard subdomain pattern. + {"sub.example.com", "*.com", true}, // Technically correct for pattern matching, but broad wildcard use like this is not recommended for CORS. + {"sub.sub.example.com", "*.example.com", true}, // Nested subdomain should match the wildcard pattern. + {"example.com", "*.org", false}, // Different TLDs should not match. + {"example.com", "example.org", false}, // Different domains should not match. + {"example.com:8080", "*.example.com", false}, // Different ports mean different origins. + {"example.com", "sub.example.net", false}, // Different domains should not match. + {"http://localhost", "http://localhost", true}, // Localhost should match. + {"http://127.0.0.1", "http://127.0.0.1", true}, // IPv4 address should match. + {"http://[::1]", "http://[::1]", true}, // IPv6 address should match. + } + + for _, tc := range testCases { + result := validateDomain(tc.domain, tc.pattern) + + if result != tc.expected { + t.Errorf("Expected validateOrigin('%s', '%s') to be %v, but got %v", tc.domain, tc.pattern, tc.expected, result) + } + } +} + +// go test -run -v Test_normalizeDomain +func Test_normalizeDomain(t *testing.T) { + testCases := []struct { + input string + expectedOutput string + }{ + {"http://example.com", "example.com"}, // Simple case with http scheme. + {"https://example.com", "example.com"}, // Simple case with https scheme. + {"http://example.com:3000", "example.com"}, // Case with port. + {"https://example.com:3000", "example.com"}, // Case with port and https scheme. + {"http://example.com/path", "example.com/path"}, // Case with path. + {"http://example.com?query=123", "example.com?query=123"}, // Case with query. + {"http://example.com#fragment", "example.com#fragment"}, // Case with fragment. + {"example.com", "example.com"}, // Case without scheme. + {"example.com:8080", "example.com"}, // Case without scheme but with port. + {"sub.example.com", "sub.example.com"}, // Case with subdomain. + {"sub.sub.example.com", "sub.sub.example.com"}, // Case with nested subdomain. + {"http://localhost", "localhost"}, // Case with localhost. + {"http://127.0.0.1", "127.0.0.1"}, // Case with IPv4 address. + {"http://[::1]", "[::1]"}, // Case with IPv6 address. + } + + for _, tc := range testCases { + output := normalizeDomain(tc.input) + + if output != tc.expectedOutput { + t.Errorf("Expected normalized domain '%s' for input '%s', but got: '%s'", tc.expectedOutput, tc.input, output) + } + } +}