Skip to content

Commit

Permalink
now get emitter through context var (#100)
Browse files Browse the repository at this point in the history
* now get emitter through context var

* remove unused parameter

* force langchain version

* make run_sync use main event loop

* typing and error

* param spec mport

* use uvicorn event loop
  • Loading branch information
willydouhard committed Jun 24, 2023
1 parent fd224e7 commit 07df51e
Show file tree
Hide file tree
Showing 16 changed files with 89 additions and 130 deletions.
2 changes: 1 addition & 1 deletion cypress/e2e/sdk_availability/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import chainlit as cl
from chainlit.sync import make_async, run_sync
from chainlit.emitter import get_emitter
from chainlit.context import get_emitter


async def async_function_from_sync():
Expand Down
9 changes: 6 additions & 3 deletions src/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from chainlit.user_session import user_session
from chainlit.sync import run_sync, make_async
from chainlit.cache import cache
from chainlit.context import emitter_var

if LANGCHAIN_INSTALLED:
from chainlit.lc.callbacks import (
Expand All @@ -43,7 +44,7 @@ def wrap_user_function(user_function: Callable, with_task=False) -> Callable:
Callable: The wrapped function.
"""

async def wrapper(*args, __chainlit_emitter__: ChainlitEmitter):
async def wrapper(*args):
# Get the parameter names of the user-defined function
user_function_params = list(inspect.signature(user_function).parameters.keys())

Expand All @@ -52,8 +53,10 @@ async def wrapper(*args, __chainlit_emitter__: ChainlitEmitter):
param_name: arg for param_name, arg in zip(user_function_params, args)
}

emitter = emitter_var.get()

if with_task:
await __chainlit_emitter__.task_start()
await emitter.task_start()

try:
# Call the user-defined function with the arguments
Expand All @@ -68,7 +71,7 @@ async def wrapper(*args, __chainlit_emitter__: ChainlitEmitter):
await ErrorMessage(content=str(e), author="Error").send()
finally:
if with_task:
await __chainlit_emitter__.task_end()
await emitter.task_end()

return wrapper

Expand Down
6 changes: 2 additions & 4 deletions src/chainlit/action.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic.dataclasses import dataclass
from dataclasses_json import dataclass_json

from chainlit.emitter import get_emit_fn
from chainlit.context import get_emitter
from chainlit.telemetry import trace_event


Expand All @@ -21,9 +21,7 @@ class Action:

def __post_init__(self) -> None:
trace_event(f"init {self.__class__.__name__}")
self.emit = get_emit_fn()
if not self.emit:
raise RuntimeError("Action should be instantiated in a Chainlit context")
self.emit = get_emitter().emit

async def send(self, for_id: str):
trace_event(f"send {self.__class__.__name__}")
Expand Down
14 changes: 1 addition & 13 deletions src/chainlit/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
import os
import sys
import uvicorn
import asyncio
import nest_asyncio

nest_asyncio.apply()

from chainlit.config import (
config,
Expand Down Expand Up @@ -52,15 +48,7 @@ def run_chainlit(target: str):

log_level = "debug" if config.run.debug else "error"

# Start the server
async def start():
config = uvicorn.Config(app, host=host, port=port, log_level=log_level)
server = uvicorn.Server(config)
await server.serve()

# Run the asyncio event loop instead of uvloop to enable re entrance
asyncio.run(start())
# uvicorn.run(app, host=host, port=port)
uvicorn.run(app, host=host, port=port, log_level=log_level)


# Define the "run" command for Chainlit CLI
Expand Down
29 changes: 29 additions & 0 deletions src/chainlit/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import contextvars
from typing import TYPE_CHECKING
from asyncio import AbstractEventLoop

if TYPE_CHECKING:
from chainlit.emitter import ChainlitEmitter


class ChainlitContextException(Exception):
def __init__(self, msg="Chainlit context not found", *args, **kwargs):
super().__init__(msg, *args, **kwargs)


emitter_var = contextvars.ContextVar("emitter")
loop_var = contextvars.ContextVar("loop")


def get_emitter() -> "ChainlitEmitter":
try:
return emitter_var.get()
except LookupError:
raise ChainlitContextException()


def get_loop() -> AbstractEventLoop:
try:
return loop_var.get()
except LookupError:
raise ChainlitContextException()
5 changes: 2 additions & 3 deletions src/chainlit/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import aiofiles
from io import BytesIO

from chainlit.emitter import get_emitter, BaseClient
from chainlit.context import get_emitter
from chainlit.client import BaseClient
from chainlit.telemetry import trace_event
from chainlit.types import ElementType, ElementDisplay, ElementSize

Expand Down Expand Up @@ -41,8 +42,6 @@ class Element:
def __post_init__(self) -> None:
trace_event(f"init {self.__class__.__name__}")
self.emitter = get_emitter()
if not self.emitter:
raise RuntimeError("Element should be instantiated in a Chainlit context")

if not self.url and not self.path and not self.content:
raise ValueError("Must provide url, path or content to instantiate element")
Expand Down
30 changes: 1 addition & 29 deletions src/chainlit/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from chainlit.session import Session
from chainlit.types import AskSpec
from chainlit.client import BaseClient
from chainlit.context import emitter_var
from socketio.exceptions import TimeoutError
import inspect


class ChainlitEmitter:
Expand Down Expand Up @@ -129,31 +129,3 @@ def stream_start(self, msg_dict: Dict):
def send_token(self, id: Union[str, int], token: str):
"""Send a message token to the UI."""
return self.emit("stream_token", {"id": id, "token": token})


def get_emitter() -> Union[ChainlitEmitter, None]:
"""
Get the Chainlit Emitter instance from the current call stack.
This unusual approach is necessary because:
- we need to get the right Emitter instance with the right websocket connection
- to preserve a lean developer experience, we do not pass the Emitter instance to every function call
What happens is that we set __chainlit_emitter__ in the local variables when we receive a websocket message.
Then we can retrieve it from the call stack when we need it, even if the developer's code has no idea about it.
"""
attr = "__chainlit_emitter__"
candidates = [i[0].f_locals.get(attr) for i in inspect.stack()]
emitter = None
for candidate in candidates:
if candidate:
emitter = candidate
break

return emitter


def get_emit_fn():
emitter = get_emitter()
if emitter:
return emitter.emit
return None
1 change: 1 addition & 0 deletions src/chainlit/lc/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any
from chainlit.lc.callbacks import ChainlitCallbackHandler, AsyncChainlitCallbackHandler
from chainlit.sync import make_async
from chainlit.context import emitter_var


async def run_langchain_agent(agent: Any, input_str: str, use_async: bool):
Expand Down
27 changes: 6 additions & 21 deletions src/chainlit/lc/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
BaseMessage,
LLMResult,
)
from chainlit.emitter import get_emitter, ChainlitEmitter
from chainlit.emitter import ChainlitEmitter
from chainlit.context import get_emitter
from chainlit.message import Message, ErrorMessage
from chainlit.config import config
from chainlit.types import LLMSettings
Expand Down Expand Up @@ -107,14 +108,10 @@ def start_stream(self):
return

if config.code.lc_rename:
author = run_sync(
config.code.lc_rename(author, __chainlit_emitter__=self.emitter)
)
author = run_sync(config.code.lc_rename(author))

self.pop_prompt()

__chainlit_emitter__ = self.emitter

streamed_message = Message(
author=author,
indent=indent,
Expand All @@ -135,11 +132,7 @@ def add_message(self, message, prompt: str = None, error=False):
return

if config.code.lc_rename:
author = run_sync(
config.code.lc_rename(author, __chainlit_emitter__=self.emitter)
)

__chainlit_emitter__ = self.emitter
author = run_sync(config.code.lc_rename(author))

if error:
run_sync(ErrorMessage(author=author, content=message).send())
Expand Down Expand Up @@ -267,14 +260,10 @@ async def start_stream(self):
return

if config.code.lc_rename:
author = await config.code.lc_rename(
author, __chainlit_emitter__=self.emitter
)
author = await config.code.lc_rename(author)

self.pop_prompt()

__chainlit_emitter__ = self.emitter

streamed_message = Message(
author=author,
indent=indent,
Expand All @@ -295,11 +284,7 @@ async def add_message(self, message, prompt: str = None, error=False):
return

if config.code.lc_rename:
author = await config.code.lc_rename(
author, __chainlit_emitter__=self.emitter
)

__chainlit_emitter__ = self.emitter
author = await config.code.lc_rename(author)

if error:
await ErrorMessage(author=author, content=message).send()
Expand Down
4 changes: 1 addition & 3 deletions src/chainlit/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio

from chainlit.telemetry import trace_event
from chainlit.emitter import get_emitter
from chainlit.context import get_emitter
from chainlit.config import config
from chainlit.types import (
LLMSettings,
Expand Down Expand Up @@ -34,8 +34,6 @@ def __post_init__(self) -> None:
self.temp_id = uuid.uuid4().hex
self.created_at = current_milli_time()
self.emitter = get_emitter()
if not self.emitter:
raise RuntimeError("Message should be instantiated in a Chainlit context")

@abstractmethod
def to_dict(self):
Expand Down
Loading

0 comments on commit 07df51e

Please sign in to comment.