From 6c6966cfa24e9c120494ebb317c6126aa2ae94af Mon Sep 17 00:00:00 2001 From: shannonsands Date: Sat, 20 Jan 2024 20:05:15 +1000 Subject: [PATCH] Added subclass of OpenAI class to support custom endpoints (#560) Adds a new class to support for providing a custom base_url to the OpenAI client, and a couple of small updates to remove the checks on the model via the "retrieve" api (since it's not always available), sets tiktoken to use gpt4 by default (with the option to override it), and enables passing temperature as a named parameter to the __call__method to override the config. Usage is basically the same as the regular OpenAI model wrapper, with the addition of the base_url parameter. Eg: ``` hermes = CustomOpenAI( model_name="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", system_prompt=system_prompt, api_key="xxx", base_url='https://api.together.xyz', ) print(hermes(test_input)) ``` The environment variables checked to set the API key was also updated, so as not to conflict with any OpenAI configurations in place, and sound a little more generic, along with a new env var `INFERENCE_BASE_URL` to act as an alternative to the constructor parameter. --- outlines/models/openai_compatible.py | 152 +++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 outlines/models/openai_compatible.py diff --git a/outlines/models/openai_compatible.py b/outlines/models/openai_compatible.py new file mode 100644 index 000000000..92a4ef960 --- /dev/null +++ b/outlines/models/openai_compatible.py @@ -0,0 +1,152 @@ +"""Integration with custom OpenAI compatible APIs.""" +import functools +import os +from dataclasses import replace +from typing import List, Optional, Union + +import numpy as np + +from outlines.models.openai import OpenAI, OpenAIConfig, generate_chat + +__all__ = ["OpenAICompatibleAPI", "openai_compatible_api"] + + +class OpenAICompatibleAPI(OpenAI): + """An object that represents an OpenAI-compatible API.""" + + def __init__( + self, + model_name: str, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + max_retries: int = 6, + timeout: Optional[float] = None, + system_prompt: Optional[str] = None, + config: Optional[OpenAIConfig] = None, + encoding="gpt4", # Default for tiktoken, should USUALLY work + ): + """Create an `OpenAI` instance. + + Parameters + ---------- + model_name + Model to use, as defined in OpenAI's documentation + api_key + Secret key to use with the OpenAI compatible API. One can also set the + `INFERENCE_API_KEY` environment variable, or the value of + `openai.api_key`. + base_url + Base URL to use for the API calls. Required if a Custom OpenAI endpoint is used. + Can also be set with the `INFERENCE_BASE_URL` environment variable. + max_retries + The maximum number of retries when calls to the API fail. + timeout + Duration after which the request times out. + system_prompt + The content of the system message that precedes the user's prompt. + config + An instance of `OpenAIConfig`. Can be useful to specify some + parameters that cannot be set by calling this class' methods. + + """ + + try: + import openai + except ImportError: + raise ImportError( + "The `openai` library needs to be installed in order to use Outlines' OpenAI integration." + ) + + if api_key is None: + if os.getenv("INFERENCE_API_KEY") is not None: + api_key = os.getenv("INFERENCE_API_KEY") + elif openai.api_key is not None: + api_key = openai.api_key + else: + raise ValueError( + "You must specify an API key to use the Custom OpenAI API integration." + ) + + if base_url is None: + if os.getenv("INFERENCE_BASE_URL") is not None: + base_url = os.getenv("INFERENCE_BASE_URL") + else: + raise ValueError( + "You must specify a base URL to use the Custom OpenAI API integration." + ) + + if config is not None: + self.config = replace(config, model=model_name) # type: ignore + else: + self.config = OpenAIConfig(model=model_name) + + # This is necesssary because of an issue with the OpenAI API. + # Status updates: https://github.com/openai/openai-python/issues/769 + self.create_client = functools.partial( + openai.AsyncOpenAI, + api_key=api_key, + base_url=base_url, + max_retries=max_retries, + timeout=timeout, + ) + + self.system_prompt = system_prompt + + # We count the total number of prompt and generated tokens as returned + # by the OpenAI API, summed over all the requests performed with this + # model instance. + self.prompt_tokens = 0 + self.completion_tokens = 0 + self.encoding = encoding + + def __call__( + self, + prompt: Union[str, List[str]], + max_tokens: Optional[int] = None, + *, + temperature: float = 1.0, + samples: int = 1, + stop_at: Optional[Union[List[str], str]] = None, + ) -> np.ndarray: + """Call the OpenAI compatible API to generate text. + + Parameters + ---------- + prompt + A string or list of strings that will be used to prompt the model + max_tokens + The maximum number of tokens to generate + temperature + The value of the temperature used to sample tokens + samples + The number of completions to generate for each prompt + stop_at + Up to 4 words where the API will stop the completion. + + """ + config = replace(self.config, max_tokens=max_tokens, n=samples, stop=stop_at, temperature=temperature) # type: ignore + + # We assume it's using the chat completion API style as that's the most commonly supported + client = self.create_client() + response, prompt_tokens, completion_tokens = generate_chat( + prompt, self.system_prompt, client, config + ) + self.prompt_tokens += prompt_tokens + self.completion_tokens += completion_tokens + + return response + + @property + def tokenizer(self): + """Defaults to gpt4, as that seems to work with most custom endpoints. Can be overridden if required in the constructor""" + try: + import tiktoken + except ImportError: + raise ImportError( + "The `tiktoken` library needs to be installed in order to choose `outlines.models.openai` with `is_in`" + ) + + return tiktoken.encoding_for_model(self.encoding) + + +openai_compatible_api = OpenAICompatibleAPI