Skip to content

Commit

Permalink
Introduce outlines.models.transformers_vision
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Jul 19, 2024
1 parent f6a6c29 commit 9ae6e70
Show file tree
Hide file tree
Showing 10 changed files with 598 additions and 79 deletions.
114 changes: 114 additions & 0 deletions docs/reference/models/transformers_vision.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Transformers Vision

Outlines allows seamless use of [vision models](https://huggingface.co/learn/computer-vision-course/en/unit4/multimodal-models/tasks-models-part1).

`outlines.models.transformers_vision` has shares interfaces with, and is based on [`outlines.models.transformers`](./transformers.md).

Tasks supported include
- image + text -> text
- video + text -> text



## Example: Using [Llava-Next](https://huggingface.co/docs/transformers/en/model_doc/llava_next) Vision Models

Install dependencies
`pip install torchvision pillow flash-attn`

Create the model
```python
import outlines

model = outlines.models.transformers_vision(
"llava-hf/llava-v1.6-mistral-7b-hf",
device="cuda",
)
```

Create convenience function to load a `PIL.Image` from URL
```
from PIL import Image
from io import BytesIO
from urllib.request import urlopen
def img_from_url(url):
img_byte_stream = BytesIO(urlopen(url).read())
return Image.open(img_byte_stream).convert("RGB")
```

### Describing an image

```python
description_generator = outlines.generate.text(model)
description_generator(
"<image> detailed description:",
[img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg")]
)
```

> This is a color photograph featuring a Siamese cat with striking blue eyes. The cat has a creamy coat and a light eye color, which is typical for the Siamese breed. Its features include elongated ears, a long, thin tail, and a striking coat pattern. The cat is sitting in an indoor setting, possibly on a cat tower or a similar raised platform, which is covered with a beige fabric, providing a comfortable and soft surface for the cat to rest or perch. The surface of the wall behind the cat appears to be a light-colored stucco or plaster.
#### Multiple Images

To include multiple images in your prompt you simply add more `<image>` tokens to the prompt

```python
image_urls = [
"https://cdn1.byjus.com/wp-content/uploads/2020/08/ShapeArtboard-1-copy-3.png", # triangle
"https://cdn1.byjus.com/wp-content/uploads/2020/08/ShapeArtboard-1-copy-11.png", # hexagon
]
description_generator = outlines.generate.text(model)
description_generator(
"<image><image><image>What shapes are present?",
list(map(img_from_url, image_urls)),
)
```

> There are two shapes present. One shape is a hexagon and the other shape is an triangle. '

### Classifying an Image

```python
pattern = "Mercury|Venus|Earth|Mars|Saturn|Jupiter|Neptune|Uranus|Pluto"
planet_generator = outlines.generate.regex(model, pattern)

planet_generator(
"What planet is this: <image>",
[img_from_url("https://upload.wikimedia.org/wikipedia/commons/e/e3/Saturn_from_Cassini_Orbiter_%282004-10-06%29.jpg")]
)
```

> Saturn

### Extracting Structured Image data

```python
from pydantic import BaseModel
from typing import List, Optional

def img_from_url(url)

class ImageData(BaseModel):
caption: str
tags_list: List[str]
object_list: List[str]
is_photo: bool

image_data_generator = outlines.generate.json(model, ImageData)

image_data_generator(
"<image> detailed JSON metadata:",
[img_from_url("https://upload.wikimedia.org/wikipedia/commons/9/98/Aldrin_Apollo_11_original.jpg")]
)
```

> `ImageData(caption='An astronaut on the moon', tags_list=['moon', 'space', 'nasa', 'americanflag'], object_list=['moon', 'moon_surface', 'space_suit', 'americanflag'], is_photo=True)`

## Resources

### Chosing a model
- https://mmbench.opencompass.org.cn/leaderboard
- https://huggingface.co/spaces/WildVision/vision-arena
109 changes: 100 additions & 9 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union

from outlines.generate.generator import sequence_generator
from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler
Expand Down Expand Up @@ -479,6 +479,13 @@ def format_sequence(self, sequence: str) -> FormattedOutput:
"""
return sequence

def _format(self, sequences):
"""Apply formatting to every string in a completion."""
if isinstance(sequences, list):
return [self._format(sequence) for sequence in sequences]
else:
return self.format_sequence(sequences)

def __call__(
self,
prompts: Union[str, List[str]],
Expand All @@ -489,13 +496,6 @@ def __call__(
):
"""Generate text from a prompt of list of prompts."""

def format(sequences):
"""Apply formatting to every string in a completion."""
if isinstance(sequences, list):
return [format(sequence) for sequence in sequences]
else:
return self.format_sequence(sequences)

generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)
Expand All @@ -508,7 +508,7 @@ def format(sequences):
**model_specific_params,
)

return format(completions)
return self._format(completions)

def stream(
self,
Expand All @@ -529,3 +529,94 @@ def stream(
self.sampling_params,
**model_specific_params,
)


class VisionSequenceGeneratorAdapter(SequenceGeneratorAdapter):
def __call__( # type: ignore
self,
prompts: Union[str, List[str]],
media: Union[str, Any],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
**model_specific_params,
):
"""
Generate text from a prompt of list of prompts.
Media: A URI to construct media or media object itself. Used as AutoProcessor argument.
"""
prompts, media = self._validate_prompt_media_types(prompts, media)

generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)

completions = self.model.generate(
prompts,
media,
generation_params,
self.logits_processor,
self.sampling_params,
**model_specific_params,
)

return self._format(completions)

def stream( # type: ignore
self,
prompts: Union[str, List[str]],
media: List[Union[str, Any, List[Union[str, Any]]]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
**model_specific_params,
):
"""Return a text generator from a prompt or a list of prompts."""
prompts, media = self._validate_prompt_media_types(prompts, media)
generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)
return self.model.stream(
prompts,
media,
generation_params,
self.logits_processor,
self.sampling_params,
**model_specific_params,
)

@classmethod
def _validate_prompt_media_types(
cls,
prompts: Union[str, List[str]],
media: Union[str, Any, List[Union[str, Any]]],
) -> Union[Any, List[Any]]:
"""
Prepare media as PIL.Image and ensure for every prompt str there is one List[PIL.Image]
"""

def valid_types(prompts, media):
from PIL import Image # type: ignore

if isinstance(prompts, list):
if not isinstance(media, list) or len(prompts) != len(media):
return False
for subprompt, submedia in zip(prompts, media):
if not isinstance(subprompt, str) or not all(
isinstance(m, Image.Image) for m in submedia
):
return False
elif isinstance(prompts, str):
if not all(isinstance(m, Image.Image) for m in media):
return False
return True

if not valid_types(prompts, media):
raise TypeError(
"Expected (prompts, media) to be of type "
"(str, List[Image])), or (List[str], List[List[Image]]) "
f"instead got prompts={prompts}, media={media}"
)

return prompts, media
17 changes: 15 additions & 2 deletions outlines/generate/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import interegular

from outlines.fsm.guide import RegexGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import MLXLM, LlamaCpp, Transformers
from outlines.generate.api import (
SequenceGenerator,
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import MLXLM, LlamaCpp, Transformers, TransformersVision
from outlines.samplers import Sampler, multinomial


Expand All @@ -29,3 +33,12 @@ def fsm_unified(
fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
logits_processor = FSMLogitsProcessor(tokenizer=model.tokenizer, fsm=fsm)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@fsm.register(TransformersVision)
def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()):
from outlines.processors import FSMLogitsProcessor

fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
logits_processor = FSMLogitsProcessor(tokenizer=model.tokenizer, fsm=fsm)
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)
31 changes: 25 additions & 6 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from functools import singledispatch

from outlines.fsm.guide import RegexGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.mlxlm import MLXLM
from outlines.models.transformers import Transformers
from outlines.models.vllm import VLLM
from outlines.generate.api import (
SequenceGenerator,
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import (
MLXLM,
VLLM,
LlamaCpp,
OpenAI,
Transformers,
TransformersVision,
)
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -53,6 +60,18 @@ def regex_unified(
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(TransformersVision)
def regex_vision(
model,
regex_str: str,
sampler: Sampler = multinomial(),
):
from outlines.processors import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(VLLM)
def regex_vllm(
model: VLLM,
Expand Down
20 changes: 18 additions & 2 deletions outlines/generate/text.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from functools import singledispatch

from outlines.fsm.guide import StopAtEOSGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI, Transformers
from outlines.generate.api import (
SequenceGenerator,
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import (
MLXLM,
VLLM,
LlamaCpp,
OpenAI,
Transformers,
TransformersVision,
)
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -43,6 +54,11 @@ def text_unified(model, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(TransformersVision)
def text_vision(model, sampler: Sampler = multinomial()):
return VisionSequenceGeneratorAdapter(model, None, sampler)


@text.register(VLLM)
def text_vllm(model: VLLM, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)
Expand Down
2 changes: 2 additions & 0 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
codebase.
"""

from typing import Union

from .exllamav2 import ExLlamaV2Model, exl2
from .llamacpp import LlamaCpp, llamacpp
from .mlxlm import MLXLM, mlxlm
from .openai import OpenAI, azure_openai, openai
from .transformers import Transformers, TransformerTokenizer, mamba, transformers
from .transformers_vision import TransformersVision, transformers_vision
from .vllm import VLLM, vllm

LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, MLXLM, VLLM]
Loading

0 comments on commit 9ae6e70

Please sign in to comment.