diff --git a/src/azarrot/frontends/openai_frontend.py b/src/azarrot/frontends/openai_frontend.py index f4e57f6..7ea8962 100644 --- a/src/azarrot/frontends/openai_frontend.py +++ b/src/azarrot/frontends/openai_frontend.py @@ -179,20 +179,32 @@ def __to_backend_generation_messages( return result def __to_backend_tool_parameters( - self, tool_parameters: list[dict[str, Any]] | None + self, tool_parameters: dict[str, Any] | None ) -> list[LocalizedToolParameter]: if tool_parameters is None: return [] - return [ - LocalizedToolParameter( - name=parameter["name"], - description=parameter.get("description"), - type=parameter["type"], - required=parameter.get("required", False), - ) - for parameter in tool_parameters - ] + param_type = tool_parameters["type"] + + if param_type != "object": + raise ValueError(f"Unsupported tool parameter type {param_type}") + + required_params = tool_parameters.get("required", []) + + params = [] + + if "properties" in tool_parameters: + for k, v in tool_parameters["properties"].items(): + p = LocalizedToolParameter( + name=k, + description=v.get("description"), + type=v.get("type"), + required=k in required_params + ) + + params.append(p) + + return params def __to_backend_tools_info( self, tools_info: list[ToolInfo] | None, tools_choice: Literal["none", "auto", "required"] | ToolChoice | None diff --git a/src/azarrot/frontends/openai_support/openai_data.py b/src/azarrot/frontends/openai_support/openai_data.py index 29f304f..67e4b8b 100644 --- a/src/azarrot/frontends/openai_support/openai_data.py +++ b/src/azarrot/frontends/openai_support/openai_data.py @@ -61,7 +61,7 @@ class ChatCompletionStreamOptions(BaseModel): class ToolFunctionInfo(BaseModel): description: str | None = None name: str - parameters: list[dict[str, Any]] | None = None + parameters: dict[str, Any] | None = None class ToolInfo(BaseModel):