From fee5aa72178b57414f409bfc31e752080f4676fc Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Mon, 7 Oct 2024 17:18:27 -0700 Subject: [PATCH 1/3] format --- dspy/clients/lm.py | 80 +++++++++++++++++++++++++--------------------- 1 file changed, 43 insertions(+), 37 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 86c8062c8..6dc32940e 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,51 +1,46 @@ -import os -import uuid -import ujson import functools -from pathlib import Path +import os +import uuid from datetime import datetime +from pathlib import Path -try: - import warnings - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - import litellm - litellm.telemetry = False +import litellm +import ujson +from litellm.caching import Cache - from litellm.caching import Cache - disk_cache_dir = os.environ.get('DSPY_CACHEDIR') or os.path.join(Path.home(), '.dspy_cache') - litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk") +disk_cache_dir = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") +litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk") +litellm.telemetry = False -except ImportError: - class LitellmPlaceholder: - def __getattr__(self, _): raise ImportError("The LiteLLM package is not installed. Run `pip install litellm`.") +if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - litellm = LitellmPlaceholder() class LM: - def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, cache=True, **kwargs): + def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, cache=True): self.model = model self.model_type = model_type self.cache = cache - self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) + self.temperature = temperature + self.max_tokens = max_tokens self.history = [] if "o1-" in model: - assert max_tokens >= 5000 and temperature == 1.0, \ - "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`" - - + assert ( + max_tokens >= 5000 and temperature == 1.0 + ), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`" + def __call__(self, prompt=None, messages=None, **kwargs): # Build the request. cache = kwargs.pop("cache", self.cache) messages = messages or [{"role": "user", "content": prompt}] - kwargs = {**self.kwargs, **kwargs} + kwargs = {"temperature": self.temperature, "max_tokens": self.max_tokens, **kwargs} # Make the request and handle LRU & disk caching. - if self.model_type == "chat": completion = cached_litellm_completion if cache else litellm_completion - else: completion = cached_litellm_text_completion if cache else litellm_text_completion + if self.model_type == "chat": + completion = cached_litellm_completion if cache else litellm_completion + else: + completion = cached_litellm_text_completion if cache else litellm_text_completion response = completion(ujson.dumps(dict(model=self.model, messages=messages, **kwargs))) outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] @@ -64,7 +59,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): ) self.history.append(entry) return outputs - + def inspect_history(self, n: int = 1): _inspect_history(self, n) @@ -73,14 +68,17 @@ def inspect_history(self, n: int = 1): def cached_litellm_completion(request): return litellm_completion(request, cache={"no-cache": False, "no-store": False}) + def litellm_completion(request, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) return litellm.completion(cache=cache, **kwargs) + @functools.lru_cache(maxsize=None) def cached_litellm_text_completion(request): return litellm_text_completion(request, cache={"no-cache": False, "no-store": False}) + def litellm_text_completion(request, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) @@ -93,32 +91,40 @@ def litellm_text_completion(request, cache={"no-cache": True, "no-store": True}) api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") # Build the prompt from the messages. - prompt = '\n\n'.join([x['content'] for x in kwargs.pop("messages")] + ['BEGIN RESPONSE:']) + prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]) - return litellm.text_completion(cache=cache, model=f'text-completion-openai/{model}', api_key=api_key, - api_base=api_base, prompt=prompt, **kwargs) + return litellm.text_completion( + cache=cache, + model=f"text-completion-openai/{model}", + api_key=api_key, + api_base=api_base, + prompt=prompt, + **kwargs, + ) def _green(text: str, end: str = "\n"): return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end + def _red(text: str, end: str = "\n"): return "\x1b[31m" + str(text) + "\x1b[0m" + end + def _inspect_history(lm, n: int = 1): """Prints the last n prompts and their completions.""" for item in lm.history[-n:]: - messages = item["messages"] or [{"role": "user", "content": item['prompt']}] + messages = item["messages"] or [{"role": "user", "content": item["prompt"]}] outputs = item["outputs"] timestamp = item.get("timestamp", "Unknown time") print("\n\n\n") print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n") - + for msg in messages: print(_red(f"{msg['role'].capitalize()} message:")) - print(msg['content'].strip()) + print(msg["content"].strip()) print("\n") print(_red("Response:")) @@ -127,5 +133,5 @@ def _inspect_history(lm, n: int = 1): if len(outputs) > 1: choices_text = f" \t (and {len(outputs)-1} other completions)" print(_red(choices_text, end="")) - - print("\n\n\n") \ No newline at end of file + + print("\n\n\n") From 4ef3841cf256ad4444e7d5522810ffafac0bbdaa Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Mon, 7 Oct 2024 22:25:08 -0700 Subject: [PATCH 2/3] add kwargs back --- dspy/clients/lm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 6dc32940e..39a7ff8b9 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -17,12 +17,13 @@ class LM: - def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, cache=True): + def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, cache=True, **kwargs): self.model = model self.model_type = model_type self.cache = cache self.temperature = temperature self.max_tokens = max_tokens + self.kwargs = kwargs self.history = [] if "o1-" in model: @@ -34,7 +35,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): # Build the request. cache = kwargs.pop("cache", self.cache) messages = messages or [{"role": "user", "content": prompt}] - kwargs = {"temperature": self.temperature, "max_tokens": self.max_tokens, **kwargs} + kwargs = {"temperature": self.temperature, "max_tokens": self.max_tokens, **self.kwargs, **kwargs} # Make the request and handle LRU & disk caching. if self.model_type == "chat": From 069808f01c2a6181bd3cd8a42c2178863b67e196 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Tue, 8 Oct 2024 09:07:51 -0700 Subject: [PATCH 3/3] Update lm.py --- dspy/clients/lm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 39a7ff8b9..994c8bc41 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -21,9 +21,7 @@ def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, c self.model = model self.model_type = model_type self.cache = cache - self.temperature = temperature - self.max_tokens = max_tokens - self.kwargs = kwargs + self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.history = [] if "o1-" in model: @@ -35,7 +33,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): # Build the request. cache = kwargs.pop("cache", self.cache) messages = messages or [{"role": "user", "content": prompt}] - kwargs = {"temperature": self.temperature, "max_tokens": self.max_tokens, **self.kwargs, **kwargs} + kwargs = {**self.kwargs, **kwargs} # Make the request and handle LRU & disk caching. if self.model_type == "chat": @@ -59,6 +57,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): model_type=self.model_type, ) self.history.append(entry) + return outputs def inspect_history(self, n: int = 1):