diff --git a/examples/bert/bert.py b/examples/bert/bert.py index 579d71d..6370f06 100644 --- a/examples/bert/bert.py +++ b/examples/bert/bert.py @@ -246,16 +246,19 @@ def __init__(self, def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_lens=None): - batch_size = None - max_padding_size = None - if seq_lens is not None: - batch_size = input_ids.shape[0] - max_padding_size = input_ids.shape[1] + batch_size = input_ids.shape[0] + max_padding_size = input_ids.shape[1] if self.first: hidden_states = self.embed(input_ids=input_ids, position_ids=None, tokentype_ids=None, seq_lens=seq_lens, batch_size=batch_size, max_padding_size=max_padding_size) # , seq_lens + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + for block in self.blocks: hidden_states = block(hidden_states=hidden_states, attention_mask=attention_mask, seq_lens=seq_lens, batch_size=batch_size, max_padding_size=max_padding_size) diff --git a/examples/bert/bert_server.py b/examples/bert/bert_server.py index 4f8fe1a..c94d981 100644 --- a/examples/bert/bert_server.py +++ b/examples/bert/bert_server.py @@ -19,7 +19,7 @@ def run(): batch_size = 32 input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64) - attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len, seq_len), dtype=torch.int64) + attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64) # seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int64) # generate seq_lens randomly hidden_states = None sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask) @@ -36,8 +36,8 @@ def run(): batch_size = 32 input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64) - attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len, seq_len), dtype=torch.int64) - seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int64) # generate seq_lens randomly + attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64) + seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int) # generate seq_lens randomly hidden_states = None sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask, seq_lens=seq_lens) diff --git a/examples/hf_gpt2/hf_gpt2.py b/examples/hf_gpt2/hf_gpt2.py index 44a31bf..ad806ef 100644 --- a/examples/hf_gpt2/hf_gpt2.py +++ b/examples/hf_gpt2/hf_gpt2.py @@ -377,7 +377,7 @@ def __init__(self, self.head = GPTLMHead1D(dim=dim, vocab_size=vocab_size, dtype=dtype) # word_embeeding_weight=self.embed.word_embedding_weight not in the same process - def forward(self, hidden_states=None, input_ids=None, attention_mask=None): + def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_lens=None): topk = 5 # TODO: add as a parameter if self.first: hidden_states = self.embed(input_ids)