Skip to content

Commit

Permalink
[AutoMM] Support customizing use_fast for AutoTokenizer (open-mmlab#3379
Browse files Browse the repository at this point in the history
)
  • Loading branch information
zhiqiangdon authored Jul 7, 2023
1 parent 03cc58c commit fa5d2d4
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ model:
data_types:
- "text"
tokenizer_name: "hf_auto"
use_fast: True # Use a fast Rust-based tokenizer if it is supported for a given model. If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead.
max_text_len: 512 # If None or <=0, then use the max length of pretrained models.
insert_sep: True
low_cpu_mem_usage: False
Expand Down
14 changes: 13 additions & 1 deletion multimodal/src/autogluon/multimodal/data/process_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
train_augment_types: Optional[List[str]] = None,
template_config: Optional[DictConfig] = None,
normalize_text: Optional[bool] = False,
use_fast: Optional[bool] = True,
):
"""
Parameters
Expand Down Expand Up @@ -125,6 +126,11 @@ def __init__(
Whether to normalize text to resolve encoding problems.
Examples of normalized texts can be found at
https://github.com/autogluon/autogluon/tree/master/examples/automm/kaggle_feedback_prize#15-a-few-examples-of-normalized-texts
use_fast
Use a fast Rust-based tokenizer if it is supported for a given model.
If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead.
See: https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoTokenizer.from_pretrained.use_fast
"""
self.prefix = model.prefix
self.tokenizer_name = tokenizer_name
Expand All @@ -136,6 +142,7 @@ def __init__(
self.tokenizer = self.get_pretrained_tokenizer(
tokenizer_name=tokenizer_name,
checkpoint_name=model.checkpoint_name,
use_fast=use_fast,
)
if hasattr(self.tokenizer, "deprecation_warnings"):
# Disable the warning "Token indices sequence length is longer than the specified maximum sequence..."
Expand Down Expand Up @@ -410,6 +417,7 @@ def get_special_tokens(tokenizer):
def get_pretrained_tokenizer(
tokenizer_name: str,
checkpoint_name: str,
use_fast: Optional[bool] = True,
):
"""
Load the tokenizer for a pre-trained huggingface checkpoint.
Expand All @@ -420,14 +428,18 @@ def get_pretrained_tokenizer(
The tokenizer type, e.g., "bert", "clip", "electra", and "hf_auto".
checkpoint_name
Name of a pre-trained checkpoint.
use_fast
Use a fast Rust-based tokenizer if it is supported for a given model.
If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead.
See: https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoTokenizer.from_pretrained.use_fast
Returns
-------
A tokenizer instance.
"""
try:
tokenizer_class = ALL_TOKENIZERS[tokenizer_name]
return tokenizer_class.from_pretrained(checkpoint_name)
return tokenizer_class.from_pretrained(checkpoint_name, use_fast=use_fast)
except TypeError as e:
try:
tokenizer_class = ALL_TOKENIZERS["bert"]
Expand Down
1 change: 1 addition & 0 deletions multimodal/src/autogluon/multimodal/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def create_data_processor(
train_augment_types=OmegaConf.select(model_config, "text_train_augment_types"),
template_config=getattr(config.data, "templates", OmegaConf.create({"turn_on": False})),
normalize_text=getattr(config.data.text, "normalize_text", False),
use_fast=OmegaConf.select(model_config, "use_fast", default=True),
)
elif data_type == CATEGORICAL:
data_processor = CategoricalProcessor(
Expand Down
1 change: 1 addition & 0 deletions multimodal/tests/hf_model_list.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ others_2:
- t5-small
- microsoft/layoutlmv3-base
- microsoft/layoutlmv2-base-uncased
- albert-base-v2
predictor:
- CLTL/MedRoBERTa.nl
- google/electra-small-discriminator
Expand Down
66 changes: 66 additions & 0 deletions multimodal/tests/unittests/others_2/test_data_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import shutil
import tempfile

import pytest
from transformers import AlbertTokenizer, AlbertTokenizerFast

from autogluon.multimodal import MultiModalPredictor
from autogluon.multimodal.constants import TEXT

from ..utils.unittest_datasets import AEDataset, HatefulMeMesDataset, IDChangeDetectionDataset, PetFinderDataset

ALL_DATASETS = {
"petfinder": PetFinderDataset,
"hateful_memes": HatefulMeMesDataset,
"ae": AEDataset,
}


@pytest.mark.parametrize(
"checkpoint_name,use_fast,tokenizer_type",
[
(
"albert-base-v2",
None,
AlbertTokenizerFast,
),
(
"albert-base-v2",
True,
AlbertTokenizerFast,
),
(
"albert-base-v2",
False,
AlbertTokenizer,
),
],
)
def test_tokenizer_use_fast(checkpoint_name, use_fast, tokenizer_type):
dataset = ALL_DATASETS["ae"]()
metric_name = dataset.metric

predictor = MultiModalPredictor(
label=dataset.label_columns[0],
problem_type=dataset.problem_type,
eval_metric=metric_name,
)
hyperparameters = {
"data.categorical.convert_to_text": True,
"data.numerical.convert_to_text": True,
"model.hf_text.checkpoint_name": checkpoint_name,
}
if use_fast is not None:
hyperparameters["model.hf_text.use_fast"] = use_fast

with tempfile.TemporaryDirectory() as save_path:
if os.path.isdir(save_path):
shutil.rmtree(save_path)
predictor.fit(
train_data=dataset.train_df,
time_limit=5,
save_path=save_path,
hyperparameters=hyperparameters,
)
assert isinstance(predictor._data_processors[TEXT][0].tokenizer, tokenizer_type)

0 comments on commit fa5d2d4

Please sign in to comment.