Skip to content

Commit

Permalink
Merge pull request #733 from thedadams/allow-provider-restart
Browse files Browse the repository at this point in the history
feat: allow providers to be restarted if they stop
  • Loading branch information
thedadams authored Aug 7, 2024
2 parents 4fd8e8a + c0507a2 commit fb32c4c
Showing 1 changed file with 18 additions and 34 deletions.
52 changes: 18 additions & 34 deletions pkg/remote/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ import (
)

type Client struct {
clientsLock sync.Mutex
modelsLock sync.Mutex
cache *cache.Client
clients map[string]*openai.Client
models map[string]*openai.Client
modelToProvider map[string]string
runner *runner.Runner
envs []string
credStore credentials.CredentialStore
Expand All @@ -43,14 +42,19 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent
}

func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
c.clientsLock.Lock()
client, ok := c.models[messageRequest.Model]
c.clientsLock.Unlock()
c.modelsLock.Lock()
provider, ok := c.modelToProvider[messageRequest.Model]
c.modelsLock.Unlock()

if !ok {
return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model)
}

client, err := c.load(ctx, provider)
if err != nil {
return nil, err
}

toolName, modelName := types.SplitToolRef(messageRequest.Model)
if modelName == "" {
// modelName is empty, then the messageRequest.Model is not of the form 'modelName from provider'
Expand Down Expand Up @@ -96,19 +100,19 @@ func (c *Client) Supports(ctx context.Context, modelString string) (bool, error)
return false, nil
}

client, err := c.load(ctx, providerName)
_, err := c.load(ctx, providerName)
if err != nil {
return false, err
}

c.clientsLock.Lock()
defer c.clientsLock.Unlock()
c.modelsLock.Lock()
defer c.modelsLock.Unlock()

if c.models == nil {
c.models = map[string]*openai.Client{}
if c.modelToProvider == nil {
c.modelToProvider = map[string]string{}
}

c.models[modelString] = client
c.modelToProvider[modelString] = providerName
return true, nil
}

Expand Down Expand Up @@ -141,24 +145,11 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie
}

func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) {
c.clientsLock.Lock()
defer c.clientsLock.Unlock()

client, ok := c.clients[toolName]
if ok {
return client, nil
}

if c.clients == nil {
c.clients = make(map[string]*openai.Client)
}

if isHTTPURL(toolName) {
remoteClient, err := c.clientFromURL(ctx, toolName)
if err != nil {
return nil, err
}
c.clients[toolName] = remoteClient
return remoteClient, nil
}

Expand All @@ -174,22 +165,15 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
return nil, err
}

if strings.HasSuffix(url, "/") {
url += "v1"
} else {
url += "/v1"
}

client, err = openai.NewClient(ctx, c.credStore, openai.Options{
BaseURL: url,
client, err := openai.NewClient(ctx, c.credStore, openai.Options{
BaseURL: strings.TrimSuffix(url, "/") + "/v1",
Cache: c.cache,
CacheKey: prg.EntryToolID,
})
if err != nil {
return nil, err
}

c.clients[toolName] = client
return client, nil
}

Expand Down

0 comments on commit fb32c4c

Please sign in to comment.