Skip to content

Commit

Permalink
Fix llama.cpp integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dtiarks committed Feb 6, 2024
1 parent 34978bb commit 604e6b5
Showing 1 changed file with 51 additions and 57 deletions.
108 changes: 51 additions & 57 deletions tests/generate/test_integration_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


@pytest.fixture(scope="session")
def model_download(tmp_path_factory):
def model(tmp_path_factory):
tmp_path_factory.mktemp("./llama-test-model")
hf_hub_download(
repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF",
Expand All @@ -24,9 +24,12 @@ def model_download(tmp_path_factory):
filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf",
)


def test_llamacpp_integration_text(model_download):
model = llamacpp(TEST_MODEL, "cpu")
return model


def test_llamacpp_integration_text(model):
model.reset()
sequence = generate.text(model)(
"<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n"
)
Expand All @@ -49,8 +52,8 @@ def test_llamacpp_integration_text(model_download):
assert isinstance(sequence[0], str)


def test_llamacpp_integration_streaming(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_integration_streaming(model):
model.reset()
sequence = generate.text(model).stream(
"<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n",
max_tokens=10,
Expand All @@ -72,17 +75,17 @@ def test_llamacpp_integration_streaming(model_download):
assert isinstance(tokens[0], str)


def test_llamacpp_integration_text_stop(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_integration_text_stop(model):
model.reset()
prompt = (
"<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n"
)
sequence = generate.text(model)(prompt, stop_at="a")
assert sequence[len(prompt) :].find("a") == -1


def test_llamacpp_various_regexes(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_various_regexes(model):
model.reset()
prompt = (
"<|im_start|>user\nWrite an email address<|im_end|>\n<|im_start|>assistant\n"
)
Expand All @@ -94,8 +97,8 @@ def test_llamacpp_various_regexes(model_download):
assert re.fullmatch(regex_str, sequence) is not None


def test_llamacpp_various_regexes_prompt_list(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_various_regexes_prompt_list(model):
model.reset()
prompt = (
"<|im_start|>user\nWrite an email address<|im_end|>\n<|im_start|>assistant\n"
)
Expand All @@ -108,8 +111,8 @@ def test_llamacpp_various_regexes_prompt_list(model_download):
assert re.fullmatch(regex_str, sequence[1]) is not None


def test_llamacpp_integration_integer(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_integration_integer(model):
model.reset()
prompt = (
"<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n"
)
Expand All @@ -119,8 +122,8 @@ def test_llamacpp_integration_integer(model_download):
int(sequence)


def test_llamacpp_integration_integer_array(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_integration_integer_array(model):
model.reset()
prompts = ["Give me a number", "And another one"]
sequence = generate.format(model, int)(prompts, max_tokens=10)
assert isinstance(sequence, list)
Expand All @@ -129,8 +132,8 @@ def test_llamacpp_integration_integer_array(model_download):
int(sequence[1])


def test_llamacpp_integration_float(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_integration_float(model):
model.reset()
prompt = (
"<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n"
)
Expand All @@ -140,8 +143,8 @@ def test_llamacpp_integration_float(model_download):
float(sequence)


def test_llamacpp_integration_bool(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_integration_bool(model):
model.reset()
prompt = (
"<|im_start|>user\nIs this True or False?<|im_end|>\n<|im_start|>assistant\n"
)
Expand All @@ -151,8 +154,8 @@ def test_llamacpp_integration_bool(model_download):
bool(sequence)


def test_llamacpp_integration_date(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_integration_date(model):
model.reset()
prompt = (
"<|im_start|>user\nWhat day is it today?<|im_end|>\n<|im_start|>assistant\n"
)
Expand All @@ -162,26 +165,26 @@ def test_llamacpp_integration_date(model_download):
datetime.datetime.strptime(sequence, "%Y-%m-%d")


def test_llamacpp_integration_time(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_integration_time(model):
model.reset()
prompt = "<|im_start|>user\nWhat time is it?<|im_end|>\n<|im_start|>assistant\n"
sequence = generate.format(model, datetime.time)(prompt, max_tokens=10)

assert sequence != ""
datetime.datetime.strptime(sequence, "%H:%M:%S")


def test_llamacpp_integration_datetime(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_integration_datetime(model):
model.reset()
prompt = "<|im_start|>user\nWhat time is it?<|im_end|>\n<|im_start|>assistant\n"
sequence = generate.format(model, datetime.datetime)(prompt, max_tokens=20)

assert sequence != 0
datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S")


def test_llamacpp_integration_choice(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_integration_choice(model):
model.reset()
prompt = (
"<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n"
)
Expand All @@ -190,8 +193,7 @@ def test_llamacpp_integration_choice(model_download):
assert sequence == "test" or sequence == "choice"


def test_llamacpp_json_basic(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_json_basic(model):
model.reset()
prompt = "<|im_start|>user\nOutput some JSON<|im_end|>\n<|im_start|>assistant\n"

Expand All @@ -212,8 +214,7 @@ class Spam(BaseModel):
assert len(result.spam) <= 10


def test_llamacpp_json_schema(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_json_schema(model):
model.reset()
prompt = "<|im_start|>user\nOutput some JSON<|im_end|>\n<|im_start|>assistant\n"

Expand All @@ -236,8 +237,7 @@ def test_llamacpp_json_schema(model_download):
assert isinstance(result["bar"], str)


def test_llamacpp_json_batch(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_json_batch(model):
model.reset()
prompts = [
"<|im_start|>user\nOutput a valid JSON object. Only use alpha numeric characters as keys.<|im_end|>\n<|im_start|>assistant\n",
Expand All @@ -257,8 +257,7 @@ class Spam(BaseModel):
assert isinstance(result[1], BaseModel)


def test_llamacpp_json_str_enum(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_json_str_enum(model):
model.reset()
prompt = "<|im_start|>user\nOutput a valid JSON object. Only use alpha numeric characters as keys.<|im_end|>\n<|im_start|>assistant\n"

Expand All @@ -279,48 +278,44 @@ class User(BaseModel):
assert result.name in ["John", "Marc", "Michel"]


def test_llamacpp_json_int_enum(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_json_array(model):
model.reset()
prompt = "<|im_start|>user\nOutput a valid JSON object. Only use alpha numeric characters as keys.<|im_end|>\n<|im_start|>assistant\n"

class Id(int, Enum):
one = 1
two = 2

class User(BaseModel):
id: Id
id: int
value: List[float]

result = generate.json(model, User)(
prompt, max_tokens=500, kw_args={"temperature": 0.0}
)
assert isinstance(result, BaseModel)
assert isinstance(result.id, int)
assert result.id in [1, 2]
assert isinstance(result.value, list)
for value in result.value:
assert isinstance(value, float) or isinstance(value, int)


def test_llamacpp_json_array(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_json_int_enum(model):
model.reset()
prompt = "<|im_start|>user\nOutput a valid JSON object. Only use alpha numeric characters as keys.<|im_end|>\n<|im_start|>assistant\n"

class Id(int, Enum):
one = 1
two = 2

class User(BaseModel):
id: int
value: List[float]
id: Id

result = generate.json(model, User)(
prompt, max_tokens=500, kw_args={"temperature": 0.0}
)
assert isinstance(result, BaseModel)
assert isinstance(result.id, int)
assert isinstance(result.value, list)
for value in result.value:
assert isinstance(value, float) or isinstance(value, int)
assert result.id in [1, 2]


@pytest.mark.xfail(reason="The implementation of `anyOf` is incorrect")
def test_llamacpp_json_union(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_json_union(model):
model.reset()
prompt = "<|im_start|>user\nOutput some JSON<|im_end|>\n<|im_start|>assistant\n"

Expand All @@ -339,8 +334,7 @@ class Spam(BaseModel):
)


def test_llamacpp_json_function(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_json_function(model):
model.reset()
prompt = "<|im_start|>user\nOutput arguments for the function<|im_end|>\n<|im_start|>assistant\n"

Expand All @@ -354,8 +348,8 @@ def function(foo: int, bar: List[int]):
assert isinstance(function(**sequence), int)


def test_llamacpp_reduced_vocabulary_caching(model_download):
model = llamacpp(TEST_MODEL, "cpu")
def test_llamacpp_reduced_vocabulary_caching(model):
model.reset()
model2 = llamacpp(TEST_MODEL, "cpu")

# TODO: We might actually want only one copy of a given tokenizer.
Expand Down

0 comments on commit 604e6b5

Please sign in to comment.