Skip to content

Commit

Permalink
Add support for suffix domains in allowed origins in http input plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
HeadHunter483 committed Sep 13, 2024
1 parent bbc5cd0 commit 44ac10b
Show file tree
Hide file tree
Showing 2 changed files with 316 additions and 4 deletions.
58 changes: 54 additions & 4 deletions plugin/input/http/http.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http

import (
"fmt"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -164,6 +165,7 @@ type Config struct {
// > @3@4@5@6
// >
// > CORS config.
// > Allowed origins support only one wildcard symbol. `http://*.example.com` - valid, `http://*.example.*.com` - invalid.
// > See CORSConfig for details.
CORS CORSConfig `json:"cors" child:"true"` // *
}
Expand Down Expand Up @@ -198,22 +200,66 @@ type AuthConfig struct {
Secrets map[string]string `json:"secrets"` // *
}

type originDomain struct {
domain string
prefix string
suffix string
}

type CORSConfig struct {
AllowedOrigins []string `json:"allowed_origins"`
DefaultOrigin string `json:"default_origin" default:"*"`
AllowedHeaders []string `json:"allowed_headers"`
ExposedHeaders []string `json:"exposed_headers"`

allowedOriginsDomains []originDomain
allowedOriginsAll bool
}

func (c *CORSConfig) getAllowedByOrigin(originHeader string) string {
for _, allowed := range c.AllowedOrigins {
if strings.HasSuffix(originHeader, allowed) || strings.HasPrefix(originHeader, allowed) {
return originHeader
func (c *CORSConfig) getAllowedByOrigin(origin string) string {
if c.allowedOriginsAll {
return origin
}

for _, ao := range c.allowedOriginsDomains {
if ao.domain != "" && origin == ao.domain {
return origin
}

pslen := len(ao.prefix) + len(ao.suffix)
if pslen > 0 && len(origin) > pslen && strings.HasPrefix(origin, ao.prefix) && strings.HasSuffix(origin, ao.suffix) {
return origin
}
}

return c.DefaultOrigin
}

func (c *CORSConfig) prepareAllowedOrigins() error {
for _, ao := range c.AllowedOrigins {
ao = strings.ToLower(ao)
if ao == "*" {
c.allowedOriginsAll = true
c.allowedOriginsDomains = nil
break
}
if wildcard := strings.IndexByte(ao, '*'); wildcard != -1 {
if strings.Contains(ao[wildcard+1:], "*") {
return fmt.Errorf("invalid origin %q, only one wildcard per origin is allowed", ao)
}
c.allowedOriginsDomains = append(c.allowedOriginsDomains, originDomain{
prefix: ao[:wildcard],
suffix: ao[wildcard+1:],
})
continue
}
c.allowedOriginsDomains = append(c.allowedOriginsDomains, originDomain{
domain: ao,
})
}
return nil
}

func init() {
fd.DefaultPluginRegistry.RegisterInput(&pipeline.PluginStaticInfo{
Type: "http",
Expand All @@ -239,6 +285,10 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.InputPluginPa
}
}

if err := p.config.CORS.prepareAllowedOrigins(); err != nil {
p.logger.Fatal("failed to prepare allowed origins", zap.Error(err))
}

p.controller = params.Controller
p.controller.DisableStreams()
p.sourceIDs = make([]pipeline.SourceID, 0)
Expand Down
262 changes: 262 additions & 0 deletions plugin/input/http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,3 +660,265 @@ func TestGzip(t *testing.T) {
})
}
}

func TestCORSPrepareAllowedOrigins(t *testing.T) {
t.Parallel()
r := require.New(t)

type allowedOriginsCfg struct {
allowedOriginsDomains []originDomain
allowedOriginsAll bool
}

tests := []struct {
Name string
Args []string
Want allowedOriginsCfg
WantErr bool
}{
{
Name: "ok_simple_host",
Args: []string{
"http://example.com",
},
Want: allowedOriginsCfg{
allowedOriginsDomains: []originDomain{
{
domain: "http://example.com",
},
},
allowedOriginsAll: false,
},
},
{
Name: "ok_suffix",
Args: []string{
"*.example.com",
},
Want: allowedOriginsCfg{
allowedOriginsDomains: []originDomain{
{
suffix: ".example.com",
},
},
allowedOriginsAll: false,
},
},
{
Name: "ok_prefix",
Args: []string{
"http://example.*",
},
Want: allowedOriginsCfg{
allowedOriginsDomains: []originDomain{
{
prefix: "http://example.",
},
},
allowedOriginsAll: false,
},
},
{
Name: "ok_prefix_and_suffix",
Args: []string{
"http://*.example.com",
},
Want: allowedOriginsCfg{
allowedOriginsDomains: []originDomain{
{
prefix: "http://",
suffix: ".example.com",
},
},
allowedOriginsAll: false,
},
},
{
Name: "ok_mixed",
Args: []string{
"*.example.com",
"http://otherexample.com",
},
Want: allowedOriginsCfg{
allowedOriginsDomains: []originDomain{
{
suffix: ".example.com",
},
{
domain: "http://otherexample.com",
},
},
allowedOriginsAll: false,
},
},
{
Name: "ok_wildcard",
Args: []string{
"*",
},
Want: allowedOriginsCfg{
allowedOriginsDomains: nil,
allowedOriginsAll: true,
},
},
{
Name: "ok_wildcard_mixed",
Args: []string{
"example.com",
"*.example.com",
"*",
},
Want: allowedOriginsCfg{
allowedOriginsDomains: nil,
allowedOriginsAll: true,
},
},
{
Name: "invalid_domain",
Args: []string{
"*.*example.com",
},
Want: allowedOriginsCfg{},
WantErr: true,
},
{
Name: "ok_host_with_port",
Args: []string{
"http://localhost:8090",
},
Want: allowedOriginsCfg{
allowedOriginsDomains: []originDomain{
{
domain: "http://localhost:8090",
},
},
allowedOriginsAll: false,
},
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
corsCfg := CORSConfig{
AllowedOrigins: tc.Args,
}
err := corsCfg.prepareAllowedOrigins()
if tc.WantErr {
r.Error(err, "expected an error")
return
}
r.NoError(err, "expected no errors")
r.Equal(tc.Want.allowedOriginsAll, corsCfg.allowedOriginsAll, "allowedOriginsAll not equal")
r.Equal(len(tc.Want.allowedOriginsDomains), len(corsCfg.allowedOriginsDomains), "allowedOriginsDomains not equal")
for i := range tc.Want.allowedOriginsDomains {
wantDomain := tc.Want.allowedOriginsDomains[i]
gotDomain := corsCfg.allowedOriginsDomains[i]
r.Equal(wantDomain.domain, gotDomain.domain, "domains are not equal")
r.Equal(wantDomain.suffix, gotDomain.suffix, "domain suffixes are not equal")
}
})
}
}

func TestCORSGetAllowedByOrigin(t *testing.T) {
t.Parallel()
r := require.New(t)

tests := []struct {
Name string
DefaultOrigin string
AllowedOrigins []string
CheckOrigin string
Want string
}{
{
Name: "ok_domain",
DefaultOrigin: "http://default.com",
AllowedOrigins: []string{
"http://example.com",
},
CheckOrigin: "http://example.com",
Want: "http://example.com",
},
{
Name: "ok_suffix",
DefaultOrigin: "http://default.com",
AllowedOrigins: []string{
"*.example.com",
},
CheckOrigin: "http://test.example.com",
Want: "http://test.example.com",
},
{
Name: "ok_prefix",
DefaultOrigin: "http://default.com",
AllowedOrigins: []string{
"http://example.*",
},
CheckOrigin: "http://example.example.org",
Want: "http://example.example.org",
},
{
Name: "ok_prefix_and_suffix",
DefaultOrigin: "http://default.com",
AllowedOrigins: []string{
"http://*.example.com:8090",
},
CheckOrigin: "http://subtest.test.example.com:8090",
Want: "http://subtest.test.example.com:8090",
},
{
Name: "ok_wildcard",
DefaultOrigin: "http://default.com",
AllowedOrigins: []string{
"*",
},
CheckOrigin: "http://example.com",
Want: "http://example.com",
},
{
Name: "default_origin",
DefaultOrigin: "http://default.com",
AllowedOrigins: []string{
"http://example.com",
},
CheckOrigin: "http://otherexample.com",
Want: "http://default.com",
},
{
Name: "ok_host_port",
DefaultOrigin: "http://default.com",
AllowedOrigins: []string{
"http://localhost:8090",
},
CheckOrigin: "http://localhost:8090",
Want: "http://localhost:8090",
},
{
Name: "ok_host_port_suffix",
DefaultOrigin: "http://default.com",
AllowedOrigins: []string{
"*.example.com:8090",
},
CheckOrigin: "http://test.example.com:8090",
Want: "http://test.example.com:8090",
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
corsCfg := CORSConfig{
AllowedOrigins: tc.AllowedOrigins,
DefaultOrigin: tc.DefaultOrigin,
}
err := corsCfg.prepareAllowedOrigins()
r.NoError(err, "expected no errors")
gotOrigin := corsCfg.getAllowedByOrigin(tc.CheckOrigin)
r.Equal(tc.Want, gotOrigin, "origin not equal")
})
}
}

0 comments on commit 44ac10b

Please sign in to comment.