From 1ad818bb3b6edc01be3b280ace1ab50f0781b32e Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Fri, 30 Aug 2024 09:31:47 -0400 Subject: [PATCH] enhance: avoid context limit (#832) Signed-off-by: Grant Linville --- pkg/openai/client.go | 46 ++++++++++++++++++++++++++++++++++++++++++++ pkg/openai/count.go | 34 ++++++++++++++++++++++++-------- 2 files changed, 72 insertions(+), 8 deletions(-) diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 42a1a39e..61a7ec77 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -2,6 +2,7 @@ package openai import ( "context" + "errors" "io" "log/slog" "os" @@ -24,6 +25,7 @@ import ( const ( DefaultModel = openai.GPT4o BuiltinCredName = "sys.openai" + TooLongMessage = "Error: tool call output is too long" ) var ( @@ -317,6 +319,14 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques } if messageRequest.Chat { + // Check the last message. If it is from a tool call, and if it takes up more than 80% of the budget on its own, reject it. + lastMessage := msgs[len(msgs)-1] + if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && countMessage(lastMessage) > int(float64(getBudget(messageRequest.MaxTokens))*0.8) { + // We need to update it in the msgs slice for right now and in the messageRequest for future calls. + msgs[len(msgs)-1].Content = TooLongMessage + messageRequest.Messages[len(messageRequest.Messages)-1].Content = types.Text(TooLongMessage) + } + msgs = dropMessagesOverCount(messageRequest.MaxTokens, msgs) } @@ -383,6 +393,16 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques return nil, err } else if !ok { response, err = c.call(ctx, request, id, status) + + // If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass. + var apiError *openai.APIError + if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat { + // Decrease maxTokens by 10% to make garbage collection more aggressive. + // The retry loop will further decrease maxTokens if needed. + maxTokens := decreaseTenPercent(messageRequest.MaxTokens) + response, err = c.contextLimitRetryLoop(ctx, request, id, maxTokens, status) + } + if err != nil { return nil, err } @@ -421,6 +441,32 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques return &result, nil } +func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, maxTokens int, status chan<- types.CompletionStatus) ([]openai.ChatCompletionStreamResponse, error) { + var ( + response []openai.ChatCompletionStreamResponse + err error + ) + + for range 10 { // maximum 10 tries + // Try to drop older messages again, with a decreased max tokens. + request.Messages = dropMessagesOverCount(maxTokens, request.Messages) + response, err = c.call(ctx, request, id, status) + if err == nil { + return response, nil + } + + var apiError *openai.APIError + if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" { + // Decrease maxTokens and try again + maxTokens = decreaseTenPercent(maxTokens) + continue + } + return nil, err + } + + return nil, err +} + func appendMessage(msg types.CompletionMessage, response openai.ChatCompletionStreamResponse) types.CompletionMessage { msg.Usage.CompletionTokens = types.FirstSet(msg.Usage.CompletionTokens, response.Usage.CompletionTokens) msg.Usage.PromptTokens = types.FirstSet(msg.Usage.PromptTokens, response.Usage.PromptTokens) diff --git a/pkg/openai/count.go b/pkg/openai/count.go index 47c5c9bd..ffd902e5 100644 --- a/pkg/openai/count.go +++ b/pkg/openai/count.go @@ -1,20 +1,30 @@ package openai -import openai "github.com/gptscript-ai/chat-completion-client" +import ( + openai "github.com/gptscript-ai/chat-completion-client" +) + +const DefaultMaxTokens = 128_000 + +func decreaseTenPercent(maxTokens int) int { + maxTokens = getBudget(maxTokens) + return int(float64(maxTokens) * 0.9) +} + +func getBudget(maxTokens int) int { + if maxTokens == 0 { + return DefaultMaxTokens + } + return maxTokens +} func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (result []openai.ChatCompletionMessage) { var ( lastSystem int withinBudget int - budget = maxTokens + budget = getBudget(maxTokens) ) - if maxTokens == 0 { - budget = 300_000 - } else { - budget *= 3 - } - for i, msg := range msgs { if msg.Role == openai.ChatMessageRoleSystem { budget -= countMessage(msg) @@ -33,6 +43,14 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) ( } } + // OpenAI gets upset if there is a tool message without a tool call preceding it. + // Check the oldest message within budget, and if it is a tool message, just drop it. + // We do this in a loop because it is possible for multiple tool messages to be in a row, + // due to parallel tool calls. + for withinBudget < len(msgs) && msgs[withinBudget].Role == openai.ChatMessageRoleTool { + withinBudget++ + } + if withinBudget == len(msgs)-1 { // We are going to drop all non system messages, which seems useless, so just return them // all and let it fail