Skip to content

Commit

Permalink
chore(ai-proxy): use logger with filter name (erda-project#6370)
Browse files Browse the repository at this point in the history
* polish dashscope code

* fix: only wrap user prompt if text is not empty

* use context logger with module name
  • Loading branch information
sfwn committed Jun 18, 2024
1 parent 3986e72 commit 52c72a9
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 15 deletions.
5 changes: 2 additions & 3 deletions internal/apps/ai-proxy/filters/azure-director/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (

"github.com/pkg/errors"
"github.com/sashabaranov/go-openai"
"github.com/sirupsen/logrus"
"sigs.k8s.io/yaml"

"github.com/erda-project/erda-infra/base/logs"
Expand Down Expand Up @@ -368,13 +367,13 @@ func (f *AzureDirector) AddContextMessages(ctx context.Context) error {
infor := reverseproxy.NewInfor(ctx, req)
var openaiReq openai.ChatCompletionRequest
if err := json.NewDecoder(infor.BodyBuffer()).Decode(&openaiReq); err != nil && err != io.EOF {
logrus.Errorf("failed to decode request body, err: %v", err)
ctxhelper.GetLogger(ctx).Errorf("failed to decode request body, err: %v", err)
return
}
openaiReq.Messages = messageGroup.AllMessages
b, err := json.Marshal(&openaiReq)
if err != nil {
logrus.Errorf("failed to marshal request body, err: %v", err)
ctxhelper.GetLogger(ctx).Errorf("failed to marshal request body, err: %v", err)
return
}
infor.SetBody(io.NopCloser(strings.NewReader(string(b))), int64(len(b)))
Expand Down
10 changes: 9 additions & 1 deletion internal/apps/ai-proxy/filters/context-chat/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,15 @@ func (c *SessionContext) OnRequest(ctx context.Context, _ http.ResponseWriter, i
// handle user message, wrap by '|start| your question here |end|'
// to avoid from content-filter
if msg.Role == openai.ChatMessageRoleUser {
msg.Content = vars.WrapUserPrompt(msg.Content)
if msg.Content != "" {
msg.Content = vars.WrapUserPrompt(msg.Content)
} else {
for i, part := range msg.MultiContent {
if part.Text != "" {
msg.MultiContent[i].Text = vars.WrapUserPrompt(part.Text)
}
}
}
}
requestedMessages = append(requestedMessages, msg)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/apps/ai-proxy/filters/dashscope-director/req.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ func oneDirector(ctx context.Context, w http.ResponseWriter, infor reverseproxy.
case metadata.AliyunDashScopeRequestTypeOpenAI:
bodyObj = oreq
case metadata.AliyunDashScopeRequestTypeDs:
qwreq, err := sdk.ConvertOpenAIChatRequestToDsRequest(oreq, model.Type)
dsreq, err := sdk.ConvertOpenAIChatRequestToDsRequest(ctx, oreq, model.Type)
if err != nil {
return reverseproxy.Intercept, fmt.Errorf("failed to convert openai chat request to dashscope request, err: %v", err)
}
bodyObj = qwreq
bodyObj = dsreq
default:
return reverseproxy.Intercept, fmt.Errorf("unsupported metadata.public.request_type: %s", modelMeta.Public.RequestType)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/apps/ai-proxy/filters/dashscope-director/resp.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (f *DashScopeDirector) dsHandleResponseStreamChunk(ctx context.Context, w r
// convert ds response to openai response
openaiChunk, err := sdk.ConvertDsStreamChunkToOpenAIFormat(*lastCompleteDeltaResp, modelName)
if err != nil {
return reverseproxy.Intercept, fmt.Errorf("failed to convert qwenVL response to openai response, err: %v", err)
return reverseproxy.Intercept, fmt.Errorf("failed to convert dashscope response to openai response, err: %v", err)
}
b, err := json.Marshal(&openaiChunk)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
package sdk

import (
"context"
"fmt"
"strings"

"github.com/sashabaranov/go-openai"
"github.com/sirupsen/logrus"

"github.com/erda-project/erda-proto-go/apps/aiproxy/model/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/common/ctxhelper"
)

func ConvertOpenAIChatRequestToDsRequest(oreq openai.ChatCompletionRequest, targetModelType pb.ModelType) (DsRequest, error) {
func ConvertOpenAIChatRequestToDsRequest(ctx context.Context, oreq openai.ChatCompletionRequest, targetModelType pb.ModelType) (DsRequest, error) {
var dsReq DsRequest
dsReq.Model = oreq.Model
for _, om := range oreq.Messages {
Expand Down Expand Up @@ -63,7 +64,7 @@ func ConvertOpenAIChatRequestToDsRequest(oreq openai.ChatCompletionRequest, targ
case openai.ChatMessagePartTypeImageURL:
parts = append(parts, DsRequestContentPart{Image: omc.ImageURL.URL})
default:
logrus.Warnf("unsupported message part type: %s", omc.Type)
ctxhelper.GetLogger(ctx).Warnf("unsupported message part type: %s", omc.Type)
}
}
}
Expand Down
9 changes: 4 additions & 5 deletions internal/apps/ai-proxy/filters/openai-director/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (

"github.com/pkg/errors"
"github.com/sashabaranov/go-openai"
"github.com/sirupsen/logrus"
"sigs.k8s.io/yaml"

"github.com/erda-project/erda-infra/base/logs"
Expand Down Expand Up @@ -198,13 +197,13 @@ func (f *OpenaiDirector) AddModelInRequestBody(ctx context.Context) error {
// read body to json, then add a `model` field, then write back to body
var body map[string]interface{}
if err := json.NewDecoder(infor.BodyBuffer()).Decode(&body); err != nil && err != io.EOF {
logrus.Errorf("failed to decode request body, err: %v", err)
ctxhelper.GetLogger(ctx).Errorf("failed to decode request body, err: %v", err)
return
}
body["model"] = model.Name
b, err := json.Marshal(body)
if err != nil {
logrus.Errorf("failed to marshal request body, err: %v", err)
ctxhelper.GetLogger(ctx).Errorf("failed to marshal request body, err: %v", err)
return
}
infor.SetBody(io.NopCloser(strings.NewReader(string(b))), int64(len(b)))
Expand All @@ -221,13 +220,13 @@ func (f *OpenaiDirector) AddContextMessages(ctx context.Context) error {
infor := reverseproxy.NewInfor(ctx, req)
var openaiReq openai.ChatCompletionRequest
if err := json.NewDecoder(infor.BodyBuffer()).Decode(&openaiReq); err != nil && err != io.EOF {
logrus.Errorf("failed to decode request body, err: %v", err)
ctxhelper.GetLogger(ctx).Errorf("failed to decode request body, err: %v", err)
return
}
openaiReq.Messages = messageGroup.AllMessages
b, err := json.Marshal(&openaiReq)
if err != nil {
logrus.Errorf("failed to marshal request body, err: %v", err)
ctxhelper.GetLogger(ctx).Errorf("failed to marshal request body, err: %v", err)
return
}
infor.SetBody(io.NopCloser(strings.NewReader(string(b))), int64(len(b)))
Expand Down

0 comments on commit 52c72a9

Please sign in to comment.