Skip to content

Commit

Permalink
Make prompt validation opt-in (#11973)
Browse files Browse the repository at this point in the history
By default replace input_variables with the correct value

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
  • Loading branch information
nfcampos committed Oct 18, 2023
2 parents 9bc7e18 + 653cf56 commit 6bd9c1d
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 34 deletions.
15 changes: 15 additions & 0 deletions libs/langchain/langchain/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Literal, Set

from langchain.schema.messages import BaseMessage, HumanMessage
Expand Down Expand Up @@ -99,6 +100,20 @@ def check_valid_template(
)


def get_template_variables(template: str, template_format: str) -> List[str]:
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
else:
raise ValueError(f"Unsupported template format: {template_format}")

return sorted(input_variables)


class StringPromptValue(PromptValue):
"""String prompt value."""

Expand Down
4 changes: 3 additions & 1 deletion libs/langchain/langchain/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
"""List of input variables in template messages. Used for validation."""
messages: List[MessageLike]
"""List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False
"""Whether or not to try validating the template."""

def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
Expand Down Expand Up @@ -432,7 +434,7 @@ def validate_input_variables(cls, values: dict) -> dict:
input_types[message.variable_name] = List[AnyMessage]
if "partial_variables" in values:
input_vars = input_vars - set(values["partial_variables"])
if "input_variables" in values:
if "input_variables" in values and values.get("validate_template"):
if input_vars != set(values["input_variables"]):
raise ValueError(
"Got mismatched input_variables. "
Expand Down
15 changes: 12 additions & 3 deletions libs/langchain/langchain/prompts/few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
check_valid_template,
get_template_variables,
)
from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate
from langchain.prompts.example_selector.base import BaseExampleSelector
Expand Down Expand Up @@ -77,7 +78,7 @@ def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return False

validate_template: bool = True
validate_template: bool = False
"""Whether or not to try validating the template."""

input_variables: List[str]
Expand All @@ -95,7 +96,7 @@ def is_lc_serializable(cls) -> bool:
prefix: str = ""
"""A prompt template string to put before the examples."""

template_format: str = "f-string"
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""

@root_validator()
Expand All @@ -107,6 +108,14 @@ def template_is_valid(cls, values: Dict) -> Dict:
values["template_format"],
values["input_variables"] + list(values["partial_variables"]),
)
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["prefix"] + values["suffix"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values

class Config:
Expand Down
8 changes: 7 additions & 1 deletion libs/langchain/langchain/prompts/few_shot_with_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""

validate_template: bool = True
validate_template: bool = False
"""Whether or not to try validating the template."""

@root_validator(pre=True)
Expand Down Expand Up @@ -72,6 +72,12 @@ def template_is_valid(cls, values: Dict) -> Dict:
f"Got input_variables={input_variables}, but based on "
f"prefix/suffix expected {expected_input_variables}"
)
else:
values["input_variables"] = sorted(
set(values["suffix"].input_variables)
| set(values["prefix"].input_variables if values["prefix"] else [])
- set(values["partial_variables"])
)
return values

class Config:
Expand Down
33 changes: 16 additions & 17 deletions libs/langchain/langchain/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from __future__ import annotations

from pathlib import Path
from string import Formatter
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
_get_jinja2_variables_from_template,
check_valid_template,
get_template_variables,
)
from langchain.pydantic_v1 import root_validator

Expand Down Expand Up @@ -53,10 +52,10 @@ def lc_attributes(self) -> Dict[str, Any]:
template: str
"""The prompt template."""

template_format: str = "f-string"
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""

validate_template: bool = True
validate_template: bool = False
"""Whether or not to try validating the template."""

def __add__(self, other: Any) -> PromptTemplate:
Expand Down Expand Up @@ -127,6 +126,14 @@ def template_is_valid(cls, values: Dict) -> Dict:
check_valid_template(
values["template"], values["template_format"], all_inputs
)
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["template"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values

@classmethod
Expand Down Expand Up @@ -202,25 +209,17 @@ def from_template(
Returns:
The prompt template loaded from the template.
"""
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
else:
raise ValueError(f"Unsupported template format: {template_format}")

input_variables = get_template_variables(template, template_format)
_partial_variables = partial_variables or {}

if _partial_variables:
input_variables = {
input_variables = [
var for var in input_variables if var not in _partial_variables
}
]

return cls(
input_variables=sorted(input_variables),
input_variables=input_variables,
template=template,
template_format=template_format,
partial_variables=_partial_variables,
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/langchain/tools/sql_database/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["llm_chain"] = LLMChain(
llm=values.get("llm"),
prompt=PromptTemplate(
template=QUERY_CHECKER, input_variables=["query", "dialect"]
template=QUERY_CHECKER, input_variables=["dialect", "query"]
),
)

if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
raise ValueError(
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,5 @@
"partial_variables": {},
"template": "Given the following question and student answer, provide a correct answer and score the student answer.\nQuestion: {question}\nStudent Answer: {student_answer}\nCorrect Answer:",
"template_format": "f-string",
"validate_template": true,
"_type": "prompt"
}
}
15 changes: 13 additions & 2 deletions libs/langchain/tests/unit_tests/prompts/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,24 @@ def test_chat_prompt_template_with_messages() -> None:
def test_chat_invalid_input_variables_extra() -> None:
messages = [HumanMessage(content="foo")]
with pytest.raises(ValueError):
ChatPromptTemplate(messages=messages, input_variables=["foo"])
ChatPromptTemplate(
messages=messages, input_variables=["foo"], validate_template=True
)
assert (
ChatPromptTemplate(messages=messages, input_variables=["foo"]).input_variables
== []
)


def test_chat_invalid_input_variables_missing() -> None:
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
with pytest.raises(ValueError):
ChatPromptTemplate(messages=messages, input_variables=[])
ChatPromptTemplate(
messages=messages, input_variables=[], validate_template=True
)
assert ChatPromptTemplate(
messages=messages, input_variables=[]
).input_variables == ["foo"]


def test_infer_variables() -> None:
Expand Down
48 changes: 48 additions & 0 deletions libs/langchain/tests/unit_tests/prompts/test_few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,14 @@ def test_prompt_missing_input_variables() -> None:
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=[],
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]

# Test when missing in prefix
template = "This is a {foo} test."
Expand All @@ -78,7 +85,15 @@ def test_prompt_missing_input_variables() -> None:
examples=[],
prefix=template,
example_prompt=EXAMPLE_PROMPT,
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=[],
suffix="foo",
examples=[],
prefix=template,
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]


def test_prompt_extra_input_variables() -> None:
Expand All @@ -91,7 +106,14 @@ def test_prompt_extra_input_variables() -> None:
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=input_variables,
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]


def test_few_shot_functionality() -> None:
Expand Down Expand Up @@ -248,7 +270,15 @@ def test_prompt_jinja2_missing_input_variables(
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=[],
suffix=suffix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
).input_variables == ["bar"]

# Test when missing in prefix
with pytest.warns(UserWarning):
Expand All @@ -259,7 +289,16 @@ def test_prompt_jinja2_missing_input_variables(
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=["bar"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
).input_variables == ["bar", "foo"]


@pytest.mark.requires("jinja2")
Expand All @@ -277,7 +316,16 @@ def test_prompt_jinja2_extra_input_variables(
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=["bar", "foo", "extra", "thing"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
).input_variables == ["bar", "foo"]


def test_few_shot_chat_message_prompt_template() -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test few shot prompt template."""

import pytest

from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
from langchain.prompts.prompt import PromptTemplate

Expand Down Expand Up @@ -38,3 +40,37 @@ def test_prompttemplate_prefix_suffix() -> None:
"Now you try to talk about party."
)
assert output == expected_output


def test_prompttemplate_validation() -> None:
"""Test that few shot works when prefix and suffix are PromptTemplates."""
prefix = PromptTemplate(
input_variables=["content"], template="This is a test about {content}."
)
suffix = PromptTemplate(
input_variables=["new_content"],
template="Now you try to talk about {new_content}.",
)

examples = [
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
]
with pytest.raises(ValueError):
FewShotPromptWithTemplates(
suffix=suffix,
prefix=prefix,
input_variables=[],
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
validate_template=True,
)
assert FewShotPromptWithTemplates(
suffix=suffix,
prefix=prefix,
input_variables=[],
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
).input_variables == ["content", "new_content"]
Loading

0 comments on commit 6bd9c1d

Please sign in to comment.