-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(dspy): added compatibility with ANY llamaindex LLM #1225
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,335 @@ | ||
""" | ||
this code implements a wrapper around the llama_index library to emulate a dspy llm | ||
|
||
this allows the llama_index library to be used in the dspy framework since dspy has limited support for LLMs | ||
|
||
This code is a slightly modified copy of dspy/dsp/modules/azure_openai.py | ||
|
||
The way this works is simply by creating a dummy openai client that wraps around any llama_index LLM object and implements .complete and .chat | ||
|
||
tested with python 3.12 | ||
|
||
dspy==0.1.4 | ||
dspy-ai==2.4.9 | ||
llama-index==0.10.35 | ||
llama-index-llms-openai==0.1.18 | ||
|
||
""" | ||
|
||
import json | ||
import logging | ||
from typing import Any, Literal | ||
|
||
from easydict import EasyDict | ||
from llama_index.core.base.llms.types import ChatMessage | ||
from llama_index.core.llms import LLM | ||
|
||
|
||
def LlamaIndexOpenAIClientWrapper(llm: LLM): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To clarify, does this support compatibility with all Llamaindex-LLM integrations? It seems to me like the definition here is only for OpenAI clients, which would leave this change redundant to just using DSPy.OpenAI. If it is not just isolated to OpenAI, lets refactor this function to LlamaIndexLLMWrapper There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is supposed to support all llamaindex llms. However the way i use this is by wrapping the llamaindex llm to look like an |
||
def chat(messages: list[ChatMessage], **kwargs) -> Any: | ||
return llm.chat([ChatMessage(**message) for message in messages], **kwargs) | ||
|
||
def complete(prompt: str, **kwargs) -> Any: | ||
return llm.complete(prompt, **kwargs) | ||
|
||
client = EasyDict( | ||
{ | ||
'chat': EasyDict({'completions': EasyDict({'create': chat})}), | ||
'completion': EasyDict({'create': complete}), | ||
'ChatCompletion': EasyDict({'create': chat}), | ||
'Completion': EasyDict({'create': complete}), | ||
} | ||
) | ||
return client | ||
|
||
|
||
# Configure logging | ||
logging.basicConfig( | ||
level=logging.INFO, | ||
format='%(message)s', | ||
handlers=[logging.FileHandler('azure_openai_usage.log')], | ||
) | ||
|
||
import functools | ||
import json | ||
from typing import Any, Literal | ||
|
||
import backoff | ||
import dsp | ||
import openai | ||
from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory, cache_turn_on | ||
from dsp.modules.lm import LM | ||
|
||
try: | ||
OPENAI_LEGACY = int(openai.version.__version__[0]) == 0 | ||
except Exception: | ||
OPENAI_LEGACY = True | ||
|
||
try: | ||
import openai.error | ||
from openai.openai_object import OpenAIObject | ||
|
||
ERRORS = ( | ||
openai.error.RateLimitError, | ||
openai.error.ServiceUnavailableError, | ||
openai.error.APIError, | ||
) | ||
except Exception: | ||
ERRORS = (openai.RateLimitError, openai.APIError) | ||
OpenAIObject = dict | ||
|
||
|
||
def backoff_hdlr(details): | ||
"""Handler from https://pypi.org/project/backoff/""" | ||
print( | ||
'Backing off {wait:0.1f} seconds after {tries} tries ' 'calling function {target} with kwargs ' '{kwargs}'.format(**details), | ||
) | ||
|
||
|
||
class DspyLlamaIndexWrapper(LM): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. refactor name to LlamaIndexWrapper. and please add to imports in lm.py as done with other modules |
||
"""Wrapper around Azure's API for OpenAI. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the azure openai documentation still relevant here? please refactor if not |
||
|
||
Args: | ||
api_base (str): Azure URL endpoint for model calling, often called 'azure_endpoint'. | ||
api_version (str): Version identifier for API. | ||
model (str, optional): OpenAI or Azure supported LLM model to use. Defaults to "text-davinci-002". | ||
api_key (Optional[str], optional): API provider Authentication token. use Defaults to None. | ||
model_type (Literal["chat", "text"], optional): The type of model that was specified. Mainly to decide the optimal prompting strategy. Defaults to "chat". | ||
**kwargs: Additional arguments to pass to the API provider. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
llm: LLM, | ||
model_type: Literal['chat', 'text'] = 'chat', | ||
**kwargs, | ||
): | ||
super().__init__(llm.model) | ||
self.provider = 'openai' | ||
|
||
self.llm = llm | ||
self.client = LlamaIndexOpenAIClientWrapper(llm) | ||
model = llm.model | ||
self.model_type = model_type | ||
|
||
# if not OPENAI_LEGACY and "model" not in kwargs: | ||
# if "deployment_id" in kwargs: | ||
# kwargs["model"] = kwargs["deployment_id"] | ||
# del kwargs["deployment_id"] | ||
|
||
# if "api_version" in kwargs: | ||
# del kwargs["api_version"] | ||
|
||
if 'model' not in kwargs: | ||
kwargs['model'] = model | ||
|
||
self.kwargs = { | ||
'temperature': 0.0, | ||
'max_tokens': 150, | ||
'top_p': 1, | ||
'frequency_penalty': 0, | ||
'presence_penalty': 0, | ||
'n': 1, | ||
**kwargs, | ||
} # TODO: add kwargs above for </s> | ||
|
||
self.history: list[dict[str, Any]] = [] | ||
|
||
def _openai_client(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. refactor if the scope of this LM is not limited to just OpenAI |
||
# if OPENAI_LEGACY: | ||
# return openai | ||
|
||
return self.client | ||
|
||
def log_usage(self, response): | ||
"""Log the total tokens from the Azure OpenAI API response.""" | ||
usage_data = response.get('usage') | ||
if usage_data: | ||
total_tokens = usage_data.get('total_tokens') | ||
logging.info(f'{total_tokens}') | ||
|
||
def basic_request(self, prompt: str, **kwargs): | ||
raw_kwargs = kwargs | ||
|
||
kwargs = {**self.kwargs, **kwargs} | ||
if self.model_type == 'chat': | ||
# caching mechanism requires hashable kwargs | ||
kwargs['messages'] = [{'role': 'user', 'content': prompt}] | ||
kwargs = {'stringify_request': json.dumps(kwargs)} | ||
# response = chat_request(self.client, **kwargs) | ||
# if OPENAI_LEGACY: | ||
# return _cached_gpt3_turbo_request_v2_wrapped(**kwargs) | ||
# else: | ||
return v1_chat_request(self.client, **kwargs) | ||
|
||
else: | ||
kwargs['prompt'] = prompt | ||
response = self.completions_request(**kwargs) | ||
|
||
history = { | ||
'prompt': prompt, | ||
'response': response, | ||
'kwargs': kwargs, | ||
'raw_kwargs': raw_kwargs, | ||
} | ||
self.history.append(history) | ||
|
||
return response | ||
|
||
@backoff.on_exception( | ||
backoff.expo, | ||
ERRORS, | ||
max_time=1000, | ||
on_backoff=backoff_hdlr, | ||
) | ||
def request(self, prompt: str, **kwargs): | ||
"""Handles retrieval of GPT-3 completions whilst handling rate limiting and caching.""" | ||
if 'model_type' in kwargs: | ||
del kwargs['model_type'] | ||
|
||
return self.basic_request(prompt, **kwargs) | ||
|
||
def _get_choice_text(self, choice: dict[str, Any]) -> str: | ||
if self.model_type == 'chat': | ||
return choice['message']['content'] | ||
return choice['text'] | ||
|
||
def __call__( | ||
self, | ||
prompt: str, | ||
only_completed: bool = True, | ||
return_sorted: bool = False, | ||
**kwargs, | ||
) -> list[dict[str, Any]]: | ||
"""Retrieves completions from OpenAI Model. | ||
|
||
Args: | ||
prompt (str): prompt to send to GPT-3 | ||
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True. | ||
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False. | ||
|
||
Returns: | ||
list[dict[str, Any]]: list of completion choices | ||
""" | ||
|
||
assert only_completed, 'for now' | ||
assert return_sorted is False, 'for now' | ||
|
||
response = self.request(prompt, **kwargs) | ||
|
||
try: | ||
if dsp.settings.log_openai_usage: | ||
self.log_usage(response) | ||
except Exception: | ||
pass | ||
|
||
choices = response['choices'] | ||
|
||
completed_choices = [c for c in choices if c['finish_reason'] != 'length'] | ||
|
||
if only_completed and len(completed_choices): | ||
choices = completed_choices | ||
|
||
completions = [self._get_choice_text(c) for c in choices] | ||
if return_sorted and kwargs.get('n', 1) > 1: | ||
scored_completions = [] | ||
|
||
for c in choices: | ||
tokens, logprobs = ( | ||
c['logprobs']['tokens'], | ||
c['logprobs']['token_logprobs'], | ||
) | ||
|
||
if '<|endoftext|>' in tokens: | ||
index = tokens.index('<|endoftext|>') + 1 | ||
tokens, logprobs = tokens[:index], logprobs[:index] | ||
|
||
avglog = sum(logprobs) / len(logprobs) | ||
scored_completions.append((avglog, self._get_choice_text(c))) | ||
|
||
scored_completions = sorted(scored_completions, reverse=True) | ||
completions = [c for _, c in scored_completions] | ||
|
||
return completions | ||
|
||
def completions_request(self, **kwargs): | ||
# if OPENAI_LEGACY: | ||
# return cached_gpt3_request_v2_wrapped(**kwargs) | ||
return v1_completions_request(self.client, **kwargs) | ||
|
||
|
||
def v1_chat_request(client, **kwargs): | ||
@functools.lru_cache(maxsize=None if cache_turn_on else 0) | ||
@NotebookCacheMemory.cache | ||
def v1_cached_gpt3_turbo_request_v2_wrapped(**kwargs): | ||
@CacheMemory.cache | ||
def v1_cached_gpt3_turbo_request_v2(**kwargs): | ||
if 'stringify_request' in kwargs: | ||
kwargs = json.loads(kwargs['stringify_request']) | ||
return client.chat.completions.create(**kwargs) | ||
|
||
return v1_cached_gpt3_turbo_request_v2(**kwargs) | ||
|
||
response = v1_cached_gpt3_turbo_request_v2_wrapped(**kwargs) | ||
|
||
try: | ||
response = response.model_dump() | ||
except Exception: | ||
response = response.raw | ||
response['choices'] = [json.loads(x.json()) for x in response['choices']] | ||
response['usage'] = json.loads(response['usage'].json()) | ||
return response | ||
|
||
|
||
def v1_completions_request(client, **kwargs): | ||
@functools.lru_cache(maxsize=None if cache_turn_on else 0) | ||
@NotebookCacheMemory.cache | ||
def v1_cached_gpt3_request_v2_wrapped(**kwargs): | ||
@CacheMemory.cache | ||
def v1_cached_gpt3_request_v2(**kwargs): | ||
return client.completions.create(**kwargs) | ||
|
||
return v1_cached_gpt3_request_v2(**kwargs) | ||
|
||
return v1_cached_gpt3_request_v2_wrapped(**kwargs).model_dump() | ||
|
||
|
||
## ======== test ========= | ||
|
||
if __name__ == '__main__': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove and add to documentation for the LM |
||
print('Testing DspyLlamaIndexWrapper') | ||
import os | ||
|
||
import dspy | ||
from dspy.datasets.gsm8k import GSM8K, gsm8k_metric | ||
from llama_index.llms.openai import OpenAI | ||
|
||
llm = OpenAI(api_key=os.environ['OPENAI_API_KEY'], model='gpt-3.5-turbo') | ||
dspy_llm = DspyLlamaIndexWrapper(llm) | ||
|
||
# Load math questions from the GSM8K dataset. | ||
gsm8k = GSM8K() | ||
gsm8k_trainset, gsm8k_devset = gsm8k.train[:10], gsm8k.dev[:10] | ||
|
||
class CoT(dspy.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.prog = dspy.ChainOfThought('question -> answer') | ||
|
||
def forward(self, question): | ||
response = self.prog(question=question) | ||
return response | ||
|
||
## | ||
|
||
dspy.settings.configure(lm=dspy_llm) | ||
|
||
from dspy.teleprompt import BootstrapFewShot | ||
|
||
# Set up the optimizer: we want to "bootstrap" (i.e., self-generate) 4-shot examples of our CoT program. | ||
config = dict(max_bootstrapped_demos=4, max_labeled_demos=4) | ||
|
||
# Optimize! Use the `gsm8k_metric` here. In general, the metric is going to tell the optimizer how well it's doing. | ||
teleprompter = BootstrapFewShot(metric=gsm8k_metric, **config) | ||
optimized_cot = teleprompter.compile(CoT(), trainset=gsm8k_trainset) | ||
print(f'{optimized_cot=}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feel free to add these comments into documentation for the LlamaIndexWrapper LM