Skip to content

Commit

Permalink
feat: set sampling defaults for all request types
Browse files Browse the repository at this point in the history
  • Loading branch information
sjmonson committed Sep 30, 2024
1 parent 53d5370 commit ce4c6ec
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions plugins/openai_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def _parse_args(self, args):
if self.api not in APIS:
logger.error("Invalid api type: %s", self.api)

# TODO Make this configurable
self.request_defaults = dict(
temperature = 0.0,
seed = 42,
)

def _process_resp(self, resp: bytes) -> Optional[dict]:
try:
_, found, data = resp.partition(b"data: ")
Expand All @@ -94,24 +100,23 @@ def request_http(self, query: dict, user_id: int, test_end_time: float = 0):
headers = {"Content-Type": "application/json"}

if self.api == 'chat':
data = {
request = {
"messages": [
{"role": "user", "content": query["text"]}
],
"max_tokens": query["output_tokens"],
"temperature": 0.1,
}
else: # self.api == 'legacy'
data = {
request = {
"prompt": query["text"],
"max_tokens": query["output_tokens"],
"min_tokens": query["output_tokens"],
"temperature": 1.0, # FIXME: This seems wrong
"top_p": 0.9, # FIXME: Standardize on something
"seed": 10, # FIXME: Standardize on something
}
if self.model_name is not None:
data["model"] = self.model_name
request["model"] = self.model_name

# Merge request and defaults
data = self.request_defaults | request

response = None
try:
Expand Down Expand Up @@ -173,26 +178,28 @@ def request_http(self, query: dict, user_id: int, test_end_time: float = 0):
def streaming_request_http(self, query: dict, user_id: int, test_end_time: float):
headers = {"Content-Type": "application/json"}

data = {
request = {
"max_tokens": query["output_tokens"],
"temperature": 0.1,
"stream": True,
"stream_options": {
"include_usage": True
}
}

if self.api == 'chat':
data["messages"] = [
request["messages"] = [
{ "role": "user", "content": query["text"] }
]
else: # self.api == 'legacy'
data["prompt"] = query["text"],
data["min_tokens"] = query["output_tokens"]
request["prompt"] = query["text"],
request["min_tokens"] = query["output_tokens"]

# some runtimes only serve one model, won't check this.
if self.model_name is not None:
data["model"] = self.model_name
request["model"] = self.model_name

# Merge request and defaults
data = self.request_defaults | request

result = RequestResult(user_id, query.get("input_id"))

Expand Down

0 comments on commit ce4c6ec

Please sign in to comment.