From e09e93637c4612be9154d3e0ab28ee95f5bbafc6 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Sat, 21 Sep 2024 22:12:12 -0700 Subject: [PATCH] Various improvements for adapters & LMs --- dspy/__init__.py | 15 ++++++ dspy/adapters/chat_adapter.py | 83 +++++++++++++++++++++++--------- dspy/clients/lm.py | 10 ++-- dspy/predict/chain_of_thought.py | 7 ++- dspy/teleprompt/random_search.py | 1 + 5 files changed, 89 insertions(+), 27 deletions(-) diff --git a/dspy/__init__.py b/dspy/__init__.py index 46d2f8169..9ad0f1f74 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -58,3 +58,18 @@ configure = settings.configure context = settings.context + + +import dspy.teleprompt + +LabeledFewShot = dspy.teleprompt.LabeledFewShot +BootstrapFewShot = dspy.teleprompt.BootstrapFewShot +BootstrapFewShotWithRandomSearch = dspy.teleprompt.BootstrapFewShotWithRandomSearch +BootstrapRS = dspy.teleprompt.BootstrapFewShotWithRandomSearch +COPRO = dspy.teleprompt.COPRO +MIPROv2 = dspy.teleprompt.MIPROv2 +Ensemble = dspy.teleprompt.Ensemble + + +def inspect_history(*args, **kwargs): + return settings.lm.inspect_history(*args, **kwargs) \ No newline at end of file diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 627f8322a..9acfafcc6 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -1,7 +1,8 @@ import re +import textwrap from .base import Adapter -field_header_pattern = re.compile(r'\[\[\[ ### (\w+) ### \]\]\]') +field_header_pattern = re.compile(r'\[\[ ## (\w+) ## \]\]') class ChatAdapter(Adapter): @@ -11,21 +12,22 @@ def __init__(self): def format(self, signature, demos, inputs): messages = [] - # TODO: Extract `raw_demos` out of `demos`, i.e. demos where some of the output_fields are not filled in. - # raw_demos = [demo for demo in demos if not all(k in demo for k in signature.output_fields)] - # demos = [demo for demo in demos if demo not in raw_demos] + # Extract demos where some of the output_fields are not filled in. + incomplete_demos = [demo for demo in demos if not all(k in demo for k in signature.fields)] + complete_demos = [demo for demo in demos if demo not in incomplete_demos] + incomplete_demos = [demo for demo in incomplete_demos \ + if any(k in demo for k in signature.input_fields) and \ + any(k in demo for k in signature.output_fields)] - messages.append({"role": "system", "content": prepare_instructions(signature)}) - # messages.append({"role": "system", "content": prepare_instructions(signature, raw_demos)}) + demos = incomplete_demos + complete_demos - # TODO: Remove the raw_demos from demos. + messages.append({"role": "system", "content": prepare_instructions(signature)}) for demo in demos: - output_fields_, demo_ = list(signature.output_fields.keys()) + ['completed'], {**demo, 'completed': ''} - messages.append({"role": "user", "content": format_chat_turn(signature.input_fields.keys(), demo)}) - messages.append({"role": "assistant", "content": format_chat_turn(output_fields_, demo_)}) + messages.append(format_turn(signature, demo, role="user", incomplete=demo in incomplete_demos)) + messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos)) - messages.append({"role": "user", "content": format_chat_turn(signature.input_fields.keys(), inputs)}) + messages.append(format_turn(signature, inputs, role="user")) return messages @@ -48,16 +50,52 @@ def parse(self, signature, completion): return fields +def format_blob(blob): + if '\n' not in blob and "«" not in blob and "»" not in blob: return f"«{blob}»" + + modified_blob = blob.replace('\n', '\n ') + return f"«««\n {modified_blob}\n»»»" + + +def format_list(items): + if len(items) == 0: return "N/A" + if len(items) == 1: return format_blob(items[0]) + + return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(items)]) + def format_fields(fields): - return '\n\n'.join([f"[[[ ### {k} ### ]]]\n{v}" for k, v in fields.items()]).strip() + output = [] + for k, v in fields.items(): + v = v if not isinstance(v, list) else format_list(v) + output.append(f"[[ ## {k} ## ]]\n{v}") -def format_chat_turn(field_names, values): - # TODO: Reinstate validation after dealing with raw_demos in the system messages. - # if not set(values).issuperset(set(field_names)): - # raise ValueError(f"Expected {field_names} but got {values.keys()}") + return '\n\n'.join(output).strip() + + + +def format_turn(signature, values, role, incomplete=False): + content = [] + + if role == "user": + field_names = signature.input_fields.keys() + if incomplete: + content.append("This is an example of the task, though some input or output fields are not supplied.") + else: + field_names, values = list(signature.output_fields.keys()) + ['completed'], {**values, 'completed': ''} + + if not incomplete: + if not set(values).issuperset(set(field_names)): + raise ValueError(f"Expected {field_names} but got {values.keys()}") - return format_fields({k: values.get(k, "Not supplied for this particular example.") for k in field_names}) + content.append(format_fields({k: values.get(k, "Not supplied for this particular example.") for k in field_names})) + + if role == "user": + content.append("Respond with the corresponding output fields, starting with the field " + + ", then ".join(f"`{f}`" for f in signature.output_fields) + + ", and then ending with the marker for `completed`.") + + return {"role": role, "content": '\n\n'.join(content).strip()} def enumerate_fields(fields): parts = [] @@ -78,12 +116,13 @@ def prepare_instructions(signature): parts.append(format_fields({f : f"{{{f}}}" for f in signature.output_fields})) parts.append(format_fields({'completed' : ""})) - objective = ('\n' + ' ' * 8).join([''] + signature.instructions.splitlines()) + instructions = textwrap.dedent(signature.instructions) + objective = ('\n' + ' ' * 8).join([''] + instructions.splitlines()) parts.append(f"In adhering to this structure, your objective is: {objective}") - parts.append("You will receive some input fields in each interaction. " + - "Respond only with the corresponding output fields, starting with the field " + - ", then ".join(f"`{f}`" for f in signature.output_fields) + - ", and then ending with the marker for `completed`.") + # parts.append("You will receive some input fields in each interaction. " + + # "Respond only with the corresponding output fields, starting with the field " + + # ", then ".join(f"`{f}`" for f in signature.output_fields) + + # ", and then ending with the marker for `completed`.") return '\n\n'.join(parts).strip() diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index a442cbe94..ccc730f2b 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,6 +1,7 @@ import os import ujson import functools +from pathlib import Path try: import warnings @@ -9,7 +10,8 @@ import litellm from litellm.caching import Cache - litellm.cache = Cache(disk_cache_dir=".dspy_cache", type="disk") + disk_cache_dir = os.environ.get('DSPY_CACHEDIR') or os.path.join(Path.home(), '.dspy_cache') + litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk") except ImportError: class LitellmPlaceholder: @@ -18,11 +20,11 @@ def __getattr__(self, _): raise ImportError("The LiteLLM package is not installe litellm = LitellmPlaceholder() class LM: - def __init__(self, model, model_type='chat', temperature=0.0, cache=True, **kwargs): + def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, cache=True, **kwargs): self.model = model self.model_type = model_type self.cache = cache - self.kwargs = dict(temperature=temperature, **kwargs) + self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.history = [] def __call__(self, prompt=None, messages=None, **kwargs): @@ -90,7 +92,7 @@ def _red(text: str, end: str = "\n"): def _inspect_history(lm, n: int = 1): """Prints the last n prompts and their completions.""" - for item in reversed(lm.history[-n:]): + for item in lm.history[-n:]: messages = item["messages"] or [{"role": "user", "content": item['prompt']}] outputs = item["outputs"] diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index e7dc8c8aa..74656697c 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -27,8 +27,13 @@ def __init__(self, signature, rationale_type=None, activated=True, **config): desc = f"${{produce the {last_key}}}. We ..." rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc) + # Add "rationale" field to the output signature. - extended_signature = signature.prepend("rationale", rationale_type, type_=str) + if isinstance(dspy.settings.lm, dspy.LM): + extended_signature = signature.prepend("reasoning", rationale_type, type_=str) + else: + extended_signature = signature.prepend("rationale", rationale_type, type_=str) + self._predict = dspy.Predict(extended_signature, **config) self._predict.extended_signature = extended_signature diff --git a/dspy/teleprompt/random_search.py b/dspy/teleprompt/random_search.py index c1aeef54c..57bd38150 100644 --- a/dspy/teleprompt/random_search.py +++ b/dspy/teleprompt/random_search.py @@ -130,6 +130,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None best_program = program scores.append(score) + print(f"Scores so far: {scores}") print(f"Best score so far: {max(scores)}") score_data.append((score, subscores, seed, program))