Skip to content

Commit

Permalink
add json generation support for openai
Browse files Browse the repository at this point in the history
add json generation support for openai
  • Loading branch information
JerryKwan committed Jul 22, 2024
1 parent f6a6c29 commit c50091c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 9 deletions.
40 changes: 34 additions & 6 deletions outlines/generate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import singledispatch
from typing import Callable, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, create_model

from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
from outlines.generate.api import SequenceGenerator
Expand Down Expand Up @@ -71,8 +71,36 @@ def json(
@json.register(OpenAI)
def json_openai(
model, schema_object: Union[str, object, Callable], sampler: Sampler = multinomial()
):
raise NotImplementedError(
"Cannot use JSON Schema-structure generation with an OpenAI model "
+ "due to the limitations of the OpenAI API"
)
) -> Callable:
response_model = None
if not isinstance(sampler, multinomial):
raise NotImplementedError(
r"The OpenAI API does not support any other sampling algorithm "
+ "that the multinomial sampler."
)
if isinstance(schema_object, type(BaseModel)):
response_model = schema_object
schema = pyjson.dumps(schema_object.model_json_schema())
elif callable(schema_object):
schema = pyjson.dumps(get_schema_from_signature(schema_object))
elif isinstance(schema_object, str):
schema = schema_object
else:
raise ValueError(
f"Cannot parse schema {schema_object}. The schema must be either "
+ "a Pydantic object, a function or a string that contains the JSON "
+ "Schema specification"
)
if response_model is None:
response_model = create_model(schema)

def generate_json(prompt: str, max_tokens: int = 1000, max_retries=3):
response = model.generate_json(prompt, schema, max_tokens)
# parse the response to pydantic object
try:
if response_model::
return response_model.model_validate_json(response)
except Exception as e:
raise e

return generate_json
33 changes: 30 additions & 3 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from dataclasses import asdict, dataclass, field, replace
from itertools import zip_longest
from textwrap import dedent
from typing import Callable, Dict, List, Optional, Set, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -250,9 +251,36 @@ def generate_choice(

return choice

def generate_json(self):
def generate_json(
self,
prompt: str,
schema: str,
max_tokens: Optional[int] = None,
) -> str:
"""Call the OpenAI API to generate a JSON object."""
raise NotImplementedError
# raise NotImplementedError
# We need to massage the prompt a bit in order to get the response we want in a json format
config = replace(
self.config, max_tokens=max_tokens, response_format={"type": "json_object"}
)

system_prompt = dedent(
f"""
As a genius expert, your task is to understand the content and provide
the parsed objects in json that match the following json_schema:\n
{schema}
Make sure to return an instance of the JSON, not the schema itself
"""
)

response, prompt_tokens, completion_tokens = generate_chat(
prompt, system_prompt, self.client, config
)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
return "".join(response)

def __str__(self):
return self.__class__.__name__ + " API"
Expand Down Expand Up @@ -309,7 +337,6 @@ async def call_api(prompt, system_prompt, config):
[responses["choices"][i]["message"]["content"] for i in range(config.n)]
)
usage = responses["usage"]

return results, usage["prompt_tokens"], usage["completion_tokens"]


Expand Down

0 comments on commit c50091c

Please sign in to comment.