Skip to content

Commit

Permalink
Merge pull request #877 from mirackara/openaichanges
Browse files Browse the repository at this point in the history
openai changes
  • Loading branch information
nr-swilloughby authored Mar 26, 2024
2 parents f1274d0 + 8bd6b0b commit cae089b
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,30 @@ import (

"github.com/newrelic/go-agent/v3/integrations/nropenai"
"github.com/newrelic/go-agent/v3/newrelic"
"github.com/pkoukk/tiktoken-go"
openai "github.com/sashabaranov/go-openai"
)

// Simulates feedback being sent to New Relic. Feedback on a chat completion requires
// having access to the ChatCompletionResponseWrapper which is returned by the NRCreateChatCompletion function.
func SendFeedback(app *newrelic.Application, resp nropenai.ChatCompletionStreamWrapper) {
trace_id := resp.TraceID
rating := "5"
category := "informative"
message := "The response was concise yet thorough."
customMetadata := map[string]interface{}{
"foo": "bar",
"pi": 3.14,
}

app.RecordLLMFeedbackEvent(trace_id, rating, category, message, customMetadata)
}

func main() {
// Start New Relic Application
app, err := newrelic.NewApplication(
newrelic.ConfigAppName("Basic OpenAI App"),
newrelic.ConfigLicense(os.Getenv("NEW_RELIC_LICENSE_KEY")),
newrelic.ConfigDebugLogger(os.Stdout),
// Enable AI Monitoring
// NOTE - If High Security Mode is enabled, AI Monitoring will always be disabled
newrelic.ConfigAIMonitoringEnabled(true),
Expand All @@ -27,7 +42,33 @@ func main() {
panic(err)
}
app.WaitForConnection(10 * time.Second)
// SetLLMTokenCountCallback allows for custom token counting, if left unset and if newrelic.ConfigAIMonitoringRecordContentEnabled()
// is disabled, no token counts will be reported
app.SetLLMTokenCountCallback(func(modelName string, content string) int {
var tokensPerMessage, tokensPerName int
switch modelName {
case "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613":
tokensPerMessage = 3
tokensPerName = 1
case "gpt-3.5-turbo-0301":
tokensPerMessage = 4
tokensPerName = -1
}

tkm, err := tiktoken.EncodingForModel(modelName)
if err != nil {
fmt.Println("error getting tokens", err)
return 0
}
token := tkm.Encode(content, nil, nil)
totalTokens := len(token) + tokensPerMessage + tokensPerName
return totalTokens
})
// OpenAI Config - Additionally, NRDefaultAzureConfig(apiKey, baseURL string) can be used for Azure
cfg := nropenai.NRDefaultConfig(os.Getenv("OPEN_AI_API_KEY"))

Expand All @@ -43,13 +84,13 @@ func main() {

// GPT Request
req := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Model: openai.GPT4,
Temperature: 0.7,
MaxTokens: 1500,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Say this is a test",
Content: "What is observability in software engineering?",
},
},
Stream: true,
Expand All @@ -59,10 +100,9 @@ func main() {
stream, err := nropenai.NRCreateChatCompletionStream(client, ctx, req, app)

if err != nil {

panic(err)
}
defer stream.Close()

fmt.Printf("Stream response: ")
for {
var response openai.ChatCompletionStreamResponse
Expand All @@ -78,6 +118,8 @@ func main() {

fmt.Printf(response.Choices[0].Delta.Content)
}
stream.Close()
SendFeedback(app, *stream)
// Shutdown Application
app.Shutdown(5 * time.Second)
}
115 changes: 99 additions & 16 deletions v3/integrations/nropenai/nropenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,23 +149,49 @@ type ChatCompletionResponseWrapper struct {

// Wrapper for ChatCompletionStream that is returned from NRCreateChatCompletionStream
type ChatCompletionStreamWrapper struct {
stream *openai.ChatCompletionStream
txn *newrelic.Transaction
app *newrelic.Application
stream *openai.ChatCompletionStream
streamResp openai.ChatCompletionResponse
responseStr string
uuid string
txn *newrelic.Transaction
cw *ClientWrapper
role string
model string
StreamingData map[string]interface{}
isRoleAdded bool
TraceID string
}

// Wrapper for Recv() method that calls the underlying stream's Recv() method
func (w *ChatCompletionStreamWrapper) Recv() (openai.ChatCompletionStreamResponse, error) {
response, err := w.stream.Recv()

if err != nil {
return response, err
}
if !w.isRoleAdded && (response.Choices[0].Delta.Role == "assistant" || response.Choices[0].Delta.Role == "user" || response.Choices[0].Delta.Role == "system") {
w.isRoleAdded = true
w.role = response.Choices[0].Delta.Role

}
if response.Choices[0].FinishReason != "stop" {
w.responseStr += response.Choices[0].Delta.Content
w.streamResp.ID = response.ID
w.streamResp.Model = response.Model
w.model = response.Model
}

return response, nil

}

func (w *ChatCompletionStreamWrapper) Close() {
w.StreamingData["response.model"] = w.model
w.app.RecordCustomEvent("LlmChatCompletionSummary", w.StreamingData)

NRCreateChatCompletionMessageStream(w.app, uuid.MustParse(w.uuid), w, w.cw)

w.txn.End()
w.stream.Close()
}

Expand All @@ -188,7 +214,8 @@ func NRCreateChatCompletionSummary(txn *newrelic.Transaction, app *newrelic.Appl
}
}
// Start span
integrationsupport.AddAgentAttribute(txn, "llm", "", true)
txn.AddAttribute("llm", true)

chatCompletionSpan := txn.StartSegment("Llm/completion/OpenAI/CreateChatCompletion")
// Track Total time taken for the chat completion or embedding call to complete in milliseconds
start := time.Now()
Expand Down Expand Up @@ -272,6 +299,58 @@ func NRCreateChatCompletionSummary(txn *newrelic.Transaction, app *newrelic.Appl
TraceID: traceID,
}
}

func NRCreateChatCompletionMessageStream(app *newrelic.Application, uuid uuid.UUID, sw *ChatCompletionStreamWrapper, cw *ClientWrapper) {

spanID := sw.txn.GetTraceMetadata().SpanID
traceID := sw.txn.GetTraceMetadata().TraceID

appCfg, configErr := app.Config()
if !configErr {
appCfg.AppName = "Unknown"
}
integrationsupport.AddAgentAttribute(sw.txn, "llm", "", true)
chatCompletionMessageSpan := sw.txn.StartSegment("Llm/completion/OpenAI/CreateChatCompletionMessageStream")

ChatCompletionMessageData := map[string]interface{}{}
// if the response doesn't have an ID, use the UUID from the summary

ChatCompletionMessageData["id"] = sw.streamResp.ID

// Response Data
ChatCompletionMessageData["request.model"] = sw.model

if appCfg.AIMonitoring.RecordContent.Enabled {
ChatCompletionMessageData["content"] = sw.responseStr
}

ChatCompletionMessageData["role"] = sw.role

// New Relic Attributes
ChatCompletionMessageData["sequence"] = 1
ChatCompletionMessageData["vendor"] = "OpenAI"
ChatCompletionMessageData["ingest_source"] = "Go"
ChatCompletionMessageData["span_id"] = spanID
ChatCompletionMessageData["trace_id"] = traceID
tmpMessage := openai.ChatCompletionMessage{
Content: sw.responseStr,
Role: sw.role,
// Name is not provided in the stream response, so we don't include it in token counting
Name: "",
}
tokenCount, tokensCounted := TokenCountingHelper(app, tmpMessage, sw.model)
if tokensCounted {
ChatCompletionMessageData["token_count"] = tokenCount
}

// If custom attributes are set, add them to the data
ChatCompletionMessageData = AppendCustomAttributesToEvent(cw, ChatCompletionMessageData)
chatCompletionMessageSpan.End()
// Record Custom Event for each message
app.RecordCustomEvent("LlmChatCompletionMessage", ChatCompletionMessageData)

}

func NRCreateChatCompletionMessageInput(txn *newrelic.Transaction, app *newrelic.Application, req openai.ChatCompletionRequest, uuid uuid.UUID, cw *ClientWrapper) {
spanID := txn.GetTraceMetadata().SpanID
traceID := txn.GetTraceMetadata().TraceID
Expand All @@ -297,13 +376,13 @@ func NRCreateChatCompletionMessageInput(txn *newrelic.Transaction, app *newrelic

// New Relic Attributes
ChatCompletionMessageData["sequence"] = 0
ChatCompletionMessageData["vendor"] = "openai"
ChatCompletionMessageData["ingest_source"] = "go"
ChatCompletionMessageData["vendor"] = "OpenAI"
ChatCompletionMessageData["ingest_source"] = "Go"
ChatCompletionMessageData["span_id"] = spanID
ChatCompletionMessageData["trace_id"] = traceID
contentTokens, contentCounted := app.InvokeLLMTokenCountCallback(req.Model, req.Messages[0].Content)

if contentCounted {
if contentCounted && app.HasLLMTokenCountCallback() {
ChatCompletionMessageData["token_count"] = contentTokens
}

Expand Down Expand Up @@ -348,8 +427,8 @@ func NRCreateChatCompletionMessage(txn *newrelic.Transaction, app *newrelic.Appl

// New Relic Attributes
ChatCompletionMessageData["sequence"] = i + 1
ChatCompletionMessageData["vendor"] = "openai"
ChatCompletionMessageData["ingest_source"] = "go"
ChatCompletionMessageData["vendor"] = "OpenAI"
ChatCompletionMessageData["ingest_source"] = "Go"
ChatCompletionMessageData["span_id"] = spanID
ChatCompletionMessageData["trace_id"] = traceID
tokenCount, tokensCounted := TokenCountingHelper(app, choice.Message, resp.Model)
Expand All @@ -369,13 +448,16 @@ func NRCreateChatCompletionMessage(txn *newrelic.Transaction, app *newrelic.Appl
}

func TokenCountingHelper(app *newrelic.Application, message openai.ChatCompletionMessage, model string) (numTokens int, tokensCounted bool) {

contentTokens, contentCounted := app.InvokeLLMTokenCountCallback(model, message.Content)
roleTokens, roleCounted := app.InvokeLLMTokenCountCallback(model, message.Role)
messageTokens, messageCounted := app.InvokeLLMTokenCountCallback(model, message.Name)
var messageTokens int
if message.Name != "" {
messageTokens, _ = app.InvokeLLMTokenCountCallback(model, message.Name)

}
numTokens += contentTokens + roleTokens + messageTokens

return numTokens, (contentCounted && roleCounted && messageCounted)
return numTokens, (contentCounted && roleCounted)
}

// NRCreateChatCompletion is a wrapper for the OpenAI CreateChatCompletion method.
Expand Down Expand Up @@ -465,7 +547,8 @@ func NRCreateEmbedding(cw *ClientWrapper, req openai.EmbeddingRequest, app *newr
// cast input as string
input := GetInput(req.Input).(string)
tokenCount, tokensCounted := app.InvokeLLMTokenCountCallback(string(resp.Model), input)
if tokensCounted {

if tokensCounted && app.HasLLMTokenCountCallback() {
EmbeddingsData["token_count"] = tokenCount
}

Expand Down Expand Up @@ -552,8 +635,8 @@ func NRCreateChatCompletionStream(cw *ClientWrapper, ctx context.Context, req op
StreamingData["vendor"] = "OpenAI"
StreamingData["ingest_source"] = "Go"
StreamingData["appName"] = config.AppName
app.RecordCustomEvent("LlmChatCompletionSummary", StreamingData)
txn.End()
return &ChatCompletionStreamWrapper{stream: stream, txn: txn}, nil

NRCreateChatCompletionMessageInput(txn, app, req, uuid, cw)
return &ChatCompletionStreamWrapper{app: app, stream: stream, txn: txn, uuid: uuid.String(), cw: cw, StreamingData: StreamingData, TraceID: traceID}, nil

}
Loading

0 comments on commit cae089b

Please sign in to comment.