Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ai-cache] Implement a WASM plugin for LLM result retrieval based on vector similarity #1290

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4f7bfbd
fix bugs
johnlanni Jul 31, 2024
0f9e816
fix bugs
Suchun-sv Aug 1, 2024
ff1bce6
fix bugs
Suchun-sv Aug 12, 2024
1e9d42e
init
EnableAsync Aug 15, 2024
f2a9ff6
fix conflict
Suchun-sv Aug 23, 2024
5cbae03
Merge branch 'alibaba:main' into main
Suchun-sv Aug 23, 2024
27b2f71
alter some errors
Suchun-sv Aug 24, 2024
130f2ee
fix: embedding error
EnableAsync Aug 24, 2024
56314d7
fix bugs && update interface design
Suchun-sv Aug 24, 2024
85549d0
fix bugs && refine the variable names
Suchun-sv Aug 25, 2024
8444f5e
update design for cache to support extension
Suchun-sv Aug 25, 2024
a655bc4
Merge branch 'alibaba:main' into main
Suchun-sv Sep 5, 2024
d68fa88
Refined the code; README.md content needs to be updated.
Suchun-sv Sep 5, 2024
5179392
fix bugs, README.md to be updated
Suchun-sv Sep 6, 2024
ece7e2f
fix bugs, refine variable name, update README.md
Suchun-sv Sep 6, 2024
e868a1a
Merge branch 'alibaba:main' into main
Suchun-sv Sep 6, 2024
138a526
delete folder
Suchun-sv Sep 6, 2024
e8ad550
fix typos
Suchun-sv Sep 6, 2024
c83f5c4
fix typos
Suchun-sv Sep 6, 2024
f3d3292
change append to appendMsg
Suchun-sv Sep 6, 2024
b0cf29d
fix bugs and refine code
Suchun-sv Sep 11, 2024
4a18f96
Merge branch 'main' into main
Suchun-sv Sep 11, 2024
21c9a79
fix bugs and update the SetEx function
Suchun-sv Sep 12, 2024
1767896
Merge branch 'main' into main
Suchun-sv Sep 12, 2024
71b9530
Optimize query flow logic (not fully tested)
Suchun-sv Sep 17, 2024
51b9ccc
Fix bugs and verify removal of cache setting
Suchun-sv Sep 21, 2024
3583bc9
fix bugs and update logic as requested
Suchun-sv Sep 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-cache/.gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# File generated by hgctl. Modify as required.

docker-compose-test/
*

!/.gitignore
Expand Down
97 changes: 77 additions & 20 deletions plugins/wasm-go/extensions/ai-cache/README.md
Original file line number Diff line number Diff line change
@@ -1,40 +1,96 @@
## 简介

**Note**

> 需要数据面的proxy wasm版本大于等于0.2.100
>

> 编译时,需要带上版本的tag,例如:`tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags="custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100" ./`
>

LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的结果缓存,同时支持流式和非流式响应的缓存。

## 简介
本插件的逻辑是 1. 通过`文本向量化接口`将请求内容向量化,结果作为 key,原请求作为 value,存入`向量数据库`。2. 同时,将请求内容作为key,LLM响应作为value,存入`缓存数据库`。3. 当有新请求时,通过向量化结果查询最相似的已有请求,若相似度高于设定阈值,则直接返回缓存的响应,否则将新请求和响应存入数据库,以提升处理效率。

> TODO: 是否需要将`文本向量化接口`和`缓存数据库`作为可选项?因为部分向量数据库内置了向量化接口,其次直接使用向量数据库存储响应出错几率可能并不大,且配置项更少。
>

## 配置说明
配置分为 3 个部分:向量数据库(vector);文本向量化接口(embedding);缓存数据库(cache),同时也提供了细粒度的 LLM 请求/响应提取参数配置等。

| Name | Type | Requirement | Default | Description |
| -------- | -------- | -------- | -------- | -------- |
| cacheKeyFrom.requestBody | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| cacheValueFrom.responseBody | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| cacheStreamValueFrom.responseBody | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| cacheKeyPrefix | string | optional | "higress-ai-cache:" | Redis缓存Key的前缀 |
| cacheTTL | integer | optional | 0 | 缓存的过期时间,单位是秒,默认值为0,即永不过期 |
| redis.serviceName | string | requried | - | redis 服务名称,带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local |
| redis.servicePort | integer | optional | 6379 | redis 服务端口 |
| redis.timeout | integer | optional | 1000 | 请求 redis 的超时时间,单位为毫秒 |
| redis.username | string | optional | - | 登陆 redis 的用户名 |
| redis.password | string | optional | - | 登陆 redis 的密码 |
| returnResponseTemplate | string | optional | `{"id":"from-cache","choices":[%s],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
| returnStreamResponseTemplate | string | optional | `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
## 向量数据库服务(vector)
| Name | Type | Requirement | Default | Description |
| --- | --- | --- | --- | --- |
| vector.type | string | required | "" | 向量存储服务提供者类型,例如 DashVector |
| vector.serviceName | string | required | "" | 向量存储服务名称 |
| vector.serviceDomain | string | required | "" | 向量存储服务域名 |
| vector.servicePort | int64 | optional | 443 | 向量存储服务端口 |
| vector.apiKey | string | optional | "" | 向量存储服务 API Key |
| vector.topK | int | optional | 1 | 返回TopK结果,默认为 1 |
| vector.timeout | uint32 | optional | 10000 | 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 |
| vector.collectionID | string | optional | "" | DashVector 向量存储服务 Collection ID |


## 文本向量化服务(embedding)
| Name | Type | Requirement | Default | Description |
| --- | --- | --- | --- | --- |
| embedding.type | string | required | "" | 请求文本向量化服务类型,例如 DashScope |
| embedding.serviceName | string | required | "" | 请求文本向量化服务名称 |
| embedding.serviceDomain | string | required | "" | 请求文本向量化服务域名 |
| embedding.servicePort | int64 | optional | 443 | 请求文本向量化服务端口 |
| embedding.apiKey | string | optional | "" | 请求文本向量化服务的 API Key |
| embedding.timeout | uint32 | optional | 10000 | 请求文本向量化服务的超时时间,单位为毫秒。默认值是10000,即10秒 |
| embedding.model | string | optional | "" | 请求文本向量化服务的模型名称 |


## 缓存服务(cache)
| cache.type | string | required | "" | 缓存服务类型,例如 redis |
| --- | --- | --- | --- | --- |
| cache.serviceName | string | required | "" | 缓存服务名称 |
| cache.serviceDomain | string | required | "" | 缓存服务域名 |
| cache.servicePort | int64 | optional | 6379 | 缓存服务端口 |
| cache.username | string | optional | "" | 缓存服务用户名 |
| cache.password | string | optional | "" | 缓存服务密码 |
| cache.timeout | uint32 | optional | 10000 | 缓存服务的超时时间,单位为毫秒。默认值是10000,即10秒 |
| cache.cacheTTL | int | optional | 0 | 缓存过期时间,单位为秒。默认值是 0,即 永不过期|
| cacheKeyPrefix | string | optional | "higressAiCache:" | 缓存 Key 的前缀,默认值为 "higressAiCache:" |

## 配置示例

## 其他配置
| Name | Type | Requirement | Default | Description |
| --- | --- | --- | --- | --- |
| cacheKeyFrom | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| cacheValueFrom | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| cacheStreamValueFrom | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| cacheToolCallsFrom | string | optional | "choices.0.delta.content.tool_calls" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| responseTemplate | string | optional | `{"id":"ai-cache.hit","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
| streamResponseTemplate | string | optional | `data:{"id":"ai-cache.hit","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |


## 配置示例
### 基础配置
```yaml
redis:
serviceName: my-redis.dns
timeout: 2000
embedding:
type: dashscope
serviceName: [Your Service Name]
apiKey: [Your Key]

vector:
type: dashvector
serviceName: [Your Service Name]
collectionID: [Your Collection ID]
serviceDomain: [Your domain]
apiKey: [Your key]

cache:
type: redis
serviceName: [Your Service Name]
servicePort: 6379
timeout: 100

```

## 进阶用法

当前默认的缓存 key 是基于 GJSON PATH 的表达式:`messages.@reverse.0.content` 提取,含义是把 messages 数组反转后取第一项的 content;

GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user 的 content 作为 key,可以写成: `messages.@reverse.#(role=="user").content`;
Expand All @@ -44,3 +100,4 @@ GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user
还可以支持管道语法,例如希望取到数第二个 role 为 user 的 content 作为 key,可以写成:`messages.@reverse.#(role=="user")#.content|1`。

更多用法可以参考[官方文档](https://github.com/tidwall/gjson/blob/master/SYNTAX.md),可以使用 [GJSON Playground](https://gjson.dev/) 进行语法测试。

123 changes: 123 additions & 0 deletions plugins/wasm-go/extensions/ai-cache/cache/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package cache

import (
"errors"

"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)

const (
PROVIDER_TYPE_REDIS = "redis"
DEFAULT_CACHE_PREFIX = "higressAiCache:"
)

type providerInitializer interface {
ValidateConfig(ProviderConfig) error
CreateProvider(ProviderConfig) (Provider, error)
}

var (
providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_REDIS: &redisProviderInitializer{},
}
)

type ProviderConfig struct {
// @Title zh-CN redis 缓存服务提供者类型
// @Description zh-CN 缓存服务提供者类型,例如 redis
typ string
// @Title zh-CN redis 缓存服务名称
// @Description zh-CN 缓存服务名称
serviceName string
// @Title zh-CN redis 缓存服务端口
// @Description zh-CN 缓存服务端口,默认值为6379
servicePort int
// @Title zh-CN redis 缓存服务地址
// @Description zh-CN Cache 缓存服务地址,非必填
serviceHost string
// @Title zh-CN 缓存服务用户名
// @Description zh-CN 缓存服务用户名,非必填
username string
// @Title zh-CN 缓存服务密码
// @Description zh-CN 缓存服务密码,非必填
password string
// @Title zh-CN 请求超时
// @Description zh-CN 请求缓存服务的超时时间,单位为毫秒。默认值是10000,即10秒
timeout uint32
// @Title zh-CN 缓存过期时间
// @Description zh-CN 缓存过期时间,单位为秒。默认值是0,即永不过期
cacheTTL int
// @Title 缓存 Key 前缀
// @Description 缓存 Key 的前缀,默认值为 "higressAiCache:"
cacheKeyPrefix string
}

func (c *ProviderConfig) FromJson(json gjson.Result) {
c.typ = json.Get("type").String()
c.serviceName = json.Get("serviceName").String()
c.servicePort = int(json.Get("servicePort").Int())
if !json.Get("servicePort").Exists() {
c.servicePort = 6379
}
c.serviceHost = json.Get("serviceHost").String()
c.username = json.Get("username").String()
if !json.Get("username").Exists() {
c.username = ""
}
c.password = json.Get("password").String()
if !json.Get("password").Exists() {
c.password = ""
}
c.timeout = uint32(json.Get("timeout").Int())
if !json.Get("timeout").Exists() {
c.timeout = 10000
}
c.cacheTTL = int(json.Get("cacheTTL").Int())
if !json.Get("cacheTTL").Exists() {
c.cacheTTL = 0
// c.cacheTTL = 3600000
}
if json.Get("cacheKeyPrefix").Exists() {
c.cacheKeyPrefix = json.Get("cacheKeyPrefix").String()
} else {
c.cacheKeyPrefix = DEFAULT_CACHE_PREFIX
}

}

func (c *ProviderConfig) Validate() error {
if c.typ == "" {
return errors.New("cache service type is required")
}
if c.serviceName == "" {
return errors.New("cache service name is required")
}
if c.cacheTTL < 0 {
return errors.New("cache TTL must be greater than or equal to 0")
}
initializer, has := providerInitializers[c.typ]
if !has {
return errors.New("unknown cache service provider type: " + c.typ)
}
if err := initializer.ValidateConfig(*c); err != nil {
return err
}
return nil
}

func CreateProvider(pc ProviderConfig) (Provider, error) {
initializer, has := providerInitializers[pc.typ]
if !has {
return nil, errors.New("unknown provider type: " + pc.typ)
}
return initializer.CreateProvider(pc)
}

type Provider interface {
GetProviderType() string
Init(username string, password string, timeout uint32) error
Get(key string, cb wrapper.RedisResponseCallback) error
Set(key string, value string, cb wrapper.RedisResponseCallback) error
GetCacheKeyPrefix() string
}
58 changes: 58 additions & 0 deletions plugins/wasm-go/extensions/ai-cache/cache/redis.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package cache

import (
"errors"

"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)

type redisProviderInitializer struct {
}

func (r *redisProviderInitializer) ValidateConfig(cf ProviderConfig) error {
if len(cf.serviceName) == 0 {
return errors.New("cache service name is required")
}
return nil
}

func (r *redisProviderInitializer) CreateProvider(cf ProviderConfig) (Provider, error) {
rp := redisProvider{
config: cf,
client: wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
FQDN: cf.serviceName,
Host: cf.serviceHost,
Port: int64(cf.servicePort)}),
}
err := rp.Init(cf.username, cf.password, cf.timeout)
return &rp, err
}

type redisProvider struct {
config ProviderConfig
client wrapper.RedisClient
}

func (rp *redisProvider) GetProviderType() string {
return PROVIDER_TYPE_REDIS
}

func (rp *redisProvider) Init(username string, password string, timeout uint32) error {
return rp.client.Init(rp.config.username, rp.config.password, int64(rp.config.timeout))
}

func (rp *redisProvider) Get(key string, cb wrapper.RedisResponseCallback) error {
return rp.client.Get(key, cb)
}

func (rp *redisProvider) Set(key string, value string, cb wrapper.RedisResponseCallback) error {
if rp.config.cacheTTL == 0 {
return rp.client.Set(key, value, cb)
} else {
return rp.client.SetEx(key, value, rp.config.cacheTTL, cb)
}
}

func (rp *redisProvider) GetCacheKeyPrefix() string {
CH3CHO marked this conversation as resolved.
Show resolved Hide resolved
return rp.config.cacheKeyPrefix
}
Loading