Skip to content

Commit

Permalink
More HF Llama tokenizer space fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Sep 16, 2023
1 parent 702bbe7 commit 9229a5d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 20 deletions.
64 changes: 44 additions & 20 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,54 @@
import torch
from transformers.file_utils import SPIECE_UNDERLINE

try:
from transformers.models.llama.tokenization_llama import LlamaTokenizer
except ImportError:
from outlines.models.tokenizer import Tokenizer

class LlamaTokenizer: # type: ignore
pass
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer

__all__ = ["transformers"]

try:
from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast
except ImportError:

class LlamaTokenizerFast: # type: ignore
pass
def get_llama_tokenizer_types():
"""Get all the Llama tokenizer types/classes that need work-arounds.
When they can't be imported, a dummy class is created.
from outlines.models.tokenizer import Tokenizer
"""
try:
from transformers.models.llama import LlamaTokenizer
except ImportError:

if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
class LlamaTokenizer: # type: ignore
pass

__all__ = ["transformers"]
try:
from transformers.models.llama import LlamaTokenizerFast
except ImportError:

class LlamaTokenizerFast: # type: ignore
pass

try:
from transformers.models.code_llama import CodeLlamaTokenizer
except ImportError:

class CodeLlamaTokenizer: # type: ignore
pass

try:
from transformers.models.code_llama import CodeLlamaTokenizerFast
except ImportError:

class CodeLlamaTokenizerFast: # type: ignore
pass

return (
LlamaTokenizer,
LlamaTokenizerFast,
CodeLlamaTokenizer,
CodeLlamaTokenizerFast,
)


class Transformers:
Expand Down Expand Up @@ -83,9 +109,7 @@ def __init__(self, model_name: str, **kwargs):
self.pad_token = self.tokenizer.pad_token

self.vocabulary = self.tokenizer.get_vocab()
self.is_sentencepiece = isinstance(
self.tokenizer, (LlamaTokenizerFast, LlamaTokenizer)
)
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())

def encode(
self, prompt: Union[str, List[str]], **kwargs
Expand All @@ -102,9 +126,9 @@ def decode(self, token_ids: torch.LongTensor) -> List[str]:
def convert_token_to_string(self, token: str) -> str:
string = self.tokenizer.convert_tokens_to_string([token])

if self.is_sentencepiece:
# A hack to handle missing spaces from SentencePiece tokenizers
if token.startswith(SPIECE_UNDERLINE):
if self.is_llama:
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string
Expand Down
4 changes: 4 additions & 0 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ def test_llama_tokenizer():

# Broken
assert tokenizer.tokenizer.convert_tokens_to_string(["▁baz"]) == "baz"
assert tokenizer.tokenizer.convert_tokens_to_string(["<0x20>"]) == ""
assert tokenizer.tokenizer.convert_tokens_to_string(["▁▁▁"]) == " "

# Not broken
assert tokenizer.convert_token_to_string("▁baz") == " baz"
assert tokenizer.convert_token_to_string("<0x20>") == " "
assert tokenizer.convert_token_to_string("▁▁▁") == " "


def test_model():
Expand Down

0 comments on commit 9229a5d

Please sign in to comment.