diff --git a/examples/babyagi.py b/examples/babyagi.py index e32131253..d21630da0 100644 --- a/examples/babyagi.py +++ b/examples/babyagi.py @@ -28,9 +28,6 @@ def perform_task_ppt(objective: str, task: str): """ -perform_task = text.function(model, perform_task_ppt) - - ##################### # Create a new task # ##################### @@ -67,9 +64,6 @@ def create_tasks_fmt(result: str) -> List[str]: return task_list -create_tasks = text.function(model, create_tasks_ppt, create_tasks_fmt) - - ######################## # Prioritize new tasks # ######################## @@ -104,9 +98,6 @@ def prioritize_tasks_fmt(result: str): return task_list -prioritize_tasks = text.function(model, prioritize_tasks_ppt, prioritize_tasks_fmt) - - objective = "Becoming rich while doing nothing." first_task = { "task_id": 1, @@ -134,18 +125,23 @@ def one_cycle(objective: str, task_list, next_task_id: int): """ task = task_list.popleft() - result = perform_task(objective, task) - new_tasks = create_tasks( + + prompt = perform_task_ppt(objective, task) + result = model(prompt) + + prompt = create_tasks_ppt( objective, first_task["task_name"], result, [first_task["task_name"]] ) + new_tasks = model(prompt) for task in new_tasks: next_task_id += 1 task_list.append({"task_id": next_task_id, "task_name": task}) - prioritized_tasks = prioritize_tasks( + prompt = prioritize_tasks_ppt( objective, [task["task_name"] for task in task_list], next_task_id ) + prioritized_tasks = model(prompt) return task, result, prioritized_tasks, next_task_id diff --git a/examples/math_generate_code.py b/examples/math_generate_code.py index 069211583..b2b25a94f 100644 --- a/examples/math_generate_code.py +++ b/examples/math_generate_code.py @@ -34,11 +34,7 @@ def execute_code(code): return result -answer_with_code = text.function( - models.text_completion.openai("text-davinci-003"), - answer_with_code_prompt, - execute_code, -) - -result = answer_with_code(question, examples) +prompt = answer_with_code_prompt(question, examples) +answer = models.text_completion.openai("text-davinci-003")(prompt) +result = execute_code(answer) print(f"It takes Carla {result:.0f} minutes to download the file.") diff --git a/outlines/text/__init__.py b/outlines/text/__init__.py index 8870c7a1f..b1ae976c9 100644 --- a/outlines/text/__init__.py +++ b/outlines/text/__init__.py @@ -1,3 +1,2 @@ -from .functions import function from .generate import continuation from .prompts import prompt, render diff --git a/outlines/text/functions.py b/outlines/text/functions.py deleted file mode 100644 index c3c6bfa50..000000000 --- a/outlines/text/functions.py +++ /dev/null @@ -1,63 +0,0 @@ -import functools -from dataclasses import dataclass -from typing import Callable, Optional, Union - -from pydantic import BaseModel - -FunctionType = type(lambda x: None) -BaseModelType = type(BaseModel) - - -@dataclass -class function: - """Represents a function that uses a language model to generate its output. - - When called, the `function` instance passes the arguments to the prompt - function, the rendered prompt is passed to the language model, and its - result to an (optional) validation function. - - Attributes - ---------- - model - A function that takes a string and returns a string that contains the - model's return value. - prompt - A prompt-generating function. - validator - A function that takes the output of the language model, parses it and - returns it in a normalized format. - - """ - - model: Callable - prompt: Callable - validator: Optional[Union[Callable, BaseModel]] = None - - def __call__(self, *args, **kwargs): - rendered_prompt = self.prompt(*args, **kwargs) - result = self.model(rendered_prompt) - validated_result = validate(self.validator, result) - return validated_result - - -@functools.singledispatch -def validate(validator, result): - if validator is not None: - raise NotImplementedError( - f"Cannot validate the input with validator of type {type(validator)}" - ) - else: - return result - - -@validate.register(BaseModelType) -def validate_pydantic(validator, result): - if hasattr(validator, "model_validate_json"): - return validator.model_validate_json(result) - else: # pragma: no cover - return validator.parse_raw(result) - - -@validate.register(FunctionType) -def validate_function(validator, result): - return validator(result) diff --git a/tests/text/test_function.py b/tests/text/test_function.py deleted file mode 100644 index ac473b613..000000000 --- a/tests/text/test_function.py +++ /dev/null @@ -1,51 +0,0 @@ -import json - -from pydantic import BaseModel - -import outlines.text as text - - -def test_function_no_validator(): - def passthrough_model(prompt: str): - return prompt - - @text.prompt - def prompt(query: str): - "{{query}}" - - fn = text.function(passthrough_model, prompt) - assert fn("Hello") == "Hello" - - -def test_function_fn_validator(): - def constant_model(_): - return "[1, 2, 3]" - - @text.prompt - def prompt(query: str): - "{{query}}" - - def validator(result): - return json.loads(result) - - fn = text.function(constant_model, prompt, validator) - assert fn("Hello") == [1, 2, 3] - - -def test_function_pydantic_validator(): - class Response(BaseModel): - thought: str - command: str - - def constant_model(_): - return '{"thought": "test thought", "command": "resume"}' - - @text.prompt - def prompt(query: str): - "{{query}}" - - fn = text.function(constant_model, prompt, Response) - result = fn("Hello") - assert isinstance(result, Response) - assert result.thought == "test thought" - assert result.command == "resume"