diff --git a/plugin/input/http/http.go b/plugin/input/http/http.go index be2459045..79c6ceea4 100644 --- a/plugin/input/http/http.go +++ b/plugin/input/http/http.go @@ -1,6 +1,7 @@ package http import ( + "fmt" "io" "net" "net/http" @@ -198,22 +199,78 @@ type AuthConfig struct { Secrets map[string]string `json:"secrets"` // * } +type originDomain struct { + domain string + suffix string +} + type CORSConfig struct { AllowedOrigins []string `json:"allowed_origins"` DefaultOrigin string `json:"default_origin" default:"*"` AllowedHeaders []string `json:"allowed_headers"` ExposedHeaders []string `json:"exposed_headers"` + + allowedOriginsDomains []originDomain + allowedOriginsAll bool } -func (c *CORSConfig) getAllowedByOrigin(originHeader string) string { - for _, allowed := range c.AllowedOrigins { - if strings.HasSuffix(originHeader, allowed) || strings.HasPrefix(originHeader, allowed) { - return originHeader +func (c *CORSConfig) getAllowedByOrigin(origin string) string { + u, err := url.Parse(origin) + if err != nil { + return c.DefaultOrigin + } + + host := u.Host + if i := strings.IndexByte(host, ':'); i > -1 { + host = host[:i] + } + + if c.allowedOriginsAll { + return host + } + + for _, allowed := range c.allowedOriginsDomains { + if host == allowed.domain { + return host + } + + if allowed.suffix != "" && strings.HasSuffix(host, allowed.suffix) { + return host } } + return c.DefaultOrigin } +func (c *CORSConfig) prepareAllowedOrigins() error { + for _, ao := range c.AllowedOrigins { + ao = strings.ToLower(ao) + if ao == "*" { + c.allowedOriginsAll = true + c.allowedOriginsDomains = nil + break + } + if ao[0] == '*' { + if strings.Contains(ao[1:], "*") { + return fmt.Errorf("invalid origin: %q", ao) + } + domain := ao[1:] + if ao[1] == '.' { + domain = ao[2:] + } + c.allowedOriginsDomains = append(c.allowedOriginsDomains, originDomain{ + domain: domain, + suffix: ao[1:], + }) + continue + } + c.allowedOriginsDomains = append(c.allowedOriginsDomains, originDomain{ + domain: ao, + }) + } + return nil +} + func init() { fd.DefaultPluginRegistry.RegisterInput(&pipeline.PluginStaticInfo{ Type: "http", @@ -239,6 +296,10 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.InputPluginPa } } + if err := p.config.CORS.prepareAllowedOrigins(); err != nil { + p.logger.Fatal("failed to prepare allowed origins", zap.Error(err)) + } + p.controller = params.Controller p.controller.DisableStreams() p.sourceIDs = make([]pipeline.SourceID, 0)