Skip to content

Commit

Permalink
offer support for chinese
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 15, 2021
1 parent 329fc0a commit dbbbcfd
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 4 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,20 @@ ex.
$ python train_dalle.py --image_text_folder ./path/to/data --bpe_path ./path/to/bpe.json
```

#### Chinese

You can train with a <a href="https://huggingface.co/bert-base-chinese">pretrained chinese tokenizer</a> offered by Huggingface 🤗 by simply passing in an extra flag `--chinese`

ex.

```sh
$ python train_dalle.py --chinese --image_text_folder ./path/to/data
```

```sh
$ python generate.py --chinese --text '追老鼠的猫'
```

## Citations

```bibtex
Expand Down
43 changes: 42 additions & 1 deletion dalle_pytorch/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# to give users a quick easy start to training DALL-E without doing BPE

import torch

from tokenizers import Tokenizer
from transformers import BertTokenizer

import html
import os
Expand Down Expand Up @@ -123,6 +125,7 @@ def encode(self, text):
def decode(self, tokens, remove_start_end = True):
if torch.is_tensor(tokens):
tokens = tokens.tolist()

if remove_start_end:
tokens = [token for token in tokens if token not in (49406, 40407, 0)]
text = ''.join([self.decoder[token] for token in tokens])
Expand Down Expand Up @@ -160,7 +163,10 @@ def __init__(self, bpe_path = None):
self.vocab_size = tokenizer.get_vocab_size()

def decode(self, tokens):
tokens = [token for token in tokens.tolist() if token not in (0,)]
if torch.is_tensor(tokens):
tokens = tokens.tolist()

tokens = [token for token in tokens if token not in (0,)]
return self.tokenizer.decode(tokens, skip_special_tokens = True)

def encode(self, text):
Expand All @@ -182,3 +188,38 @@ def tokenize(self, texts, context_length = 256, truncate_text = False):
result[i, :len(tokens)] = torch.tensor(tokens)

return result

# chinese tokenizer

class ChineseTokenizer:
def __init__(self):
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
self.tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size

def decode(self, tokens):
if torch.is_tensor(tokens):
tokens = tokens.tolist()

tokens = [token for token in tokens if token not in (0,)]
return self.tokenizer.decode(tokens)

def encode(self, text):
return torch.tensor(self.tokenizer.encode(text, add_special_tokens = False))

def tokenize(self, texts, context_length = 256, truncate_text = False):
if isinstance(texts, str):
texts = [texts]

all_tokens = [self.encode(text) for text in texts]

result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate_text:
tokens = tokens[:context_length]
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)

return result
11 changes: 10 additions & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,23 @@
parser.add_argument('--bpe_path', type = str,
help='path to your huggingface BPE json file')

parser.add_argument('--chinese', dest='chinese', action = 'store_true')

parser.add_argument('--taming', dest='taming', action='store_true')

args = parser.parse_args()

# helper fns

def exists(val):
return val is not None

# tokenizer

if args.bpe_path is not None:
if exists(args.bpe_path):
tokenizer = HugTokenizer(args.bpe_path)
elif args.chinese:
tokenizer = ChineseTokenizer()

# load DALL-E

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '0.10.1',
version = '0.10.2',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand All @@ -27,6 +27,7 @@
'tokenizers',
'torch>=1.6',
'torchvision',
'transformers',
'tqdm'
],
classifiers=[
Expand Down
6 changes: 5 additions & 1 deletion train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from dalle_pytorch import distributed_utils
from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE1024, DiscreteVAE, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer

# argument parsing

Expand All @@ -41,6 +41,8 @@
parser.add_argument('--truncate_captions', dest='truncate_captions',
help='Captions passed in which exceed the max token length will be truncated if this is set.')

parser.add_argument('--chinese', dest='chinese', action = 'store_true')

parser.add_argument('--taming', dest='taming', action='store_true')

parser.add_argument('--bpe_path', type = str,
Expand Down Expand Up @@ -89,6 +91,8 @@ def exists(val):

if exists(args.bpe_path):
tokenizer = HugTokenizer(args.bpe_path)
elif args.chinese:
tokenizer = ChineseTokenizer()

# reconstitute vae

Expand Down

0 comments on commit dbbbcfd

Please sign in to comment.