forked from pytorch/serve
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bert_tokenizer.py
49 lines (43 loc) · 1.54 KB
/
bert_tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import uuid
import json
from argparse import ArgumentParser
from transformers import AutoTokenizer
parser = ArgumentParser()
parser.add_argument("--input_text", help="Input text", type=str, required=True)
parser.add_argument("--model_name", help="bert model name", default="bert-base-uncased", type=str)
parser.add_argument("--do_lower_case", help="Use lower case", default=True, type=bool)
parser.add_argument("--max_length", help="Max length of the string", default=150, type=int)
parser.add_argument("--result_file", help="Path to result file", default="bert_v2.json", type=str)
args = vars(parser.parse_args())
tokenizer = AutoTokenizer.from_pretrained(args["model_name"], do_lower_case=True)
print("Tokenizing input")
tokenized_input = tokenizer.encode_plus(
args["input_text"],
max_length=args["max_length"],
pad_to_max_length=True,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = tokenized_input["input_ids"]
attention_mask = tokenized_input["attention_mask"]
request = {
"id": str(uuid.uuid4()),
"inputs": [
{
"name": "input_ids",
"shape": input_ids.shape[1],
"datatype": "INT64",
"data": input_ids[0].tolist(),
},
{
"name": "attention_masks",
"shape": attention_mask.shape[1],
"datatype": "INT64",
"data": attention_mask[0].tolist(),
},
],
}
result_file = args["result_file"]
print("Generating input file: ", result_file)
with open(result_file, "w") as outfile:
json.dump(request, outfile)