Skip to content

Commit

Permalink
Multi system message constructor, and types for model + chat roles. (#14
Browse files Browse the repository at this point in the history
)

* types for chat role and model

* sys message part

* constructor for multi system messages

* multi system message builder tests

* wspace

* more constructors

* revert rl header err checking to main
  • Loading branch information
WillMatthews committed Aug 30, 2024
1 parent 59ad2d1 commit bc51f1b
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 32 deletions.
22 changes: 13 additions & 9 deletions common.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
package anthropic

type Model string

const (
ModelClaudeInstant1Dot2 = "claude-instant-1.2"
ModelClaude2Dot0 = "claude-2.0"
ModelClaude2Dot1 = "claude-2.1"
ModelClaude3Opus20240229 = "claude-3-opus-20240229"
ModelClaude3Sonnet20240229 = "claude-3-sonnet-20240229"
ModelClaude3Dot5Sonnet20240620 = "claude-3-5-sonnet-20240620"
ModelClaude3Haiku20240307 = "claude-3-haiku-20240307"
ModelClaudeInstant1Dot2 Model = "claude-instant-1.2"
ModelClaude2Dot0 Model = "claude-2.0"
ModelClaude2Dot1 Model = "claude-2.1"
ModelClaude3Opus20240229 Model = "claude-3-opus-20240229"
ModelClaude3Sonnet20240229 Model = "claude-3-sonnet-20240229"
ModelClaude3Dot5Sonnet20240620 Model = "claude-3-5-sonnet-20240620"
ModelClaude3Haiku20240307 Model = "claude-3-haiku-20240307"
)

type ChatRole string

const (
RoleUser = "user"
RoleAssistant = "assistant"
RoleUser ChatRole = "user"
RoleAssistant ChatRole = "assistant"
)
4 changes: 2 additions & 2 deletions complete.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

type CompleteRequest struct {
Model string `json:"model"`
Model Model `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`

Expand Down Expand Up @@ -38,7 +38,7 @@ type CompleteResponse struct {
Completion string `json:"completion"`
// possible values are: stop_sequence、max_tokens、null
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Model Model `json:"model"`
}

func (c *Client) CreateComplete(ctx context.Context, request CompleteRequest) (response CompleteResponse, err error) {
Expand Down
47 changes: 37 additions & 10 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const (
)

type MessagesRequest struct {
Model string `json:"model"`
Model Model `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`

Expand Down Expand Up @@ -87,8 +87,23 @@ type MessageSystemPart struct {
CacheControl *MessageCacheControl `json:"cache_control,omitempty"`
}

func NewMultiSystemMessages(texts ...string) []MessageSystemPart {
var systemParts []MessageSystemPart
for _, text := range texts {
systemParts = append(systemParts, NewSystemMessagePart(text))
}
return systemParts
}

func NewSystemMessagePart(text string) MessageSystemPart {
return MessageSystemPart{
Type: "text",
Text: text,
}
}

type Message struct {
Role string `json:"role"`
Role ChatRole `json:"role"`
Content []MessageContent `json:"content"`
}

Expand Down Expand Up @@ -169,12 +184,8 @@ func NewToolResultMessageContent(toolUseID, content string, isError bool) Messag

func NewToolUseMessageContent(toolUseID, name string, input json.RawMessage) MessageContent {
return MessageContent{
Type: MessagesContentTypeToolUse,
MessageContentToolUse: &MessageContentToolUse{
ID: toolUseID,
Name: name,
Input: input,
},
Type: MessagesContentTypeToolUse,
MessageContentToolUse: NewMessageContentToolUse(toolUseID, name, input),
}
}

Expand Down Expand Up @@ -252,12 +263,28 @@ type MessageContentImageSource struct {
Data any `json:"data"`
}

func NewMessageContentImageSource(imageSourceType, mediaType string, data any) MessageContentImageSource {
return MessageContentImageSource{
Type: imageSourceType,
MediaType: mediaType,
Data: data,
}
}

type MessageContentToolUse struct {
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"`
}

func NewMessageContentToolUse(toolUseId, name string, input json.RawMessage) *MessageContentToolUse {
return &MessageContentToolUse{
ID: toolUseId,
Name: name,
Input: input,
}
}

func (c *MessageContentToolUse) UnmarshalInput(v any) error {
return json.Unmarshal(c.Input, v)
}
Expand All @@ -267,9 +294,9 @@ type MessagesResponse struct {

ID string `json:"id"`
Type MessagesResponseType `json:"type"`
Role string `json:"role"`
Role ChatRole `json:"role"`
Content []MessageContent `json:"content"`
Model string `json:"model"`
Model Model `json:"model"`
StopReason MessagesStopReason `json:"stop_reason"`
StopSequence string `json:"stop_sequence"`
Usage MessagesUsage `json:"usage"`
Expand Down
84 changes: 73 additions & 11 deletions message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,79 @@ func TestMessages(t *testing.T) {
anthropic.WithEmptyMessagesLimit(100),
anthropic.WithHTTPClient(http.DefaultClient),
)
resp, err := client.CreateMessages(context.Background(), anthropic.MessagesRequest{
Model: anthropic.ModelClaudeInstant1Dot2,
Messages: []anthropic.Message{
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,

t.Run("create messages success", func(t *testing.T) {
resp, err := client.CreateMessages(context.Background(), anthropic.MessagesRequest{
Model: anthropic.ModelClaudeInstant1Dot2,
Messages: []anthropic.Message{
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
})
if err != nil {
t.Fatalf("CreateMessages error: %v", err)
}

t.Logf("CreateMessages resp: %+v", resp)
})

t.Run("create messages success with single system message", func(t *testing.T) {
resp, err := client.CreateMessages(context.Background(), anthropic.MessagesRequest{
Model: anthropic.ModelClaudeInstant1Dot2,
Messages: []anthropic.Message{
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
System: "test system message",
})
if err != nil {
t.Fatalf("CreateMessages error: %v", err)
}

t.Logf("CreateMessages resp: %+v", resp)
})

t.Run("create messages success with single multi-system message", func(t *testing.T) {
resp, err := client.CreateMessages(context.Background(), anthropic.MessagesRequest{
Model: anthropic.ModelClaudeInstant1Dot2,
Messages: []anthropic.Message{
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
MultiSystem: anthropic.NewMultiSystemMessages("test single multi-system message"),
})
if err != nil {
t.Fatalf("CreateMessages error: %v", err)
}

t.Logf("CreateMessages resp: %+v", resp)
})

t.Run("create messages success with multi-system messages", func(t *testing.T) {
resp, err := client.CreateMessages(context.Background(), anthropic.MessagesRequest{
Model: anthropic.ModelClaudeInstant1Dot2,
Messages: []anthropic.Message{
anthropic.NewUserTextMessage("What is your name?"),
},
MaxTokens: 1000,
MultiSystem: anthropic.NewMultiSystemMessages(
"test multi-system messages",
"here",
"are",
"some",
"more",
"messages",
"for",
"testing",
),
})
if err != nil {
t.Fatalf("CreateMessages error: %v", err)
}

t.Logf("CreateMessages resp: %+v", resp)
})
if err != nil {
t.Fatalf("CreateMessages error: %v", err)
}

t.Logf("CreateMessages resp: %+v", resp)
}

func TestNewUserTextMessage(t *testing.T) {
Expand Down Expand Up @@ -251,12 +312,13 @@ func TestMessagesVision(t *testing.T) {
{
Role: anthropic.RoleUser,
Content: []anthropic.MessageContent{
anthropic.NewImageMessageContent(anthropic.NewMessageContentImageSource("base64", imageMediaType, imageData)),
anthropic.NewImageMessageContent(anthropic.MessageContentImageSource{
Type: "base64",
MediaType: imageMediaType,
Data: imageData,
}),
anthropic.NewTextMessageContent("Describe this image."),
anthropic.NewTextMessageContent("Describe these images."),
},
},
},
Expand Down

0 comments on commit bc51f1b

Please sign in to comment.