Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: avoid context limit #832

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package openai

import (
"context"
"errors"
"io"
"log/slog"
"math"
"os"
"slices"
"sort"
Expand All @@ -24,6 +26,7 @@ import (
const (
DefaultModel = openai.GPT4o
BuiltinCredName = "sys.openai"
TooLongMessage = "Error: tool call output is too long"
)

var (
Expand Down Expand Up @@ -317,6 +320,14 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
}

if messageRequest.Chat {
// Check the last message. If it is from a tool call, and if it takes up more than 80% of the budget on its own, reject it.
lastMessage := msgs[len(msgs)-1]
if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && countMessage(lastMessage) > int(math.Round(float64(getBudget(messageRequest.MaxTokens))*0.8)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: probably don't need to math.Round here. Not much of a difference between 102,399 and 102,400.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

// We need to update it in the msgs slice for right now and in the messageRequest for future calls.
msgs[len(msgs)-1].Content = TooLongMessage
messageRequest.Messages[len(messageRequest.Messages)-1].Content = types.Text(TooLongMessage)
}

msgs = dropMessagesOverCount(messageRequest.MaxTokens, msgs)
}

Expand Down Expand Up @@ -383,6 +394,16 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
return nil, err
} else if !ok {
response, err = c.call(ctx, request, id, status)

// If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass.
var apiError *openai.APIError
if err != nil && errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: errors.As takes care of the err != nil check

Suggested change
if err != nil && errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat {
if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat {

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

// Decrease maxTokens by 10% to make garbage collection more aggressive.
// The retry loop will further decrease maxTokens if needed.
maxTokens := decreaseTenPercent(messageRequest.MaxTokens)
response, err = c.contextLimitRetryLoop(ctx, request, id, maxTokens, status)
}

if err != nil {
return nil, err
}
Expand Down Expand Up @@ -421,6 +442,32 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
return &result, nil
}

func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, maxTokens int, status chan<- types.CompletionStatus) ([]openai.ChatCompletionStreamResponse, error) {
var (
response []openai.ChatCompletionStreamResponse
err error
)

for range 10 { // maximum 10 tries
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the first use in our code base?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so!

// Try to drop older messages again, with a decreased max tokens.
request.Messages = dropMessagesOverCount(maxTokens, request.Messages)
response, err = c.call(ctx, request, id, status)
if err == nil {
return response, nil
}

var apiError *openai.APIError
if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" {
// Decrease maxTokens and try again
maxTokens = decreaseTenPercent(maxTokens)
continue
}
return nil, err
}

return nil, err
}

func appendMessage(msg types.CompletionMessage, response openai.ChatCompletionStreamResponse) types.CompletionMessage {
msg.Usage.CompletionTokens = types.FirstSet(msg.Usage.CompletionTokens, response.Usage.CompletionTokens)
msg.Usage.PromptTokens = types.FirstSet(msg.Usage.PromptTokens, response.Usage.PromptTokens)
Expand Down
36 changes: 28 additions & 8 deletions pkg/openai/count.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
package openai

import openai "github.com/gptscript-ai/chat-completion-client"
import (
"math"

openai "github.com/gptscript-ai/chat-completion-client"
)

const DefaultMaxTokens = 128_000

func decreaseTenPercent(maxTokens int) int {
maxTokens = getBudget(maxTokens)
return int(math.Round(float64(maxTokens) * 0.9))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit about math.Round

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

}

func getBudget(maxTokens int) int {
if maxTokens == 0 {
return DefaultMaxTokens
}
return maxTokens
}

func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (result []openai.ChatCompletionMessage) {
var (
lastSystem int
withinBudget int
budget = maxTokens
budget = getBudget(maxTokens)
)

if maxTokens == 0 {
budget = 300_000
} else {
budget *= 3
}

for i, msg := range msgs {
if msg.Role == openai.ChatMessageRoleSystem {
budget -= countMessage(msg)
Expand All @@ -33,6 +45,14 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (
}
}

// OpenAI gets upset if there is a tool message without a tool call preceding it.
// Check the oldest message within budget, and if it is a tool message, just drop it.
// We do this in a loop because it is possible for multiple tool messages to be in a row,
// due to parallel tool calls.
for withinBudget < len(msgs) && msgs[withinBudget].Role == openai.ChatMessageRoleTool {
withinBudget++
}

if withinBudget == len(msgs)-1 {
// We are going to drop all non system messages, which seems useless, so just return them
// all and let it fail
Expand Down