Skip to content

Commit

Permalink
Modify the openai model to conform to the new openai sdk v1.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard authored and rlouf committed Nov 15, 2023
1 parent 44f79d0 commit b60bb7a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 70 deletions.
124 changes: 56 additions & 68 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,57 @@
"""Integration with OpenAI's API."""
import functools
import os
from typing import Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union

import numpy as np
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)

import outlines
from outlines.caching import cache

__all__ = ["OpenAIAPI", "openai"]

if TYPE_CHECKING:
from openai import AsyncOpenAI


class OpenAIAPI:
def __init__(
self,
model_name: str,
api_key: Optional[str] = os.getenv("OPENAI_API_KEY"),
temperature: float = 1.0,
max_retries: int = 6,
):
self.api_key = api_key
try:
import openai
except ImportError:
raise ImportError(
"The `openai` library needs to be installed in order to use Outlines' OpenAI integration."
)

try:
self.client = openai.AsyncOpenAI(api_key=api_key, max_retries=max_retries)
except openai.OpenAIError as e:
raise e

@error_handler
@cache
async def cached_call_completion_api(*args, **kwargs):
response = await call_completion_api(self.client, *args, **kwargs)
return response

@error_handler
@cache
async def cached_call_chat_completion_api(*args, **kwargs):
response = await call_chat_completion_api(self.client, *args, **kwargs)
return response

if "text-" in model_name:
call_api = call_completion_api
call_api = cached_call_completion_api
format_prompt = lambda x: x
extract_choice = lambda x: x["text"]
elif "gpt-" in model_name:
call_api = call_chat_completion_api
call_api = cached_call_chat_completion_api
format_prompt = lambda x: [{"role": "user", "content": x}]
extract_choice = lambda x: x["message"]["content"]
else:
Expand All @@ -45,7 +65,7 @@ async def generate_base(
max_tokens: int,
stop_at: List[Optional[str]],
samples: int,
api_key: str,
client: openai.AsyncOpenAI,
) -> str:
responses = await call_api(
model_name,
Expand All @@ -55,7 +75,6 @@ async def generate_base(
stop_at,
{},
samples,
api_key,
)

if samples == 1:
Expand All @@ -69,7 +88,11 @@ async def generate_base(

@functools.partial(outlines.vectorize, signature="(),(),(m),(),()->(s)")
async def generate_choice(
prompt: str, max_tokens: int, is_in: List[str], samples: int, api_key: str
prompt: str,
max_tokens: int,
is_in: List[str],
samples: int,
client: openai.AsyncOpenAI,
) -> Union[List[str], str]:
"""Generate a sequence that must be one of many options.
Expand Down Expand Up @@ -117,7 +140,6 @@ async def generate_choice(
[],
mask,
samples,
api_key,
)
decoded.append(extract_choice(response["choices"][0]))
prompt = prompt + "".join(decoded)
Expand All @@ -141,15 +163,11 @@ def __call__(
if is_in is not None and stop_at:
raise TypeError("You cannot set `is_in` and `stop_at` at the same time.")
elif is_in is not None:
return self.generate_choice(
prompt, max_tokens, is_in, samples, self.api_key
)
return self.generate_choice(prompt, max_tokens, is_in, samples, self.client)
else:
if isinstance(stop_at, str):
stop_at = [stop_at]
return self.generate_base(
prompt, max_tokens, stop_at, samples, self.api_key
)
return self.generate_base(prompt, max_tokens, stop_at, samples, self.client)


openai = OpenAIAPI
Expand All @@ -164,93 +182,63 @@ def call(*args, **kwargs):
try:
return api_call_fn(*args, **kwargs)
except (
openai.error.RateLimitError,
openai.error.Timeout,
openai.error.TryAgain,
openai.error.APIConnectionError,
openai.error.ServiceUnavailableError,
openai.APITimeoutError,
openai.InternalServerError,
openai.RateLimitError,
) as e:
raise OSError(f"Could not connect to the OpenAI API: {e}")
except (
openai.error.AuthenticationError,
openai.error.PermissionError,
openai.error.InvalidRequestError,
openai.error.InvalidAPIType,
openai.AuthenticationError,
openai.BadRequestError,
openai.ConflictError,
openai.PermissionDeniedError,
openai.NotFoundError,
openai.UnprocessableEntityError,
) as e:
raise e

return call


retry_config = {
"wait": wait_random_exponential(min=1, max=30),
"stop": stop_after_attempt(6),
"retry": retry_if_exception_type(OSError),
}


@retry(**retry_config)
@error_handler
@cache
async def call_completion_api(
client: "AsyncOpenAI",
model: str,
prompt: str,
max_tokens: int,
temperature: float,
stop_sequences: List[str],
logit_bias: Dict[str, int],
num_samples: int,
api_key: str,
):
try:
import openai
except ImportError:
raise ImportError(
"The `openai` library needs to be installed in order to use Outlines' OpenAI integration."
)

response = await openai.Completion.acreate(
engine=model,
) -> dict:
response = await client.completions.create(
model=model,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
stop=list(stop_sequences) if len(stop_sequences) > 0 else None,
logit_bias=logit_bias,
n=int(num_samples),
api_key=api_key,
)
return response
return response.model_dump()


@retry(**retry_config)
@error_handler
@cache
async def call_chat_completion_api(
client: "AsyncOpenAI",
model: str,
messages: List[Dict[str, str]],
max_tokens: int,
temperature: float,
stop_sequences: List[str],
logit_bias: Dict[str, int],
num_samples: int,
api_key: str,
):
try:
import openai
except ImportError:
raise ImportError(
"The `openai` library needs to be installed in order to use Outlines' OpenAI integration."
)

response = await openai.ChatCompletion.acreate(
) -> dict:
response = await client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
stop=list(stop_sequences) if len(stop_sequences) > 0 else None,
logit_bias=logit_bias,
n=int(num_samples),
api_key=api_key,
)

return response
return response.model_dump()
8 changes: 6 additions & 2 deletions tests/text/generate/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
import pytest
import torch

from outlines import models
from outlines.models import OpenAIAPI
from outlines.models.tokenizer import Tokenizer
from outlines.text.generate.sequence import Sequence


def test_openai_error():
model = models.openai("text-davinci-003")
class Mock(OpenAIAPI):
def __init__(self):
pass

model = Mock()
with pytest.raises(TypeError):
Sequence(model)

Expand Down

0 comments on commit b60bb7a

Please sign in to comment.