Skip to content

Commit

Permalink
fix: Fix the quotation issue of deny message in ai-security-guard (#1352
Browse files Browse the repository at this point in the history
)
  • Loading branch information
CH3CHO authored Sep 27, 2024
1 parent 1b119ed commit 71aae9d
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions plugins/wasm-go/extensions/ai-security-guard/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/sha1"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
mrand "math/rand"
Expand Down Expand Up @@ -194,7 +195,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
}

func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
proxywasm.LogDebugf("checking request body...")
log.Debugf("checking request body...")
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
model := gjson.GetBytes(body, "model").Raw
ctx.SetContext("requestModel", model)
Expand Down Expand Up @@ -231,25 +232,35 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
respAdvice := respData.Get("Advice")
respResult := respData.Get("Result")
var denyMessage string
messageNeedSerialization := true
if config.protocolOriginal {
// not openai
if config.denyMessage != "" {
denyMessage = config.denyMessage
} else if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
messageNeedSerialization = false
} else {
denyMessage = DefaultDenyMessage
}
} else {
// openai
if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
messageNeedSerialization = false
} else if config.denyMessage != "" {
denyMessage = config.denyMessage
} else {
denyMessage = DefaultDenyMessage
}
}
if messageNeedSerialization {
if data, err := json.Marshal(denyMessage); err == nil {
denyMessage = string(data)
} else {
denyMessage = fmt.Sprintf("\"%s\"", DefaultDenyMessage)
}
}
if respResult.Array()[0].Get("Label").String() != "nonLabel" {
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request"))
Expand Down Expand Up @@ -280,7 +291,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
}
return types.ActionPause
} else {
proxywasm.LogDebugf("request content is empty. skip")
log.Debugf("request content is empty. skip")
return types.ActionContinue
}
}
Expand Down Expand Up @@ -320,7 +331,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
}

func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
proxywasm.LogDebugf("checking response body...")
log.Debugf("checking response body...")
hdsMap := ctx.GetContext("headers").(map[string][]string)
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream")
model := ctx.GetStringContext("requestModel", "unknown")
Expand Down Expand Up @@ -411,7 +422,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
}
return types.ActionPause
} else {
proxywasm.LogDebugf("request content is empty. skip")
log.Debugf("request content is empty. skip")
return types.ActionContinue
}
}
Expand Down

0 comments on commit 71aae9d

Please sign in to comment.