Skip to content

Commit

Permalink
add headers property to response in krypto middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
James-Pickett committed Sep 30, 2024
1 parent 4d952b2 commit 4e75514
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 21 deletions.
41 changes: 31 additions & 10 deletions ee/localserver/krypto-ec-middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,44 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler {
bhr := &bufferedHttpResponse{}
next.ServeHTTP(bhr, newReq)

// add headers to the response map
var responseMap map[string]interface{}
bhrBytes := bhr.Bytes()
if err := json.Unmarshal(bhrBytes, &responseMap); err != nil {
traces.SetError(span, err)
e.slogger.Log(r.Context(), slog.LevelError,
"unable to unmarshal response",
"err", err,
)
responseMap = map[string]any{
"headers": bhr.Header(),

// the request body was not in json format, just pass it through as "msg"
// property of json
"msg": string(bhrBytes),
}
} else {
responseMap["headers"] = bhr.Header()
}

responseBytes, err := json.Marshal(responseMap)
if err != nil {
traces.SetError(span, err)
e.slogger.Log(r.Context(), slog.LevelError,
"unable to marshal response",
"err", err,
)
}

var response []byte
// it's possible the keys will be noop keys, then they will error or give nil when crypto.Signer funcs are called
// krypto library has a nil check for the object but not the funcs, so if are getting nil from the funcs, just
// pass nil to krypto
// hardware signing is not implemented for darwin
if runtime.GOOS != "darwin" && e.hardwareSigner != nil && e.hardwareSigner.Public() != nil {
response, err = challengeBox.Respond(e.localDbSigner, e.hardwareSigner, bhr.Bytes())
response, err = challengeBox.Respond(e.localDbSigner, e.hardwareSigner, responseBytes)
} else {
response, err = challengeBox.Respond(e.localDbSigner, nil, bhr.Bytes())
response, err = challengeBox.Respond(e.localDbSigner, nil, responseBytes)
}

if err != nil {
Expand All @@ -344,14 +373,6 @@ func (e *kryptoEcMiddleware) Wrap(next http.Handler) http.Handler {

w.Header().Add(kolideKryptoHeaderKey, kolideKryptoEccHeader20230130Value)

for k, v := range bhr.Header() {
if len(v) == 0 {
continue
}

w.Header().Add(k, v[0])
}

// arguable the png things here should be their own handler. But doing that means another layer
// buffering the http response, so it feels a bit silly. When we ditch the v1/v2 switcher, we can
// be a bit more clever and move this.
Expand Down
66 changes: 55 additions & 11 deletions ee/localserver/krypto-ec-middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log/slog"
"math/big"
Expand Down Expand Up @@ -139,12 +140,18 @@ func TestKryptoEcMiddleware(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

responseData := tt.responseData
// generate the response we want the handler to return
if responseData == nil {
responseData = []byte(ulid.New())
responseMap := make(map[string]any)
const testMsgKey = "msg"

responseValue := string(tt.responseData)
if responseValue == "" {
responseValue = ulid.New()
}

responseMap[testMsgKey] = responseValue

responseDataRaw := mustMarshal(t, responseMap)

testHandler := tt.handler

// this handler is what will respond to the request made by the kryptoEcMiddleware.Wrap handler
Expand All @@ -163,7 +170,7 @@ func TestKryptoEcMiddleware(t *testing.T) {
defer r.Body.Close()

require.Equal(t, cmdReqBody, reqBodyRaw)
w.Write(responseData)
w.Write(responseDataRaw)
})
}

Expand Down Expand Up @@ -212,10 +219,6 @@ func TestKryptoEcMiddleware(t *testing.T) {

require.Equal(t, kolideKryptoEccHeader20230130Value, rr.Header().Get(kolideKryptoHeaderKey))

if runtime.GOOS == "darwin" {
require.Equal(t, (0 * time.Second).String(), rr.Header().Get(kolideDurationSinceLastPresenceDetection))
}

// try to open the response
returnedResponseBytes, err := base64.StdEncoding.DecodeString(rr.Body.String())
require.NoError(t, err)
Expand All @@ -227,8 +230,20 @@ func TestKryptoEcMiddleware(t *testing.T) {
opened, err := responseUnmarshalled.Open(*privateEncryptionKey)
require.NoError(t, err)
require.Equal(t, challengeData, opened.ChallengeData)
require.Equal(t, responseData, opened.ResponseData)

opendResponseValue, err := extractJsonProperty[string](opened.ResponseData, testMsgKey)
require.NoError(t, err)
require.Equal(t, responseValue, opendResponseValue)

require.WithinDuration(t, time.Now(), time.Unix(opened.Timestamp, 0), time.Second*5)

responseHeaders, err := extractJsonProperty[map[string][]string](opened.ResponseData, "headers")
require.NoError(t, err)

// check that the presence detection interval is present
if runtime.GOOS == "darwin" {
require.Equal(t, (0 * time.Second).String(), responseHeaders[kolideDurationSinceLastPresenceDetection][0])
}
})
}
})
Expand Down Expand Up @@ -383,7 +398,11 @@ func Test_AllowedOrigin(t *testing.T) {
opened, err := responseUnmarshalled.Open(*privateEncryptionKey)
require.NoError(t, err)
require.Equal(t, challengeData, opened.ChallengeData)
require.Equal(t, responseData, opened.ResponseData)

openedResponseValue, err := extractJsonProperty[string](opened.ResponseData, "msg")
require.NoError(t, err)

require.Equal(t, responseData, []byte(openedResponseValue))
require.WithinDuration(t, time.Now(), time.Unix(opened.Timestamp, 0), time.Second*5)

})
Expand Down Expand Up @@ -448,3 +467,28 @@ func mustMarshal(t *testing.T, v interface{}) []byte {
require.NoError(t, err)
return b
}

func extractJsonProperty[T any](jsonData []byte, property string) (T, error) {
var result map[string]json.RawMessage

// Unmarshal the JSON data into a map with json.RawMessage
err := json.Unmarshal(jsonData, &result)
if err != nil {
return *new(T), err
}

// Retrieve the field from the map
value, ok := result[property]
if !ok {
return *new(T), fmt.Errorf("property %s not found", property)
}

// Unmarshal the value into the type T
var extractedValue T
err = json.Unmarshal(value, &extractedValue)
if err != nil {
return *new(T), err
}

return extractedValue, nil
}

0 comments on commit 4e75514

Please sign in to comment.