Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #57 from hpcaitech/feature/accumulate
Browse files Browse the repository at this point in the history
update bert
  • Loading branch information
MaruyamaAya authored May 11, 2022
2 parents b493ee3 + ded6f94 commit 21390f7
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
13 changes: 8 additions & 5 deletions examples/bert/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,19 @@ def __init__(self,

def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_lens=None):

batch_size = None
max_padding_size = None
if seq_lens is not None:
batch_size = input_ids.shape[0]
max_padding_size = input_ids.shape[1]
batch_size = input_ids.shape[0]
max_padding_size = input_ids.shape[1]

if self.first:
hidden_states = self.embed(input_ids=input_ids, position_ids=None, tokentype_ids=None, seq_lens=seq_lens,
batch_size=batch_size, max_padding_size=max_padding_size) # , seq_lens

if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0

for block in self.blocks:
hidden_states = block(hidden_states=hidden_states, attention_mask=attention_mask, seq_lens=seq_lens,
batch_size=batch_size, max_padding_size=max_padding_size)
Expand Down
6 changes: 3 additions & 3 deletions examples/bert/bert_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def run():
batch_size = 32

input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len, seq_len), dtype=torch.int64)
attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64)
# seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int64) # generate seq_lens randomly
hidden_states = None
sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)
Expand All @@ -36,8 +36,8 @@ def run():
batch_size = 32

input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len, seq_len), dtype=torch.int64)
seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int64) # generate seq_lens randomly
attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64)
seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int) # generate seq_lens randomly
hidden_states = None
sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask, seq_lens=seq_lens)

Expand Down
2 changes: 1 addition & 1 deletion examples/hf_gpt2/hf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def __init__(self,
self.head = GPTLMHead1D(dim=dim, vocab_size=vocab_size,
dtype=dtype) # word_embeeding_weight=self.embed.word_embedding_weight not in the same process

def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_lens=None):
topk = 5 # TODO: add as a parameter
if self.first:
hidden_states = self.embed(input_ids)
Expand Down

0 comments on commit 21390f7

Please sign in to comment.