Skip to content

Commit

Permalink
Merge pull request #1523 from stanfordnlp/improve_adapter_raw_demos
Browse files Browse the repository at this point in the history
Various improvements for adapters & LMs
  • Loading branch information
okhat authored Sep 22, 2024
2 parents e5a7dba + e09e936 commit 34399b7
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 27 deletions.
15 changes: 15 additions & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,18 @@

configure = settings.configure
context = settings.context


import dspy.teleprompt

LabeledFewShot = dspy.teleprompt.LabeledFewShot
BootstrapFewShot = dspy.teleprompt.BootstrapFewShot
BootstrapFewShotWithRandomSearch = dspy.teleprompt.BootstrapFewShotWithRandomSearch
BootstrapRS = dspy.teleprompt.BootstrapFewShotWithRandomSearch
COPRO = dspy.teleprompt.COPRO
MIPROv2 = dspy.teleprompt.MIPROv2
Ensemble = dspy.teleprompt.Ensemble


def inspect_history(*args, **kwargs):
return settings.lm.inspect_history(*args, **kwargs)
83 changes: 61 additions & 22 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
import textwrap
from .base import Adapter

field_header_pattern = re.compile(r'\[\[\[ ### (\w+) ### \]\]\]')
field_header_pattern = re.compile(r'\[\[ ## (\w+) ## \]\]')


class ChatAdapter(Adapter):
Expand All @@ -11,21 +12,22 @@ def __init__(self):
def format(self, signature, demos, inputs):
messages = []

# TODO: Extract `raw_demos` out of `demos`, i.e. demos where some of the output_fields are not filled in.
# raw_demos = [demo for demo in demos if not all(k in demo for k in signature.output_fields)]
# demos = [demo for demo in demos if demo not in raw_demos]
# Extract demos where some of the output_fields are not filled in.
incomplete_demos = [demo for demo in demos if not all(k in demo for k in signature.fields)]
complete_demos = [demo for demo in demos if demo not in incomplete_demos]
incomplete_demos = [demo for demo in incomplete_demos \
if any(k in demo for k in signature.input_fields) and \
any(k in demo for k in signature.output_fields)]

messages.append({"role": "system", "content": prepare_instructions(signature)})
# messages.append({"role": "system", "content": prepare_instructions(signature, raw_demos)})
demos = incomplete_demos + complete_demos

# TODO: Remove the raw_demos from demos.
messages.append({"role": "system", "content": prepare_instructions(signature)})

for demo in demos:
output_fields_, demo_ = list(signature.output_fields.keys()) + ['completed'], {**demo, 'completed': ''}
messages.append({"role": "user", "content": format_chat_turn(signature.input_fields.keys(), demo)})
messages.append({"role": "assistant", "content": format_chat_turn(output_fields_, demo_)})
messages.append(format_turn(signature, demo, role="user", incomplete=demo in incomplete_demos))
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos))

messages.append({"role": "user", "content": format_chat_turn(signature.input_fields.keys(), inputs)})
messages.append(format_turn(signature, inputs, role="user"))

return messages

Expand All @@ -48,16 +50,52 @@ def parse(self, signature, completion):

return fields

def format_blob(blob):
if '\n' not in blob and "«" not in blob and "»" not in blob: return f"«{blob}»"

modified_blob = blob.replace('\n', '\n ')
return f"«««\n {modified_blob}\n»»»"


def format_list(items):
if len(items) == 0: return "N/A"
if len(items) == 1: return format_blob(items[0])

return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(items)])


def format_fields(fields):
return '\n\n'.join([f"[[[ ### {k} ### ]]]\n{v}" for k, v in fields.items()]).strip()
output = []
for k, v in fields.items():
v = v if not isinstance(v, list) else format_list(v)
output.append(f"[[ ## {k} ## ]]\n{v}")

def format_chat_turn(field_names, values):
# TODO: Reinstate validation after dealing with raw_demos in the system messages.
# if not set(values).issuperset(set(field_names)):
# raise ValueError(f"Expected {field_names} but got {values.keys()}")
return '\n\n'.join(output).strip()



def format_turn(signature, values, role, incomplete=False):
content = []

if role == "user":
field_names = signature.input_fields.keys()
if incomplete:
content.append("This is an example of the task, though some input or output fields are not supplied.")
else:
field_names, values = list(signature.output_fields.keys()) + ['completed'], {**values, 'completed': ''}

if not incomplete:
if not set(values).issuperset(set(field_names)):
raise ValueError(f"Expected {field_names} but got {values.keys()}")

return format_fields({k: values.get(k, "Not supplied for this particular example.") for k in field_names})
content.append(format_fields({k: values.get(k, "Not supplied for this particular example.") for k in field_names}))

if role == "user":
content.append("Respond with the corresponding output fields, starting with the field " +
", then ".join(f"`{f}`" for f in signature.output_fields) +
", and then ending with the marker for `completed`.")

return {"role": role, "content": '\n\n'.join(content).strip()}

def enumerate_fields(fields):
parts = []
Expand All @@ -78,12 +116,13 @@ def prepare_instructions(signature):
parts.append(format_fields({f : f"{{{f}}}" for f in signature.output_fields}))
parts.append(format_fields({'completed' : ""}))

objective = ('\n' + ' ' * 8).join([''] + signature.instructions.splitlines())
instructions = textwrap.dedent(signature.instructions)
objective = ('\n' + ' ' * 8).join([''] + instructions.splitlines())
parts.append(f"In adhering to this structure, your objective is: {objective}")

parts.append("You will receive some input fields in each interaction. " +
"Respond only with the corresponding output fields, starting with the field " +
", then ".join(f"`{f}`" for f in signature.output_fields) +
", and then ending with the marker for `completed`.")
# parts.append("You will receive some input fields in each interaction. " +
# "Respond only with the corresponding output fields, starting with the field " +
# ", then ".join(f"`{f}`" for f in signature.output_fields) +
# ", and then ending with the marker for `completed`.")

return '\n\n'.join(parts).strip()
10 changes: 6 additions & 4 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import ujson
import functools
from pathlib import Path

try:
import warnings
Expand All @@ -9,7 +10,8 @@
import litellm

from litellm.caching import Cache
litellm.cache = Cache(disk_cache_dir=".dspy_cache", type="disk")
disk_cache_dir = os.environ.get('DSPY_CACHEDIR') or os.path.join(Path.home(), '.dspy_cache')
litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk")

except ImportError:
class LitellmPlaceholder:
Expand All @@ -18,11 +20,11 @@ def __getattr__(self, _): raise ImportError("The LiteLLM package is not installe
litellm = LitellmPlaceholder()

class LM:
def __init__(self, model, model_type='chat', temperature=0.0, cache=True, **kwargs):
def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, cache=True, **kwargs):
self.model = model
self.model_type = model_type
self.cache = cache
self.kwargs = dict(temperature=temperature, **kwargs)
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
self.history = []

def __call__(self, prompt=None, messages=None, **kwargs):
Expand Down Expand Up @@ -90,7 +92,7 @@ def _red(text: str, end: str = "\n"):
def _inspect_history(lm, n: int = 1):
"""Prints the last n prompts and their completions."""

for item in reversed(lm.history[-n:]):
for item in lm.history[-n:]:
messages = item["messages"] or [{"role": "user", "content": item['prompt']}]
outputs = item["outputs"]

Expand Down
7 changes: 6 additions & 1 deletion dspy/predict/chain_of_thought.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@ def __init__(self, signature, rationale_type=None, activated=True, **config):
desc = f"${{produce the {last_key}}}. We ..."

rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc)

# Add "rationale" field to the output signature.
extended_signature = signature.prepend("rationale", rationale_type, type_=str)
if isinstance(dspy.settings.lm, dspy.LM):
extended_signature = signature.prepend("reasoning", rationale_type, type_=str)
else:
extended_signature = signature.prepend("rationale", rationale_type, type_=str)

self._predict = dspy.Predict(extended_signature, **config)
self._predict.extended_signature = extended_signature

Expand Down
1 change: 1 addition & 0 deletions dspy/teleprompt/random_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None
best_program = program

scores.append(score)
print(f"Scores so far: {scores}")
print(f"Best score so far: {max(scores)}")

score_data.append((score, subscores, seed, program))
Expand Down

0 comments on commit 34399b7

Please sign in to comment.