Skip to content

Commit

Permalink
More comprehensive TransformersVision test cases
Browse files Browse the repository at this point in the history
Multi-image context, using `apply_chat_template`
  • Loading branch information
parkervg committed Aug 1, 2024
1 parent 6200d40 commit aad0ee9
Showing 1 changed file with 119 additions and 13 deletions.
132 changes: 119 additions & 13 deletions tests/generate/test_integration_transformers_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,22 @@

import pytest
from PIL import Image
from transformers import LlavaForConditionalGeneration
from transformers import AutoProcessor, LlavaForConditionalGeneration

import outlines
from outlines.models.transformers_vision import transformers_vision

IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg",
"https://upload.wikimedia.org/wikipedia/commons/7/71/2010-kodiak-bear-1.jpg",
"https://upload.wikimedia.org/wikipedia/commons/b/be/Tamias-rufus-001.jpg",
]


def img_from_url(url):
img_byte_stream = BytesIO(urlopen(url).read())
return Image.open(img_byte_stream).convert("RGB")


@pytest.fixture(scope="session")
def model(tmp_path_factory):
Expand All @@ -19,23 +30,118 @@ def model(tmp_path_factory):


@pytest.fixture(scope="session")
def image(tmp_path_factory):
def img_from_url(url):
img_byte_stream = BytesIO(urlopen(url).read())
return Image.open(img_byte_stream).convert("RGB")
def larger_model_with_processor(tmp_path_factory):
return transformers_vision(
"llava-hf/llava-interleave-qwen-0.5b-hf",
model_class=LlavaForConditionalGeneration,
device="cpu",
)


@pytest.fixture(scope="session")
def processor(tmp_path_factory):
return AutoProcessor.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf")


def test_single_image_text_gen(model, processor):
conversation = [
{
"role": "user",
"content": [{"type": "text", "text": "What is this?"}, {"type": "image"}],
},
]
generator = outlines.generate.text(model)
sequence = generator(
processor.apply_chat_template(conversation),
[img_from_url(IMAGE_URLS[0])],
seed=10000,
max_tokens=10,
)
assert isinstance(sequence, str)


def test_multi_image_text_gen(model, processor):
"""If the length of image tags and number of images we pass are > 1 and equal,
we should yield a successful generation.
"""
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What do all these have in common?"},
]
+ [{"type": "image"} for _ in range(len(IMAGE_URLS))],
},
]
generator = outlines.generate.text(model)
sequence = generator(
processor.apply_chat_template(conversation),
[img_from_url(i) for i in IMAGE_URLS],
seed=10000,
max_tokens=10,
)
assert isinstance(sequence, str)


def test_mismatched_image_text_gen(model, processor):
"""If the length of image tags and number of images we pass are unequal,
we should raise an error.
"""
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "I'm passing 3 images, but only 1 image tag"},
{"type": "image"},
],
},
]
generator = outlines.generate.text(model)
with pytest.raises(ValueError):
_ = generator(
processor.apply_chat_template(conversation),
[img_from_url(i) for i in IMAGE_URLS],
seed=10000,
max_tokens=10,
)


return img_from_url(
"https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg"
def test_default_apply_chat_template(larger_model_with_processor):
"""If we pass in a list of dicts in the format that huggingface chat_template expects,
and specify `apply_chat_template=True` when calling the generator,
we should yield a successful generation.
"""
conversation = [
{
"role": "user",
"content": [{"type": "text", "text": "What is this?"}, {"type": "image"}],
},
]
generator = outlines.generate.text(larger_model_with_processor)
sequence = generator(
conversation,
[img_from_url(IMAGE_URLS[0])],
seed=10000,
max_tokens=10,
apply_chat_template=True,
)
assert isinstance(sequence, str)


def test_text_gen(model, image):
assert model.model.device.type == "cpu"
description_generator = outlines.generate.text(model)
sequence = description_generator(
"<|im_start|>user\n<image>\nWhat is this?<|im_end|>\n<|im_start|>assistant\n",
[image],
def test_single_image_choice(model, processor):
conversation = [
{
"role": "user",
"content": [{"type": "text", "text": "What is this?"}, {"type": "image"}],
},
]
choices = ["cat", "dog"]
generator = outlines.generate.choice(model, choices)
sequence = generator(
processor.apply_chat_template(conversation),
[img_from_url(IMAGE_URLS[0])],
seed=10000,
max_tokens=10,
)
assert isinstance(sequence, str)
assert sequence in choices

0 comments on commit aad0ee9

Please sign in to comment.