Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

update bert #57

Merged
merged 1 commit into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions examples/bert/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,19 @@ def __init__(self,

def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_lens=None):

batch_size = None
max_padding_size = None
if seq_lens is not None:
batch_size = input_ids.shape[0]
max_padding_size = input_ids.shape[1]
batch_size = input_ids.shape[0]
max_padding_size = input_ids.shape[1]

if self.first:
hidden_states = self.embed(input_ids=input_ids, position_ids=None, tokentype_ids=None, seq_lens=seq_lens,
batch_size=batch_size, max_padding_size=max_padding_size) # , seq_lens

if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0

for block in self.blocks:
hidden_states = block(hidden_states=hidden_states, attention_mask=attention_mask, seq_lens=seq_lens,
batch_size=batch_size, max_padding_size=max_padding_size)
Expand Down
6 changes: 3 additions & 3 deletions examples/bert/bert_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def run():
batch_size = 32

input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len, seq_len), dtype=torch.int64)
attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64)
# seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int64) # generate seq_lens randomly
hidden_states = None
sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)
Expand All @@ -36,8 +36,8 @@ def run():
batch_size = 32

input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len, seq_len), dtype=torch.int64)
seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int64) # generate seq_lens randomly
attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64)
seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int) # generate seq_lens randomly
hidden_states = None
sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask, seq_lens=seq_lens)

Expand Down
2 changes: 1 addition & 1 deletion examples/hf_gpt2/hf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def __init__(self,
self.head = GPTLMHead1D(dim=dim, vocab_size=vocab_size,
dtype=dtype) # word_embeeding_weight=self.embed.word_embedding_weight not in the same process

def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_lens=None):
topk = 5 # TODO: add as a parameter
if self.first:
hidden_states = self.embed(input_ids)
Expand Down