Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: dbczumar <corey.zumar@databricks.com>
  • Loading branch information
dbczumar committed Oct 10, 2024
1 parent 8a8f800 commit b5ebfa5
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions dspy/utils/dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from dsp.modules import LM as DSPLM
from dsp.utils.utils import dotdict
from dspy.adapters.chat_adapter import field_header_pattern, format_fields
from dspy.adapters.chat_adapter import FieldInfoWithName, field_header_pattern, format_fields
from dspy.clients.lm import LM
from dspy.signatures.field import OutputField


class DSPDummyLM(DSPLM):
Expand Down Expand Up @@ -170,6 +171,10 @@ def _use_example(self, messages):
return output["content"]

def __call__(self, prompt=None, messages=None, **kwargs):
def format_answer_field(field_name, answer):
field = FieldInfoWithName(name=field_name, info=OutputField())
return format_fields(fields_with_values={field: answer})

# Build the request.
outputs = []
for _ in range(kwargs.get("n", 1)):
Expand All @@ -181,12 +186,12 @@ def __call__(self, prompt=None, messages=None, **kwargs):
elif isinstance(self.answers, dict):
outputs.append(
next(
(format_fields(v) for k, v in self.answers.items() if k in messages[-1]["content"]),
(format_answer_field(k, v) for k, v in self.answers.items() if k in messages[-1]["content"]),
"No more responses",
)
)
else:
outputs.append(format_fields(next(self.answers, {"answer": "No more responses"})))
outputs.append(format_answer_field(**next(self.answers, {"answer": "No more responses"})))

# Logging, with removed api key & where `cost` is None on cache hit.
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}
Expand Down

0 comments on commit b5ebfa5

Please sign in to comment.