diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 6b443a8fa..80ee813b9 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -4,28 +4,54 @@ import torch from transformers.file_utils import SPIECE_UNDERLINE -try: - from transformers.models.llama.tokenization_llama import LlamaTokenizer -except ImportError: +from outlines.models.tokenizer import Tokenizer - class LlamaTokenizer: # type: ignore - pass +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer +__all__ = ["transformers"] -try: - from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast -except ImportError: - class LlamaTokenizerFast: # type: ignore - pass +def get_llama_tokenizer_types(): + """Get all the Llama tokenizer types/classes that need work-arounds. + When they can't be imported, a dummy class is created. -from outlines.models.tokenizer import Tokenizer + """ + try: + from transformers.models.llama import LlamaTokenizer + except ImportError: -if TYPE_CHECKING: - from transformers import PreTrainedModel, PreTrainedTokenizer + class LlamaTokenizer: # type: ignore + pass -__all__ = ["transformers"] + try: + from transformers.models.llama import LlamaTokenizerFast + except ImportError: + + class LlamaTokenizerFast: # type: ignore + pass + + try: + from transformers.models.code_llama import CodeLlamaTokenizer + except ImportError: + + class CodeLlamaTokenizer: # type: ignore + pass + + try: + from transformers.models.code_llama import CodeLlamaTokenizerFast + except ImportError: + + class CodeLlamaTokenizerFast: # type: ignore + pass + + return ( + LlamaTokenizer, + LlamaTokenizerFast, + CodeLlamaTokenizer, + CodeLlamaTokenizerFast, + ) class Transformers: @@ -83,9 +109,7 @@ def __init__(self, model_name: str, **kwargs): self.pad_token = self.tokenizer.pad_token self.vocabulary = self.tokenizer.get_vocab() - self.is_sentencepiece = isinstance( - self.tokenizer, (LlamaTokenizerFast, LlamaTokenizer) - ) + self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types()) def encode( self, prompt: Union[str, List[str]], **kwargs @@ -102,9 +126,9 @@ def decode(self, token_ids: torch.LongTensor) -> List[str]: def convert_token_to_string(self, token: str) -> str: string = self.tokenizer.convert_tokens_to_string([token]) - if self.is_sentencepiece: - # A hack to handle missing spaces from SentencePiece tokenizers - if token.startswith(SPIECE_UNDERLINE): + if self.is_llama: + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": return " " + string return string diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index f44aca4c7..a7e9f9842 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -43,9 +43,13 @@ def test_llama_tokenizer(): # Broken assert tokenizer.tokenizer.convert_tokens_to_string(["▁baz"]) == "baz" + assert tokenizer.tokenizer.convert_tokens_to_string(["<0x20>"]) == "" + assert tokenizer.tokenizer.convert_tokens_to_string(["▁▁▁"]) == " " # Not broken assert tokenizer.convert_token_to_string("▁baz") == " baz" + assert tokenizer.convert_token_to_string("<0x20>") == " " + assert tokenizer.convert_token_to_string("▁▁▁") == " " def test_model():