diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 39a7ff8b9..994c8bc41 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -21,9 +21,7 @@ def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, c self.model = model self.model_type = model_type self.cache = cache - self.temperature = temperature - self.max_tokens = max_tokens - self.kwargs = kwargs + self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.history = [] if "o1-" in model: @@ -35,7 +33,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): # Build the request. cache = kwargs.pop("cache", self.cache) messages = messages or [{"role": "user", "content": prompt}] - kwargs = {"temperature": self.temperature, "max_tokens": self.max_tokens, **self.kwargs, **kwargs} + kwargs = {**self.kwargs, **kwargs} # Make the request and handle LRU & disk caching. if self.model_type == "chat": @@ -59,6 +57,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): model_type=self.model_type, ) self.history.append(entry) + return outputs def inspect_history(self, n: int = 1):