Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to restrict localserver response handling to specific origins #1641

Merged
merged 9 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions ee/localserver/krypto-ec-middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/ecdsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
Expand All @@ -20,7 +21,6 @@ import (
"github.com/kolide/launcher/pkg/log/multislogger"
"github.com/kolide/launcher/pkg/traces"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
)

const (
Expand All @@ -35,6 +35,7 @@ type v2CmdRequestType struct {
Body []byte
CallbackUrl string
CallbackHeaders map[string][]string
AllowedOrigins []string
}

func (cmdReq v2CmdRequestType) CallbackReq() (*http.Request, error) {
Expand Down Expand Up @@ -79,8 +80,9 @@ func newKryptoEcMiddleware(slogger *slog.Logger, localDbSigner, hardwareSigner c
type callbackErrors string

const (
timeOutOfRangeErr callbackErrors = "time-out-of-range"
responseFailureErr callbackErrors = "response-failure"
timeOutOfRangeErr callbackErrors = "time-out-of-range"
responseFailureErr callbackErrors = "response-failure"
originDisallowedErr callbackErrors = "origin-disallowed"
)

type callbackDataStruct struct {
Expand All @@ -96,7 +98,7 @@ type callbackDataStruct struct {
// Also, because the URL is the box, we cannot cleanly do this through middleware. It reqires a lot of passing data
// around through context. Doing it here, as part of kryptoEcMiddleware, allows for a fairly succint defer.
//
// Note that this should be a goroutine.
// Note that because this is a network call, it should be called in a goroutine.
func (e *kryptoEcMiddleware) sendCallback(req *http.Request, data *callbackDataStruct) {
if req == nil {
return
Expand Down Expand Up @@ -216,12 +218,47 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler {
}()
}

// Check if the origin is in the allowed list. See https://github.com/kolide/k2/issues/9634
if len(cmdReq.AllowedOrigins) > 0 {
RebeccaMahany marked this conversation as resolved.
Show resolved Hide resolved
allowed := false
for _, ao := range cmdReq.AllowedOrigins {
if strings.EqualFold(ao, r.Header.Get("Origin")) {
allowed = true
break
}
}

if !allowed {
span.SetAttributes(attribute.String("origin", r.Header.Get("Origin")))
traces.SetError(span, fmt.Errorf("origin %s is not allowed", r.Header.Get("Origin")))
e.slogger.Log(r.Context(), slog.LevelError,
"origin is not allowed",
"allowlist", cmdReq.AllowedOrigins,
"origin", r.Header.Get("Origin"),
)

w.WriteHeader(http.StatusUnauthorized)
callbackData.Error = originDisallowedErr
return
}

e.slogger.Log(r.Context(), slog.LevelDebug,
"origin matches allowlist",
"origin", r.Header.Get("Origin"),
)
} else {
e.slogger.Log(r.Context(), slog.LevelDebug,
"origin is allowed by default, no allowlist",
RebeccaMahany marked this conversation as resolved.
Show resolved Hide resolved
"origin", r.Header.Get("Origin"),
)
}

// Check the timestamp, this prevents people from saving a challenge and then
// reusing it a bunch. However, it will fail if the clocks are too far out of sync.
timestampDelta := time.Now().Unix() - challengeBox.Timestamp()
if timestampDelta > timestampValidityRange || timestampDelta < -timestampValidityRange {
span.SetAttributes(attribute.Int64("timestamp_delta", timestampDelta))
span.SetStatus(codes.Error, "timestamp is out of range")
traces.SetError(span, errors.New("timestamp is out of range"))
e.slogger.Log(r.Context(), slog.LevelError,
"timestamp is out of range",
"delta", timestampDelta,
Expand All @@ -234,13 +271,16 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler {

newReq := &http.Request{
Method: http.MethodPost,
Header: make(http.Header),
URL: &url.URL{
Scheme: r.URL.Scheme,
Host: r.Host,
Path: cmdReq.Path,
},
}

newReq.Header.Set("Origin", r.Header.Get("Origin"))

// setting the newReq context to the current request context
// allows the trace to continue to the inner request,
// maintains the same lifetime as the original request,
Expand Down
147 changes: 147 additions & 0 deletions ee/localserver/krypto-ec-middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,153 @@ func TestKryptoEcMiddleware(t *testing.T) {
}
}

func Test_AllowedOrigin(t *testing.T) {
t.Parallel()

counterpartyKey, err := echelper.GenerateEcdsaKey()
require.NoError(t, err)

challengeId := []byte(ulid.New())
challengeData := []byte(ulid.New())

var tests = []struct {
name string
requestOrigin string
allowedOrigins []string
logStr string
expectedStatus int
}{
{
name: "no allowed specified",
requestOrigin: "https://auth.example.com",
expectedStatus: http.StatusOK,
logStr: "origin is allowed by default",
},
{
name: "no allowed specified missing origin",
expectedStatus: http.StatusOK,
logStr: "origin is allowed by default",
},
{
name: "allowed specified missing origin",
allowedOrigins: []string{"https://auth.example.com", "https://login.example.com"},
expectedStatus: http.StatusUnauthorized,
logStr: "origin is not allowed",
},
{
name: "allowed specified origin mismatch",
allowedOrigins: []string{"https://auth.example.com", "https://login.example.com"},
requestOrigin: "https://not-it.example.com",
expectedStatus: http.StatusUnauthorized,
logStr: "origin is not allowed",
},
{
name: "scheme mismatch",
allowedOrigins: []string{"https://auth.example.com"},
requestOrigin: "http://auth.example.com",
expectedStatus: http.StatusUnauthorized,
logStr: "origin is not allowed",
},
{
name: "allowed specified origin matches",
allowedOrigins: []string{"https://auth.example.com", "https://login.example.com"},
requestOrigin: "https://auth.example.com",
expectedStatus: http.StatusOK,
logStr: "origin matches allowlist",
},
{
name: "allowed specified origin matches 2",
allowedOrigins: []string{"https://auth.example.com", "https://login.example.com"},
requestOrigin: "https://login.example.com",
expectedStatus: http.StatusOK,
logStr: "origin matches allowlist",
},
{
name: "allowed specified origin matches casing",
allowedOrigins: []string{"https://auth.example.com", "https://login.example.com"},
requestOrigin: "https://AuTh.ExAmPlE.cOm",
expectedStatus: http.StatusOK,
logStr: "origin matches allowlist",
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

cmdReqBody := []byte(randomStringWithSqlCharacters(t, 100000))

cmdReq := v2CmdRequestType{
Path: "whatevs",
Body: cmdReqBody,
AllowedOrigins: tt.allowedOrigins,
}

challengeBytes, privateEncryptionKey, err := challenge.Generate(counterpartyKey, challengeId, challengeData, mustMarshal(t, cmdReq))
require.NoError(t, err)
encodedChallenge := base64.StdEncoding.EncodeToString(challengeBytes)

responseData := []byte(ulid.New())

testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqBodyRaw, err := io.ReadAll(r.Body)
require.NoError(t, err)
defer r.Body.Close()

require.Equal(t, cmdReqBody, reqBodyRaw)
w.Write(responseData)
})

var logBytes bytes.Buffer
slogger := multislogger.New(slog.NewTextHandler(&logBytes, &slog.HandlerOptions{
Level: slog.LevelDebug,
})).Logger

// set up middlewares
kryptoEcMiddleware := newKryptoEcMiddleware(slogger, ecdsaKey(t), nil, counterpartyKey.PublicKey)
require.NoError(t, err)

h := kryptoEcMiddleware.Wrap(testHandler)

req := makeGetRequest(t, encodedChallenge)
req.Header.Set("origin", tt.requestOrigin)

rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)

// FIXME: add some log string tests
//spew.Dump(logBytes.String(), tt.logStr)

require.Equal(t, tt.expectedStatus, rr.Code)

if tt.logStr != "" {
assert.Contains(t, logBytes.String(), tt.logStr)
}

if tt.expectedStatus != http.StatusOK {
return
}

// try to open the response
returnedResponseBytes, err := base64.StdEncoding.DecodeString(rr.Body.String())
require.NoError(t, err)

responseUnmarshalled, err := challenge.UnmarshalResponse(returnedResponseBytes)
require.NoError(t, err)
require.Equal(t, challengeId, responseUnmarshalled.ChallengeId)

opened, err := responseUnmarshalled.Open(*privateEncryptionKey)
require.NoError(t, err)
require.Equal(t, challengeData, opened.ChallengeData)
require.Equal(t, responseData, opened.ResponseData)
require.WithinDuration(t, time.Now(), time.Unix(opened.Timestamp, 0), time.Second*5)

})
}

}

func ecdsaKey(t *testing.T) *ecdsa.PrivateKey {
key, err := echelper.GenerateEcdsaKey()
require.NoError(t, err)
Expand Down
2 changes: 2 additions & 0 deletions ee/localserver/request-id.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type (
Nonce string
Timestamp time.Time
Status status
Origin string
}

status struct {
Expand Down Expand Up @@ -78,6 +79,7 @@ func (ls *localServer) requestIdHandlerFunc(w http.ResponseWriter, r *http.Reque
response := requestIdsResponse{
Nonce: ulid.New(),
Timestamp: time.Now(),
Origin: r.Header.Get("Origin"),
directionless marked this conversation as resolved.
Show resolved Hide resolved
Status: status{
EnrollmentStatus: string(enrollmentStatus),
},
Expand Down
6 changes: 4 additions & 2 deletions ee/localserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,10 @@ func (ls *localServer) startListener() (net.Listener, error) {

func (ls *localServer) preflightCorsHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Think harder, maybe?
// https://stackoverflow.com/questions/12830095/setting-http-headers
// We don't believe we can meaningfully enforce a CORS style check here -- those are enforced by the browser.
// And we recognize there are some patterns that bypass the browsers CORS enforcement. However, we do implement
// origin enforcement as an allowlist inside kryptoEcMiddleware
// See https://github.com/kolide/k2/issues/9634
if origin := r.Header.Get("Origin"); origin != "" {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
Expand Down
Loading