diff --git a/.github/workflows/test-integrations-data-processing.yml b/.github/workflows/test-integrations-data-processing.yml index 28c788d69a..b9f1b3fdcb 100644 --- a/.github/workflows/test-integrations-data-processing.yml +++ b/.github/workflows/test-integrations-data-processing.yml @@ -70,6 +70,10 @@ jobs: run: | set -x # print commands that are executed ./scripts/runtox.sh "py${{ matrix.python-version }}-openai-latest" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch + - name: Test huggingface_hub latest + run: | + set -x # print commands that are executed + ./scripts/runtox.sh "py${{ matrix.python-version }}-huggingface_hub-latest" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch - name: Test rq latest run: | set -x # print commands that are executed @@ -134,6 +138,10 @@ jobs: run: | set -x # print commands that are executed ./scripts/runtox.sh --exclude-latest "py${{ matrix.python-version }}-openai" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch + - name: Test huggingface_hub pinned + run: | + set -x # print commands that are executed + ./scripts/runtox.sh --exclude-latest "py${{ matrix.python-version }}-huggingface_hub" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch - name: Test rq pinned run: | set -x # print commands that are executed diff --git a/mypy.ini b/mypy.ini index 0d8a60b64c..4f143ede97 100644 --- a/mypy.ini +++ b/mypy.ini @@ -73,6 +73,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-openai.*] ignore_missing_imports = True +[mypy-huggingface_hub.*] +ignore_missing_imports = True [mypy-arq.*] ignore_missing_imports = True [mypy-grpc.*] diff --git a/scripts/split-tox-gh-actions/split-tox-gh-actions.py b/scripts/split-tox-gh-actions/split-tox-gh-actions.py index 53fa55d909..5d5f423857 100755 --- a/scripts/split-tox-gh-actions/split-tox-gh-actions.py +++ b/scripts/split-tox-gh-actions/split-tox-gh-actions.py @@ -73,6 +73,7 @@ "huey", "langchain", "openai", + "huggingface_hub", "rq", ], "Databases": [ diff --git a/sentry_sdk/consts.py b/sentry_sdk/consts.py index 3ffa384e04..a83fde9f1b 100644 --- a/sentry_sdk/consts.py +++ b/sentry_sdk/consts.py @@ -325,6 +325,9 @@ class OP: MIDDLEWARE_STARLITE_SEND = "middleware.starlite.send" OPENAI_CHAT_COMPLETIONS_CREATE = "ai.chat_completions.create.openai" OPENAI_EMBEDDINGS_CREATE = "ai.embeddings.create.openai" + HUGGINGFACE_HUB_CHAT_COMPLETIONS_CREATE = ( + "ai.chat_completions.create.huggingface_hub" + ) LANGCHAIN_PIPELINE = "ai.pipeline.langchain" LANGCHAIN_RUN = "ai.run.langchain" LANGCHAIN_TOOL = "ai.tool.langchain" diff --git a/sentry_sdk/integrations/__init__.py b/sentry_sdk/integrations/__init__.py index f692e88294..fffd573491 100644 --- a/sentry_sdk/integrations/__init__.py +++ b/sentry_sdk/integrations/__init__.py @@ -85,6 +85,7 @@ def iter_default_integrations(with_auto_enabling_integrations): "sentry_sdk.integrations.graphene.GrapheneIntegration", "sentry_sdk.integrations.httpx.HttpxIntegration", "sentry_sdk.integrations.huey.HueyIntegration", + "sentry_sdk.integrations.huggingface_hub.HuggingfaceHubIntegration", "sentry_sdk.integrations.langchain.LangchainIntegration", "sentry_sdk.integrations.loguru.LoguruIntegration", "sentry_sdk.integrations.openai.OpenAIIntegration", diff --git a/sentry_sdk/integrations/huggingface_hub.py b/sentry_sdk/integrations/huggingface_hub.py new file mode 100644 index 0000000000..8e5f0e7339 --- /dev/null +++ b/sentry_sdk/integrations/huggingface_hub.py @@ -0,0 +1,173 @@ +from functools import wraps + +from sentry_sdk import consts +from sentry_sdk.ai.monitoring import record_token_usage +from sentry_sdk.ai.utils import set_data_normalized +from sentry_sdk.consts import SPANDATA + +from typing import Any, Iterable, Callable + +import sentry_sdk +from sentry_sdk.scope import should_send_default_pii +from sentry_sdk.integrations import DidNotEnable, Integration +from sentry_sdk.utils import ( + capture_internal_exceptions, + event_from_exception, + ensure_integration_enabled, +) + +try: + import huggingface_hub.inference._client + + from huggingface_hub import ChatCompletionStreamOutput, TextGenerationOutput +except ImportError: + raise DidNotEnable("Huggingface not installed") + + +class HuggingfaceHubIntegration(Integration): + identifier = "huggingface_hub" + + def __init__(self, include_prompts=True): + # type: (HuggingfaceHubIntegration, bool) -> None + self.include_prompts = include_prompts + + @staticmethod + def setup_once(): + # type: () -> None + huggingface_hub.inference._client.InferenceClient.text_generation = ( + _wrap_text_generation( + huggingface_hub.inference._client.InferenceClient.text_generation + ) + ) + + +def _capture_exception(exc): + # type: (Any) -> None + event, hint = event_from_exception( + exc, + client_options=sentry_sdk.get_client().options, + mechanism={"type": "huggingface_hub", "handled": False}, + ) + sentry_sdk.capture_event(event, hint=hint) + + +def _wrap_text_generation(f): + # type: (Callable[..., Any]) -> Callable[..., Any] + @wraps(f) + @ensure_integration_enabled(HuggingfaceHubIntegration, f) + def new_text_generation(*args, **kwargs): + # type: (*Any, **Any) -> Any + if "prompt" in kwargs: + prompt = kwargs["prompt"] + elif len(args) >= 2: + kwargs["prompt"] = args[1] + prompt = kwargs["prompt"] + args = (args[0],) + args[2:] + else: + # invalid call, let it return error + return f(*args, **kwargs) + + model = kwargs.get("model") + streaming = kwargs.get("stream") + + span = sentry_sdk.start_span( + op=consts.OP.HUGGINGFACE_HUB_CHAT_COMPLETIONS_CREATE, + description="Text Generation", + ) + span.__enter__() + try: + res = f(*args, **kwargs) + except Exception as e: + _capture_exception(e) + span.__exit__(None, None, None) + raise e from None + + integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration) + + with capture_internal_exceptions(): + if should_send_default_pii() and integration.include_prompts: + set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompt) + + set_data_normalized(span, SPANDATA.AI_MODEL_ID, model) + set_data_normalized(span, SPANDATA.AI_STREAMING, streaming) + + if isinstance(res, str): + if should_send_default_pii() and integration.include_prompts: + set_data_normalized( + span, + "ai.responses", + [res], + ) + span.__exit__(None, None, None) + return res + + if isinstance(res, TextGenerationOutput): + if should_send_default_pii() and integration.include_prompts: + set_data_normalized( + span, + "ai.responses", + [res.generated_text], + ) + if res.details is not None and res.details.generated_tokens > 0: + record_token_usage(span, total_tokens=res.details.generated_tokens) + span.__exit__(None, None, None) + return res + + if not isinstance(res, Iterable): + # we only know how to deal with strings and iterables, ignore + set_data_normalized(span, "unknown_response", True) + span.__exit__(None, None, None) + return res + + if kwargs.get("details", False): + # res is Iterable[TextGenerationStreamOutput] + def new_details_iterator(): + # type: () -> Iterable[ChatCompletionStreamOutput] + with capture_internal_exceptions(): + tokens_used = 0 + data_buf: list[str] = [] + for x in res: + if hasattr(x, "token") and hasattr(x.token, "text"): + data_buf.append(x.token.text) + if hasattr(x, "details") and hasattr( + x.details, "generated_tokens" + ): + tokens_used = x.details.generated_tokens + yield x + if ( + len(data_buf) > 0 + and should_send_default_pii() + and integration.include_prompts + ): + set_data_normalized( + span, SPANDATA.AI_RESPONSES, "".join(data_buf) + ) + if tokens_used > 0: + record_token_usage(span, total_tokens=tokens_used) + span.__exit__(None, None, None) + + return new_details_iterator() + else: + # res is Iterable[str] + + def new_iterator(): + # type: () -> Iterable[str] + data_buf: list[str] = [] + with capture_internal_exceptions(): + for s in res: + if isinstance(s, str): + data_buf.append(s) + yield s + if ( + len(data_buf) > 0 + and should_send_default_pii() + and integration.include_prompts + ): + set_data_normalized( + span, SPANDATA.AI_RESPONSES, "".join(data_buf) + ) + span.__exit__(None, None, None) + + return new_iterator() + + return new_text_generation diff --git a/sentry_sdk/integrations/langchain.py b/sentry_sdk/integrations/langchain.py index 35e955b958..c559870a86 100644 --- a/sentry_sdk/integrations/langchain.py +++ b/sentry_sdk/integrations/langchain.py @@ -63,7 +63,7 @@ def count_tokens(s): # To avoid double collecting tokens, we do *not* measure # token counts for models for which we have an explicit integration -NO_COLLECT_TOKEN_MODELS = ["openai-chat"] +NO_COLLECT_TOKEN_MODELS = ["openai-chat"] # TODO add huggingface and anthropic class LangchainIntegration(Integration): diff --git a/setup.py b/setup.py index e10fe624e1..39934c8aae 100644 --- a/setup.py +++ b/setup.py @@ -60,6 +60,7 @@ def get_file_text(file_name): "grpcio": ["grpcio>=1.21.1"], "httpx": ["httpx>=0.16.0"], "huey": ["huey>=2"], + "huggingface_hub": ["huggingface_hub>=0.22"], "langchain": ["langchain>=0.0.210"], "loguru": ["loguru>=0.5"], "openai": ["openai>=1.0.0", "tiktoken>=0.3.0"], diff --git a/tests/integrations/huggingface_hub/__init__.py b/tests/integrations/huggingface_hub/__init__.py new file mode 100644 index 0000000000..fe1fa0af50 --- /dev/null +++ b/tests/integrations/huggingface_hub/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("huggingface_hub") diff --git a/tests/integrations/huggingface_hub/test_huggingface_hub.py b/tests/integrations/huggingface_hub/test_huggingface_hub.py new file mode 100644 index 0000000000..062bd4fb31 --- /dev/null +++ b/tests/integrations/huggingface_hub/test_huggingface_hub.py @@ -0,0 +1,163 @@ +import itertools +import json + +import pytest +from huggingface_hub import ( + InferenceClient, + TextGenerationOutput, + TextGenerationOutputDetails, + TextGenerationStreamOutput, + TextGenerationOutputToken, + TextGenerationStreamDetails, +) +from huggingface_hub.errors import OverloadedError + +from sentry_sdk import start_transaction +from sentry_sdk.integrations.huggingface_hub import HuggingfaceHubIntegration + +from unittest import mock # python 3.3 and above + + +@pytest.mark.parametrize( + "send_default_pii, include_prompts, details_arg", + itertools.product([True, False], repeat=3), +) +def test_nonstreaming_chat_completion( + sentry_init, capture_events, send_default_pii, include_prompts, details_arg +): + sentry_init( + integrations=[HuggingfaceHubIntegration(include_prompts=include_prompts)], + traces_sample_rate=1.0, + send_default_pii=send_default_pii, + ) + events = capture_events() + + client = InferenceClient("some-model") + if details_arg: + client.post = mock.Mock( + return_value=json.dumps( + [ + TextGenerationOutput( + generated_text="the model response", + details=TextGenerationOutputDetails( + finish_reason="TextGenerationFinishReason", + generated_tokens=10, + prefill=[], + tokens=[], # not needed for integration + ), + ) + ] + ).encode("utf-8") + ) + else: + client.post = mock.Mock( + return_value=b'[{"generated_text": "the model response"}]' + ) + with start_transaction(name="huggingface_hub tx"): + response = client.text_generation( + prompt="hello", + details=details_arg, + stream=False, + ) + if details_arg: + assert response.generated_text == "the model response" + else: + assert response == "the model response" + tx = events[0] + assert tx["type"] == "transaction" + span = tx["spans"][0] + assert span["op"] == "ai.chat_completions.create.huggingface_hub" + + if send_default_pii and include_prompts: + assert "hello" in span["data"]["ai.input_messages"] + assert "the model response" in span["data"]["ai.responses"] + else: + assert "ai.input_messages" not in span["data"] + assert "ai.responses" not in span["data"] + + if details_arg: + assert span["measurements"]["ai_total_tokens_used"]["value"] == 10 + + +@pytest.mark.parametrize( + "send_default_pii, include_prompts, details_arg", + itertools.product([True, False], repeat=3), +) +def test_streaming_chat_completion( + sentry_init, capture_events, send_default_pii, include_prompts, details_arg +): + sentry_init( + integrations=[HuggingfaceHubIntegration(include_prompts=include_prompts)], + traces_sample_rate=1.0, + send_default_pii=send_default_pii, + ) + events = capture_events() + + client = InferenceClient("some-model") + client.post = mock.Mock( + return_value=[ + b"data:" + + json.dumps( + TextGenerationStreamOutput( + token=TextGenerationOutputToken( + id=1, special=False, text="the model " + ), + ), + ).encode("utf-8"), + b"data:" + + json.dumps( + TextGenerationStreamOutput( + token=TextGenerationOutputToken( + id=2, special=False, text="response" + ), + details=TextGenerationStreamDetails( + finish_reason="length", + generated_tokens=10, + seed=0, + ), + ) + ).encode("utf-8"), + ] + ) + with start_transaction(name="huggingface_hub tx"): + response = list( + client.text_generation( + prompt="hello", + details=details_arg, + stream=True, + ) + ) + assert len(response) == 2 + print(response) + if details_arg: + assert response[0].token.text + response[1].token.text == "the model response" + else: + assert response[0] + response[1] == "the model response" + + tx = events[0] + assert tx["type"] == "transaction" + span = tx["spans"][0] + assert span["op"] == "ai.chat_completions.create.huggingface_hub" + + if send_default_pii and include_prompts: + assert "hello" in span["data"]["ai.input_messages"] + assert "the model response" in span["data"]["ai.responses"] + else: + assert "ai.input_messages" not in span["data"] + assert "ai.responses" not in span["data"] + + if details_arg: + assert span["measurements"]["ai_total_tokens_used"]["value"] == 10 + + +def test_bad_chat_completion(sentry_init, capture_events): + sentry_init(integrations=[HuggingfaceHubIntegration()], traces_sample_rate=1.0) + events = capture_events() + + client = InferenceClient("some-model") + client.post = mock.Mock(side_effect=OverloadedError("The server is overloaded")) + with pytest.raises(OverloadedError): + client.text_generation(prompt="hello") + + (event,) = events + assert event["level"] == "error" diff --git a/tox.ini b/tox.ini index 47651c0faf..f1bc0e7a5e 100644 --- a/tox.ini +++ b/tox.ini @@ -144,6 +144,9 @@ envlist = {py3.6,py3.11,py3.12}-huey-v{2.0} {py3.6,py3.11,py3.12}-huey-latest + # Huggingface Hub + {py3.9,py3.11,py3.12}-huggingface_hub-{v0.22,latest} + # Langchain {py3.9,py3.11,py3.12}-langchain-0.1 {py3.9,py3.11,py3.12}-langchain-latest @@ -446,6 +449,10 @@ deps = huey-v2.0: huey~=2.0.0 huey-latest: huey + # Huggingface Hub + huggingface_hub-v0.22: huggingface_hub~=0.22.2 + huggingface_hub-latest: huggingface_hub + # Langchain langchain: openai~=1.0.0 langchain-0.1: langchain~=0.1.11 @@ -622,6 +629,7 @@ setenv = graphene: TESTPATH=tests/integrations/graphene httpx: TESTPATH=tests/integrations/httpx huey: TESTPATH=tests/integrations/huey + huggingface_hub: TESTPATH=tests/integrations/huggingface_hub langchain: TESTPATH=tests/integrations/langchain loguru: TESTPATH=tests/integrations/loguru openai: TESTPATH=tests/integrations/openai