From 44ac10bc419b3c3f07632239d1b6518cdeb9d639 Mon Sep 17 00:00:00 2001 From: Oleg Don <31502412+HeadHunter483@users.noreply.github.com> Date: Thu, 12 Sep 2024 18:02:43 +0500 Subject: [PATCH] Add support for suffix domains in allowed origins in http input plugin --- plugin/input/http/http.go | 58 +++++++- plugin/input/http/http_test.go | 262 +++++++++++++++++++++++++++++++++ 2 files changed, 316 insertions(+), 4 deletions(-) diff --git a/plugin/input/http/http.go b/plugin/input/http/http.go index be2459045..bb65825ef 100644 --- a/plugin/input/http/http.go +++ b/plugin/input/http/http.go @@ -1,6 +1,7 @@ package http import ( + "fmt" "io" "net" "net/http" @@ -164,6 +165,7 @@ type Config struct { // > @3@4@5@6 // > // > CORS config. + // > Allowed origins support only one wildcard symbol. `http://*.example.com` - valid, `http://*.example.*.com` - invalid. // > See CORSConfig for details. CORS CORSConfig `json:"cors" child:"true"` // * } @@ -198,22 +200,66 @@ type AuthConfig struct { Secrets map[string]string `json:"secrets"` // * } +type originDomain struct { + domain string + prefix string + suffix string +} + type CORSConfig struct { AllowedOrigins []string `json:"allowed_origins"` DefaultOrigin string `json:"default_origin" default:"*"` AllowedHeaders []string `json:"allowed_headers"` ExposedHeaders []string `json:"exposed_headers"` + + allowedOriginsDomains []originDomain + allowedOriginsAll bool } -func (c *CORSConfig) getAllowedByOrigin(originHeader string) string { - for _, allowed := range c.AllowedOrigins { - if strings.HasSuffix(originHeader, allowed) || strings.HasPrefix(originHeader, allowed) { - return originHeader +func (c *CORSConfig) getAllowedByOrigin(origin string) string { + if c.allowedOriginsAll { + return origin + } + + for _, ao := range c.allowedOriginsDomains { + if ao.domain != "" && origin == ao.domain { + return origin + } + + pslen := len(ao.prefix) + len(ao.suffix) + if pslen > 0 && len(origin) > pslen && strings.HasPrefix(origin, ao.prefix) && strings.HasSuffix(origin, ao.suffix) { + return origin } } + return c.DefaultOrigin } +func (c *CORSConfig) prepareAllowedOrigins() error { + for _, ao := range c.AllowedOrigins { + ao = strings.ToLower(ao) + if ao == "*" { + c.allowedOriginsAll = true + c.allowedOriginsDomains = nil + break + } + if wildcard := strings.IndexByte(ao, '*'); wildcard != -1 { + if strings.Contains(ao[wildcard+1:], "*") { + return fmt.Errorf("invalid origin %q, only one wildcard per origin is allowed", ao) + } + c.allowedOriginsDomains = append(c.allowedOriginsDomains, originDomain{ + prefix: ao[:wildcard], + suffix: ao[wildcard+1:], + }) + continue + } + c.allowedOriginsDomains = append(c.allowedOriginsDomains, originDomain{ + domain: ao, + }) + } + return nil +} + func init() { fd.DefaultPluginRegistry.RegisterInput(&pipeline.PluginStaticInfo{ Type: "http", @@ -239,6 +285,10 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.InputPluginPa } } + if err := p.config.CORS.prepareAllowedOrigins(); err != nil { + p.logger.Fatal("failed to prepare allowed origins", zap.Error(err)) + } + p.controller = params.Controller p.controller.DisableStreams() p.sourceIDs = make([]pipeline.SourceID, 0) diff --git a/plugin/input/http/http_test.go b/plugin/input/http/http_test.go index 75ed1a40f..d9de29a31 100644 --- a/plugin/input/http/http_test.go +++ b/plugin/input/http/http_test.go @@ -660,3 +660,265 @@ func TestGzip(t *testing.T) { }) } } + +func TestCORSPrepareAllowedOrigins(t *testing.T) { + t.Parallel() + r := require.New(t) + + type allowedOriginsCfg struct { + allowedOriginsDomains []originDomain + allowedOriginsAll bool + } + + tests := []struct { + Name string + Args []string + Want allowedOriginsCfg + WantErr bool + }{ + { + Name: "ok_simple_host", + Args: []string{ + "http://example.com", + }, + Want: allowedOriginsCfg{ + allowedOriginsDomains: []originDomain{ + { + domain: "http://example.com", + }, + }, + allowedOriginsAll: false, + }, + }, + { + Name: "ok_suffix", + Args: []string{ + "*.example.com", + }, + Want: allowedOriginsCfg{ + allowedOriginsDomains: []originDomain{ + { + suffix: ".example.com", + }, + }, + allowedOriginsAll: false, + }, + }, + { + Name: "ok_prefix", + Args: []string{ + "http://example.*", + }, + Want: allowedOriginsCfg{ + allowedOriginsDomains: []originDomain{ + { + prefix: "http://example.", + }, + }, + allowedOriginsAll: false, + }, + }, + { + Name: "ok_prefix_and_suffix", + Args: []string{ + "http://*.example.com", + }, + Want: allowedOriginsCfg{ + allowedOriginsDomains: []originDomain{ + { + prefix: "http://", + suffix: ".example.com", + }, + }, + allowedOriginsAll: false, + }, + }, + { + Name: "ok_mixed", + Args: []string{ + "*.example.com", + "http://otherexample.com", + }, + Want: allowedOriginsCfg{ + allowedOriginsDomains: []originDomain{ + { + suffix: ".example.com", + }, + { + domain: "http://otherexample.com", + }, + }, + allowedOriginsAll: false, + }, + }, + { + Name: "ok_wildcard", + Args: []string{ + "*", + }, + Want: allowedOriginsCfg{ + allowedOriginsDomains: nil, + allowedOriginsAll: true, + }, + }, + { + Name: "ok_wildcard_mixed", + Args: []string{ + "example.com", + "*.example.com", + "*", + }, + Want: allowedOriginsCfg{ + allowedOriginsDomains: nil, + allowedOriginsAll: true, + }, + }, + { + Name: "invalid_domain", + Args: []string{ + "*.*example.com", + }, + Want: allowedOriginsCfg{}, + WantErr: true, + }, + { + Name: "ok_host_with_port", + Args: []string{ + "http://localhost:8090", + }, + Want: allowedOriginsCfg{ + allowedOriginsDomains: []originDomain{ + { + domain: "http://localhost:8090", + }, + }, + allowedOriginsAll: false, + }, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + corsCfg := CORSConfig{ + AllowedOrigins: tc.Args, + } + err := corsCfg.prepareAllowedOrigins() + if tc.WantErr { + r.Error(err, "expected an error") + return + } + r.NoError(err, "expected no errors") + r.Equal(tc.Want.allowedOriginsAll, corsCfg.allowedOriginsAll, "allowedOriginsAll not equal") + r.Equal(len(tc.Want.allowedOriginsDomains), len(corsCfg.allowedOriginsDomains), "allowedOriginsDomains not equal") + for i := range tc.Want.allowedOriginsDomains { + wantDomain := tc.Want.allowedOriginsDomains[i] + gotDomain := corsCfg.allowedOriginsDomains[i] + r.Equal(wantDomain.domain, gotDomain.domain, "domains are not equal") + r.Equal(wantDomain.suffix, gotDomain.suffix, "domain suffixes are not equal") + } + }) + } +} + +func TestCORSGetAllowedByOrigin(t *testing.T) { + t.Parallel() + r := require.New(t) + + tests := []struct { + Name string + DefaultOrigin string + AllowedOrigins []string + CheckOrigin string + Want string + }{ + { + Name: "ok_domain", + DefaultOrigin: "http://default.com", + AllowedOrigins: []string{ + "http://example.com", + }, + CheckOrigin: "http://example.com", + Want: "http://example.com", + }, + { + Name: "ok_suffix", + DefaultOrigin: "http://default.com", + AllowedOrigins: []string{ + "*.example.com", + }, + CheckOrigin: "http://test.example.com", + Want: "http://test.example.com", + }, + { + Name: "ok_prefix", + DefaultOrigin: "http://default.com", + AllowedOrigins: []string{ + "http://example.*", + }, + CheckOrigin: "http://example.example.org", + Want: "http://example.example.org", + }, + { + Name: "ok_prefix_and_suffix", + DefaultOrigin: "http://default.com", + AllowedOrigins: []string{ + "http://*.example.com:8090", + }, + CheckOrigin: "http://subtest.test.example.com:8090", + Want: "http://subtest.test.example.com:8090", + }, + { + Name: "ok_wildcard", + DefaultOrigin: "http://default.com", + AllowedOrigins: []string{ + "*", + }, + CheckOrigin: "http://example.com", + Want: "http://example.com", + }, + { + Name: "default_origin", + DefaultOrigin: "http://default.com", + AllowedOrigins: []string{ + "http://example.com", + }, + CheckOrigin: "http://otherexample.com", + Want: "http://default.com", + }, + { + Name: "ok_host_port", + DefaultOrigin: "http://default.com", + AllowedOrigins: []string{ + "http://localhost:8090", + }, + CheckOrigin: "http://localhost:8090", + Want: "http://localhost:8090", + }, + { + Name: "ok_host_port_suffix", + DefaultOrigin: "http://default.com", + AllowedOrigins: []string{ + "*.example.com:8090", + }, + CheckOrigin: "http://test.example.com:8090", + Want: "http://test.example.com:8090", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + corsCfg := CORSConfig{ + AllowedOrigins: tc.AllowedOrigins, + DefaultOrigin: tc.DefaultOrigin, + } + err := corsCfg.prepareAllowedOrigins() + r.NoError(err, "expected no errors") + gotOrigin := corsCfg.getAllowedByOrigin(tc.CheckOrigin) + r.Equal(tc.Want, gotOrigin, "origin not equal") + }) + } +}