Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat command with gpt-wrapper integration #545

Merged
merged 19 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ jobs:
echo "GR_ARGS=${args}" >> $GITHUB_ENV
- name: Echo GoReleaser Args
run: echo ${{ env.GR_ARGS }}
- name: Config Git credentials
run: git config --global url."https://${{ secrets.PERSONAL_ACCESS_TOKEN }}@github.com".insteadOf "https://github.com"
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v3
with:
Expand Down
2 changes: 2 additions & 0 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ func main() {
tenantConfigurationWrapper := wrappers.NewHTTPTenantConfigurationWrapper(tenantConfigurationPath)
jwtWrapper := wrappers.NewJwtWrapper()
scaRealTimeWrapper := wrappers.NewHTTPScaRealTimeWrapper()
chatWrapper := wrappers.NewChatWrapper()

astCli := commands.NewAstCLI(
scansWrapper,
Expand All @@ -92,6 +93,7 @@ func main() {
tenantConfigurationWrapper,
jwtWrapper,
scaRealTimeWrapper,
chatWrapper,
)
exitListener()
err = astCli.Execute()
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ require (
)

require (
github.com/checkmarxDev/gpt-wrapper v0.0.0-20230620151243-0a3131178ae4 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
Expand Down
8 changes: 8 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,14 @@ github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/checkmarxDev/gpt-wrapper v0.0.0-20230613122817-da0c37ce1aa6 h1:EFnj6OniqFxKmeWZmtPesK/K5NhpP0AOLPnfK0R9tWw=
github.com/checkmarxDev/gpt-wrapper v0.0.0-20230613122817-da0c37ce1aa6/go.mod h1:l+0rISRGaps2HWkpvKbYPE1nsNx28vBj6bKorEm1M5o=
github.com/checkmarxDev/gpt-wrapper v0.0.0-20230614144539-0f97d561dc9f h1:KW8EJM77EoICG54vLWx9CN7s0hXHJjtFODsO3xFEFGI=
github.com/checkmarxDev/gpt-wrapper v0.0.0-20230614144539-0f97d561dc9f/go.mod h1:l+0rISRGaps2HWkpvKbYPE1nsNx28vBj6bKorEm1M5o=
github.com/checkmarxDev/gpt-wrapper v0.0.0-20230616130821-990a46a50d40 h1:gJdxyugcYjcQLWo4PKzZUOJ5/Dy7tw8lcztabkpiObQ=
github.com/checkmarxDev/gpt-wrapper v0.0.0-20230616130821-990a46a50d40/go.mod h1:l+0rISRGaps2HWkpvKbYPE1nsNx28vBj6bKorEm1M5o=
github.com/checkmarxDev/gpt-wrapper v0.0.0-20230620151243-0a3131178ae4 h1:e4EGi1vqMlJ52qo6NsUkrKhUsprS6fYH4mNiE13omxU=
github.com/checkmarxDev/gpt-wrapper v0.0.0-20230620151243-0a3131178ae4/go.mod h1:l+0rISRGaps2HWkpvKbYPE1nsNx28vBj6bKorEm1M5o=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
Expand Down
150 changes: 150 additions & 0 deletions internal/commands/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package commands

import (
"fmt"
"os"

"github.com/checkmarx/ast-cli/internal/commands/util/printer"
"github.com/checkmarx/ast-cli/internal/logger"
"github.com/checkmarx/ast-cli/internal/params"
"github.com/checkmarx/ast-cli/internal/wrappers"
"github.com/checkmarxDev/gpt-wrapper/pkg/connector"
"github.com/checkmarxDev/gpt-wrapper/pkg/message"
"github.com/checkmarxDev/gpt-wrapper/pkg/role"
"github.com/checkmarxDev/gpt-wrapper/pkg/wrapper"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/spf13/cobra"
)

const systemInput = `You are the Checkmarx AI Guided Remediation bot who can answer technical questions related to the results of Infrastructure as Code Security.
You should be able to analyze and understand both the technical aspects of the security results and the common queries users may have about the results.
You should also be capable of delivering clear, concise, and informative answers to help take appropriate action based on the findings.
If a question irrelevant to the mentioned Infrastructure as Code Security source or result is asked,
answer 'I am the AI Guided Remediation assistant and can answer only on questions related to the selected result'.`

const assistantInputFormat = `Checkmarx Infrastructure as Code Security has scanned this source code and reported the result.
This is the source code:
` + "```" + `
%s
` + "```" + `
and this is the result (vulnerability or security issue) found by Infrastructure as Code Security:
'%s' is detected in line %s with severity '%s'.`

const userInputFormat = `The user question is:
'<|IAC_QUESTION_START|>'
"%s"
'<|IAC_QUESTION_END|>'`

// dropLen number of messages to drop when limit is reached, 4 due to 2 from prompt, 1 from user question, 1 from reply
const dropLen = 4

const ConversationIDErrorFormat = "Invalid conversation ID %s."
const FileErrorFormat = "It seems that %s is not available for AI Guided Remediation. Please ensure that you have opened the correct workspace or the relevant file."

type OutputModel struct {
ConversationID string `json:"conversationId"`
Response []string `json:"response"`
}

func NewChatCommand(chatWrapper wrappers.ChatWrapper) *cobra.Command {
chatCmd := &cobra.Command{
Use: "chat",
Short: "Interact with OpenAI models",
Long: "Interact with OpenAI models",
RunE: runChat(chatWrapper),
}

chatCmd.Flags().String(params.ChatAPIKey, "", "OpenAI API key")
chatCmd.Flags().String(params.ChatConversationID, "", "ID of existing conversation")
chatCmd.Flags().String(params.ChatUserInput, "", "User question")
chatCmd.Flags().String(params.ChatModel, "", "OpenAI model version")
chatCmd.Flags().String(params.ChatResultFile, "", "IaC result code file")
chatCmd.Flags().String(params.ChatResultLine, "", "IaC result line")
chatCmd.Flags().String(params.ChatResultSeverity, "", "IaC result severity")
chatCmd.Flags().String(params.ChatResultVulnerability, "", "IaC result vulnerability name")

_ = chatCmd.MarkFlagRequired(params.ChatUserInput)
_ = chatCmd.MarkFlagRequired(params.ChatAPIKey)
_ = chatCmd.MarkFlagRequired(params.ChatResultFile)
_ = chatCmd.MarkFlagRequired(params.ChatResultLine)
_ = chatCmd.MarkFlagRequired(params.ChatResultSeverity)
_ = chatCmd.MarkFlagRequired(params.ChatResultVulnerability)

return chatCmd
}

func runChat(chatWrapper wrappers.ChatWrapper) func(cmd *cobra.Command, args []string) error {
return func(cmd *cobra.Command, args []string) error {
chatAPIKey, _ := cmd.Flags().GetString(params.ChatAPIKey)
chatConversationID, _ := cmd.Flags().GetString(params.ChatConversationID)
chatModel, _ := cmd.Flags().GetString(params.ChatModel)
chatResultFile, _ := cmd.Flags().GetString(params.ChatResultFile)
chatResultLine, _ := cmd.Flags().GetString(params.ChatResultLine)
chatResultSeverity, _ := cmd.Flags().GetString(params.ChatResultSeverity)
chatResultVulnerability, _ := cmd.Flags().GetString(params.ChatResultVulnerability)
userInput, _ := cmd.Flags().GetString(params.ChatUserInput)

statefulWrapper := wrapper.NewStatefulWrapper(connector.NewFileSystemConnector(""), chatAPIKey, chatModel, dropLen)

if chatConversationID == "" {
chatConversationID = statefulWrapper.GenerateId().String()
}

id, err := uuid.Parse(chatConversationID)
if err != nil {
logger.PrintIfVerbose(err.Error())
return outputError(cmd, id, errors.Errorf(ConversationIDErrorFormat, chatConversationID))
}

chatResultCode, err := os.ReadFile(chatResultFile)
if err != nil {
logger.PrintIfVerbose(err.Error())
return outputError(cmd, id, errors.Errorf(FileErrorFormat, chatResultFile))
}

newMessages := buildMessages(chatResultCode, chatResultVulnerability, chatResultLine, chatResultSeverity, userInput)
response, err := chatWrapper.Call(statefulWrapper, id, newMessages)
if err != nil {
return outputError(cmd, id, err)
}

responseContent := getMessageContents(response)

return printer.Print(cmd.OutOrStdout(), &OutputModel{
ConversationID: id.String(),
Response: responseContent,
}, printer.FormatJSON)
}
}

func getMessageContents(response []message.Message) []string {
var responseContent []string
for _, r := range response {
responseContent = append(responseContent, r.Content)
}
return responseContent
}

func buildMessages(chatResultCode []byte,
chatResultVulnerability, chatResultLine, chatResultSeverity, userInput string) []message.Message {
var newMessages []message.Message
newMessages = append(newMessages, message.Message{
Role: role.System,
Content: systemInput,
}, message.Message{
Role: role.Assistant,
Content: fmt.Sprintf(assistantInputFormat, string(chatResultCode), chatResultVulnerability, chatResultLine, chatResultSeverity),
}, message.Message{
Role: role.User,
Content: fmt.Sprintf(userInputFormat, userInput),
})
return newMessages
}

func outputError(cmd *cobra.Command, id uuid.UUID, err error) error {
return printer.Print(cmd.OutOrStdout(), &OutputModel{
ConversationID: id.String(),
Response: []string{err.Error()},
}, printer.FormatJSON)
}
63 changes: 63 additions & 0 deletions internal/commands/chat_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package commands

import (
"fmt"
"io"
"strings"
"testing"

"github.com/google/uuid"
"gotest.tools/assert"
)

func TestChatHelp(t *testing.T) {
execCmdNilAssertion(t, "help", "chat")
}

func TestChatInvalidId(t *testing.T) {
buffer, err := executeRedirectedTestCommand("chat",
"--conversation-id", "invalidId",
"--chat-apikey", "apiKey",
"--user-input", "userInput",
"--result-file", "file",
"--result-line", "0",
"--result-severity", "LOW",
"--result-vulnerability", "Vulnerability")
assert.NilError(t, err)
output, err := io.ReadAll(buffer)
assert.NilError(t, err)
s := string(output)
assert.Assert(t, strings.Contains(s, fmt.Sprintf(ConversationIDErrorFormat, "invalidId")), s)
}

func TestChatInvalidFile(t *testing.T) {
buffer, err := executeRedirectedTestCommand("chat",
"--conversation-id", uuid.New().String(),
"--chat-apikey", "apiKey",
"--user-input", "userInput",
"--result-file", "invalidfile",
"--result-line", "0",
"--result-severity", "LOW",
"--result-vulnerability", "Vulnerability")
assert.NilError(t, err)
output, err := io.ReadAll(buffer)
assert.NilError(t, err)
s := string(output)
assert.Assert(t, strings.Contains(s, fmt.Sprintf(FileErrorFormat, "invalidfile")), s)
}

func TestChatCorrectResponse(t *testing.T) {
buffer, err := executeRedirectedTestCommand("chat",
"--conversation-id", uuid.New().String(),
"--chat-apikey", "apiKey",
"--user-input", "userInput",
"--result-file", "./data/Dockerfile",
"--result-line", "0",
"--result-severity", "LOW",
"--result-vulnerability", "Vulnerability")
assert.NilError(t, err)
output, err := io.ReadAll(buffer)
assert.NilError(t, err)
s := strings.ToLower(string(output))
assert.Assert(t, strings.Contains(s, "mock"), s)
}
4 changes: 4 additions & 0 deletions internal/commands/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func NewAstCLI(
tenantWrapper wrappers.TenantConfigurationWrapper,
jwtWrapper wrappers.JWTWrapper,
scaRealTimeWrapper wrappers.ScaRealTimeWrapper,
chatWrapper wrappers.ChatWrapper,
) *cobra.Command {
// Create the root
rootCmd := &cobra.Command{
Expand Down Expand Up @@ -165,6 +166,8 @@ func NewAstCLI(
configCmd := util.NewConfigCommand()
triageCmd := NewResultsPredicatesCommand(resultsPredicatesWrapper)

chatCmd := NewChatCommand(chatWrapper)

rootCmd.AddCommand(
scanCmd,
projectCmd,
Expand All @@ -174,6 +177,7 @@ func NewAstCLI(
authCmd,
utilsCmd,
configCmd,
chatCmd,
)

rootCmd.SilenceErrors = true
Expand Down
12 changes: 12 additions & 0 deletions internal/commands/root_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package commands

import (
"bytes"
"fmt"
"log"
"os"
Expand Down Expand Up @@ -52,6 +53,7 @@ func createASTTestCommand() *cobra.Command {
tenantConfigurationMockWrapper := &mock.TenantConfigurationMockWrapper{}
jwtWrapper := &mock.JWTMockWrapper{}
scaRealtimeMockWrapper := &mock.ScaRealTimeHTTPMockWrapper{}
chatWrapper := &mock.ChatMockWrapper{}

return NewAstCLI(
scansMockWrapper,
Expand All @@ -77,6 +79,7 @@ func createASTTestCommand() *cobra.Command {
tenantConfigurationMockWrapper,
jwtWrapper,
scaRealtimeMockWrapper,
chatWrapper,
)
}

Expand All @@ -100,6 +103,15 @@ func executeTestCommand(cmd *cobra.Command, args ...string) error {
return cmd.Execute()
}

func executeRedirectedTestCommand(args ...string) (*bytes.Buffer, error) {
buffer := bytes.NewBufferString("")
cmd := createASTTestCommand()
cmd.SetArgs(args)
cmd.SilenceUsage = true
cmd.SetOut(buffer)
return buffer, cmd.Execute()
}

func execCmdNilAssertion(t *testing.T, args ...string) {
err := executeTestCommand(createASTTestCommand(), args...)
assert.NilError(t, err)
Expand Down
10 changes: 10 additions & 0 deletions internal/params/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ const (
RepoNameFlagUsage = "Github repository details"
PRNumberFlag = "pr-number"
PRNumberFlagUsage = "Pull Request number for posting notifications and comments"

// Chat
ChatAPIKey = "chat-apikey"
ChatConversationID = "conversation-id"
ChatUserInput = "user-input"
ChatModel = "model"
ChatResultFile = "result-file"
ChatResultLine = "result-line"
ChatResultSeverity = "result-severity"
ChatResultVulnerability = "result-vulnerability"
)

// Parameter values
Expand Down
18 changes: 18 additions & 0 deletions internal/wrappers/chat-http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package wrappers

import (
gptWrapperMessage "github.com/checkmarxDev/gpt-wrapper/pkg/message"
gptWrapper "github.com/checkmarxDev/gpt-wrapper/pkg/wrapper"
"github.com/google/uuid"
)

type ChatHTTPWrapper struct {
}

func (c ChatHTTPWrapper) Call(w gptWrapper.StatefulWrapper, id uuid.UUID, messages []gptWrapperMessage.Message) ([]gptWrapperMessage.Message, error) {
return w.Call(id, messages)
}

func NewChatWrapper() ChatWrapper {
return ChatHTTPWrapper{}
}
11 changes: 11 additions & 0 deletions internal/wrappers/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package wrappers

import (
gptWrapperMessage "github.com/checkmarxDev/gpt-wrapper/pkg/message"
gptWrapper "github.com/checkmarxDev/gpt-wrapper/pkg/wrapper"
"github.com/google/uuid"
)

type ChatWrapper interface {
Call(gptWrapper.StatefulWrapper, uuid.UUID, []gptWrapperMessage.Message) ([]gptWrapperMessage.Message, error)
}
18 changes: 18 additions & 0 deletions internal/wrappers/mock/chat-mock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package mock

import (
gptWrapperMessage "github.com/checkmarxDev/gpt-wrapper/pkg/message"
gptWrapperRole "github.com/checkmarxDev/gpt-wrapper/pkg/role"
gptWrapper "github.com/checkmarxDev/gpt-wrapper/pkg/wrapper"
"github.com/google/uuid"
)

type ChatMockWrapper struct {
}

func (c ChatMockWrapper) Call(_ gptWrapper.StatefulWrapper, _ uuid.UUID, _ []gptWrapperMessage.Message) ([]gptWrapperMessage.Message, error) {
return []gptWrapperMessage.Message{{
Role: gptWrapperRole.Assistant,
Content: "Mock message",
}}, nil
}
Loading
Loading