diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index c347e43525..ebc1c6b1cb 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -94,13 +94,14 @@ func New(config ...Config) fiber.Handler { if cfg.AllowMethods == "" { cfg.AllowMethods = ConfigDefault.AllowMethods } - if cfg.AllowOrigins == "" { + // When none of the AllowOrigins or AllowOriginsFunc config was defined, set the default AllowOrigins value with "*" + if cfg.AllowOrigins == "" && cfg.AllowOriginsFunc == nil { cfg.AllowOrigins = ConfigDefault.AllowOrigins } } // Warning logs if both AllowOrigins and AllowOriginsFunc are set - if cfg.AllowOrigins != ConfigDefault.AllowOrigins && cfg.AllowOriginsFunc != nil { + if cfg.AllowOrigins != "" && cfg.AllowOriginsFunc != nil { log.Warn("[CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.") } @@ -145,7 +146,7 @@ func New(config ...Config) fiber.Handler { // Run AllowOriginsFunc if the logic for // handling the value in 'AllowOrigins' does // not result in allowOrigin being set. - if (allowOrigin == "" || allowOrigin == ConfigDefault.AllowOrigins) && cfg.AllowOriginsFunc != nil { + if allowOrigin == "" && cfg.AllowOriginsFunc != nil { if cfg.AllowOriginsFunc(origin) { allowOrigin = origin } diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 692a24bfcc..22fef8442b 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -331,9 +331,9 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) { // Perform request handler(ctx) - // Allow-Origin header should be "*" because http://google.com does not satisfy 'strings.Contains(origin, "example-2")' - // and AllowOrigins has not been set so the default "*" is used - utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) + // Allow-Origin header should be empty because http://google.com does not satisfy 'strings.Contains(origin, "example-2")' + // and AllowOrigins has not been set + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) ctx.Request.Reset() ctx.Response.Reset() @@ -348,3 +348,198 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) { // Allow-Origin header should be "http://example-2.com" utils.AssertEqual(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) } + +func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { + testCases := []struct { + Name string + Config Config + RequestOrigin string + ResponseOrigin string + }{ + { + Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginAllowed", + Config: Config{ + AllowOrigins: "http://aaa.com", + AllowOriginsFunc: nil, + }, + RequestOrigin: "http://aaa.com", + ResponseOrigin: "http://aaa.com", + }, + { + Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginNotAllowed", + Config: Config{ + AllowOrigins: "http://aaa.com", + AllowOriginsFunc: nil, + }, + RequestOrigin: "http://bbb.com", + ResponseOrigin: "", + }, + { + Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginAllowed", + Config: Config{ + AllowOrigins: "http://aaa.com", + AllowOriginsFunc: func(origin string) bool { + return true + }, + }, + RequestOrigin: "http://aaa.com", + ResponseOrigin: "http://aaa.com", + }, + { + Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginNotAllowed", + Config: Config{ + AllowOrigins: "http://aaa.com", + AllowOriginsFunc: func(origin string) bool { + return true + }, + }, + RequestOrigin: "http://bbb.com", + ResponseOrigin: "http://bbb.com", + }, + { + Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginAllowed", + Config: Config{ + AllowOrigins: "http://aaa.com", + AllowOriginsFunc: func(origin string) bool { + return false + }, + }, + RequestOrigin: "http://aaa.com", + ResponseOrigin: "http://aaa.com", + }, + { + Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginNotAllowed", + Config: Config{ + AllowOrigins: "http://aaa.com", + AllowOriginsFunc: func(origin string) bool { + return false + }, + }, + RequestOrigin: "http://bbb.com", + ResponseOrigin: "", + }, + { + Name: "AllowOriginsEmpty/AllowOriginsFuncUndefined/OriginAllowed", + Config: Config{ + AllowOrigins: "", + AllowOriginsFunc: nil, + }, + RequestOrigin: "http://aaa.com", + ResponseOrigin: "*", + }, + { + Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsTrue/OriginAllowed", + Config: Config{ + AllowOrigins: "", + AllowOriginsFunc: func(origin string) bool { + return true + }, + }, + RequestOrigin: "http://aaa.com", + ResponseOrigin: "http://aaa.com", + }, + { + Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsFalse/OriginNotAllowed", + Config: Config{ + AllowOrigins: "", + AllowOriginsFunc: func(origin string) bool { + return false + }, + }, + RequestOrigin: "http://aaa.com", + ResponseOrigin: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + app := fiber.New() + app.Use("/", New(tc.Config)) + + handler := app.Handler() + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/") + ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin) + + handler(ctx) + + utils.AssertEqual(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) + }) + } +} + +// The fix for issue #2422 +func Test_CORS_AllowCredetials(t *testing.T) { + testCases := []struct { + Name string + Config Config + RequestOrigin string + ResponseOrigin string + }{ + { + Name: "AllowOriginsFuncDefined", + Config: Config{ + AllowCredentials: true, + AllowOriginsFunc: func(origin string) bool { + return true + }, + }, + RequestOrigin: "http://aaa.com", + // The AllowOriginsFunc config was defined, should use the real origin of the function + ResponseOrigin: "http://aaa.com", + }, + { + Name: "AllowOriginsFuncNotDefined", + Config: Config{ + AllowCredentials: true, + }, + RequestOrigin: "http://aaa.com", + // None of the AllowOrigins or AllowOriginsFunc config was defined, should use the default origin of "*" + // which will cause the CORS error in the client: + // The value of the 'Access-Control-Allow-Origin' header in the response must not be the wildcard '*' + // when the request's credentials mode is 'include'. + ResponseOrigin: "*", + }, + { + Name: "AllowOriginsDefined", + Config: Config{ + AllowCredentials: true, + AllowOrigins: "http://aaa.com", + }, + RequestOrigin: "http://aaa.com", + ResponseOrigin: "http://aaa.com", + }, + { + Name: "AllowOriginsDefined/UnallowedOrigin", + Config: Config{ + AllowCredentials: true, + AllowOrigins: "http://aaa.com", + }, + RequestOrigin: "http://bbb.com", + ResponseOrigin: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + app := fiber.New() + app.Use("/", New(tc.Config)) + + handler := app.Handler() + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/") + ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin) + + handler(ctx) + + if tc.Config.AllowCredentials { + utils.AssertEqual(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) + } + utils.AssertEqual(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) + }) + } +}