Skip to content

Commit

Permalink
🩹 Fix: CORS middleware should use the defined AllowedOriginsFunc conf…
Browse files Browse the repository at this point in the history
…ig when AllowedOrigins is empty (#2771)
  • Loading branch information
muhammadkholidb authored Dec 22, 2023
1 parent 43fa236 commit 1fac52a
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 6 deletions.
7 changes: 4 additions & 3 deletions middleware/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,14 @@ func New(config ...Config) fiber.Handler {
if cfg.AllowMethods == "" {
cfg.AllowMethods = ConfigDefault.AllowMethods
}
if cfg.AllowOrigins == "" {
// When none of the AllowOrigins or AllowOriginsFunc config was defined, set the default AllowOrigins value with "*"
if cfg.AllowOrigins == "" && cfg.AllowOriginsFunc == nil {
cfg.AllowOrigins = ConfigDefault.AllowOrigins
}
}

// Warning logs if both AllowOrigins and AllowOriginsFunc are set
if cfg.AllowOrigins != ConfigDefault.AllowOrigins && cfg.AllowOriginsFunc != nil {
if cfg.AllowOrigins != "" && cfg.AllowOriginsFunc != nil {
log.Warn("[CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.")
}

Expand Down Expand Up @@ -145,7 +146,7 @@ func New(config ...Config) fiber.Handler {
// Run AllowOriginsFunc if the logic for
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if (allowOrigin == "" || allowOrigin == ConfigDefault.AllowOrigins) && cfg.AllowOriginsFunc != nil {
if allowOrigin == "" && cfg.AllowOriginsFunc != nil {
if cfg.AllowOriginsFunc(origin) {
allowOrigin = origin
}
Expand Down
201 changes: 198 additions & 3 deletions middleware/cors/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,9 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) {
// Perform request
handler(ctx)

// Allow-Origin header should be "*" because http://google.com does not satisfy 'strings.Contains(origin, "example-2")'
// and AllowOrigins has not been set so the default "*" is used
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
// Allow-Origin header should be empty because http://google.com does not satisfy 'strings.Contains(origin, "example-2")'
// and AllowOrigins has not been set
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))

ctx.Request.Reset()
ctx.Response.Reset()
Expand All @@ -348,3 +348,198 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) {
// Allow-Origin header should be "http://example-2.com"
utils.AssertEqual(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}

func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
testCases := []struct {
Name string
Config Config
RequestOrigin string
ResponseOrigin string
}{
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginNotAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: func(origin string) bool {
return true
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginNotAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: func(origin string) bool {
return true
},
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "http://bbb.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: func(origin string) bool {
return false
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginNotAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: func(origin string) bool {
return false
},
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
},
{
Name: "AllowOriginsEmpty/AllowOriginsFuncUndefined/OriginAllowed",
Config: Config{
AllowOrigins: "",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "*",
},
{
Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsTrue/OriginAllowed",
Config: Config{
AllowOrigins: "",
AllowOriginsFunc: func(origin string) bool {
return true
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsFalse/OriginNotAllowed",
Config: Config{
AllowOrigins: "",
AllowOriginsFunc: func(origin string) bool {
return false
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "",
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
app := fiber.New()
app.Use("/", New(tc.Config))

handler := app.Handler()

ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)

handler(ctx)

utils.AssertEqual(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
})
}
}

// The fix for issue #2422
func Test_CORS_AllowCredetials(t *testing.T) {
testCases := []struct {
Name string
Config Config
RequestOrigin string
ResponseOrigin string
}{
{
Name: "AllowOriginsFuncDefined",
Config: Config{
AllowCredentials: true,
AllowOriginsFunc: func(origin string) bool {
return true
},
},
RequestOrigin: "http://aaa.com",
// The AllowOriginsFunc config was defined, should use the real origin of the function
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsFuncNotDefined",
Config: Config{
AllowCredentials: true,
},
RequestOrigin: "http://aaa.com",
// None of the AllowOrigins or AllowOriginsFunc config was defined, should use the default origin of "*"
// which will cause the CORS error in the client:
// The value of the 'Access-Control-Allow-Origin' header in the response must not be the wildcard '*'
// when the request's credentials mode is 'include'.
ResponseOrigin: "*",
},
{
Name: "AllowOriginsDefined",
Config: Config{
AllowCredentials: true,
AllowOrigins: "http://aaa.com",
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/UnallowedOrigin",
Config: Config{
AllowCredentials: true,
AllowOrigins: "http://aaa.com",
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
app := fiber.New()
app.Use("/", New(tc.Config))

handler := app.Handler()

ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)

handler(ctx)

if tc.Config.AllowCredentials {
utils.AssertEqual(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
}
utils.AssertEqual(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
})
}
}

0 comments on commit 1fac52a

Please sign in to comment.