diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index d60a54d25..1c611ca34 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -30,8 +30,10 @@ def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, c self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.history = [] - # Exceptions for O-1 models, which requires a temperature of 1.0. - if "o1-" in model: self.kwargs['temperature'] = 1.0 + if "o1-" in model: + assert max_tokens >= 5000 and temperature == 1.0, \ + "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`" + def __call__(self, prompt=None, messages=None, **kwargs): # Build the request.