Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add impersonate feature to API /v1/chat/completions #6342

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions extensions/openai/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from extensions.openai.utils import debug_msg
from modules import shared
from modules.chat import (
get_stopping_strings,
generate_chat_prompt,
generate_chat_reply,
load_character_memoized,
Expand Down Expand Up @@ -242,6 +243,9 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
# generation parameters
generate_params = process_parameters(body, is_legacy=is_legacy)
continue_ = body['continue_']
impersonate = body['impersonate']
if impersonate:
continue_ = False # While impersonate, continue_ should be False. References impersonate_wrapper in chat.py

# Instruction template
if body['instruction_template_str']:
Expand Down Expand Up @@ -294,6 +298,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p

def chat_streaming_chunk(content):
# begin streaming
role = 'user' if impersonate else 'assistant'
chunk = {
"id": cmpl_id,
"object": object_type,
Expand All @@ -302,7 +307,7 @@ def chat_streaming_chunk(content):
resp_list: [{
"index": 0,
"finish_reason": None,
"delta": {'role': 'assistant', 'content': content},
"delta": {'role': role, 'content': content},
}],
}

Expand All @@ -314,7 +319,9 @@ def chat_streaming_chunk(content):
return chunk

# generate reply #######################################
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_, impersonate=impersonate)
if impersonate:
prompt += user_input
if prompt_only:
yield {'prompt': prompt}
return
Expand All @@ -324,14 +331,23 @@ def chat_streaming_chunk(content):
if stream:
yield chat_streaming_chunk('')

generator = generate_chat_reply(
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
if impersonate:
stopping_strings = get_stopping_strings(generate_params)
generator = generate_reply(prompt, generate_params, stopping_strings=stopping_strings, is_chat=True)
else:
generator = generate_chat_reply(
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)

answer = ''
seen_content = ''

for a in generator:
answer = a['internal'][-1][1]
if impersonate:
# The generate_chat_reply returns the entire message, but generate_reply will only start from new content.
# So we need to add the user_input to keep output consistent.
answer = user_input + a
else:
answer = a['internal'][-1][1]
if stream:
len_seen = len(seen_content)
new_content = answer[len_seen:]
Expand Down Expand Up @@ -360,6 +376,7 @@ def chat_streaming_chunk(content):

yield chunk
else:
role = 'user' if impersonate else 'assistant'
resp = {
"id": cmpl_id,
"object": object_type,
Expand All @@ -368,7 +385,7 @@ def chat_streaming_chunk(content):
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"message": {"role": "assistant", "content": answer}
"message": {"role": role, "content": answer}
}],
"usage": {
"prompt_tokens": token_count,
Expand Down
2 changes: 2 additions & 0 deletions extensions/openai/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class ChatCompletionRequestParams(BaseModel):

continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.")

impersonate: bool = Field(default=False, description="Impersonate the user in the chat. Makes the model continue generate the last user message.")


class ChatCompletionRequest(GenerationOptions, ChatCompletionRequestParams):
pass
Expand Down