This repository has been archived by the owner on Oct 16, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #94 from dujiangsu/feature/opt-175b
add basic model as component
- Loading branch information
Showing
7 changed files
with
473 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .model_factory import gpt2_8B, gpt3, hf_gpt2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import math | ||
import torch | ||
from torch import nn, dtype | ||
|
||
from colossalai.nn.layer.utils import divide | ||
from colossalai.nn import Linear1D_Col, Linear1D_Row | ||
|
||
from energonai.utils import get_current_device | ||
|
||
class MultiHeadAttention1D(nn.Module): | ||
def __init__(self, | ||
hidden_size: int, | ||
num_heads: int, | ||
bias: bool = True, | ||
dtype: dtype = torch.float16, | ||
max_seq_len: int = 512, | ||
fused_qkv: bool = True, | ||
is_decoder:bool = True | ||
) -> None: | ||
super().__init__() | ||
|
||
self.hidden_size = hidden_size | ||
self.attention_head_size = divide(hidden_size, num_heads) | ||
self.fused_qkv = fused_qkv | ||
self.is_decoder = is_decoder | ||
|
||
if fused_qkv: | ||
self.query_key_value = Linear1D_Col(hidden_size, 3 * hidden_size, bias=bias, dtype=dtype) | ||
else: | ||
self.query_ = Linear1D_Col(hidden_size, hidden_size, bias=bias, dtype=dtype) | ||
self.key_ = Linear1D_Col(hidden_size, hidden_size, bias=bias, dtype=dtype) | ||
self.value_ = Linear1D_Col(hidden_size, hidden_size, bias=bias, dtype=dtype) | ||
|
||
self.softmax = nn.Softmax(dim=-1) | ||
|
||
self.dense = Linear1D_Row(hidden_size, hidden_size, bias=True, dtype=dtype, parallel_input=True) | ||
|
||
if is_decoder: | ||
self.causal_mask = torch.tril(torch.ones((max_seq_len, max_seq_len), dtype=torch.uint8, | ||
device=get_current_device())).view(1, 1, max_seq_len, max_seq_len).bool() | ||
self.causal_mask_bias = torch.tensor(-1e4, dtype=dtype, device=get_current_device()) | ||
|
||
def _split_heads(self, tensor, num_heads, attn_head_size): | ||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) | ||
tensor = tensor.view(new_shape) | ||
return tensor.permute(0, 2, 1, 3) | ||
|
||
def forward(self, | ||
hidden_states, | ||
attention_mask=None): | ||
|
||
if self.fused_qkv: | ||
qkv = self.query_key_value(hidden_states) | ||
all_head_size = qkv.shape[-1] // 3 | ||
num_attention_heads = divide(all_head_size, self.attention_head_size) | ||
|
||
qkv = self._split_heads(qkv, num_attention_heads, 3 * self.attention_head_size) | ||
q, k, v = torch.chunk(qkv, 3, dim=-1) | ||
else: | ||
q = self.query_(hidden_states) | ||
k = self.key_(hidden_states) | ||
v = self.value_(hidden_states) | ||
all_head_size = q.shape[-1] | ||
num_attention_heads = divide(all_head_size, self.attention_head_size) | ||
q = self._split_heads(q, num_attention_heads, self.attention_head_size) | ||
k = self._split_heads(k, num_attention_heads, self.attention_head_size) | ||
v = self._split_heads(v, num_attention_heads, self.attention_head_size) | ||
|
||
hidden_states = torch.matmul(q, k.transpose(-1, -2)) | ||
hidden_states = hidden_states / math.sqrt(self.attention_head_size) | ||
q_len, k_len = q.size(-2), k.size(-2) | ||
|
||
if self.is_decoder: | ||
hidden_states = torch.where(self.causal_mask[: ,: ,0:q_len , 0:k_len], hidden_states, self.causal_mask_bias) | ||
|
||
if attention_mask is not None: | ||
hidden_states = hidden_states + attention_mask | ||
hidden_states = self.softmax(hidden_states) | ||
|
||
hidden_states = torch.matmul(hidden_states, v) | ||
|
||
hidden_states = hidden_states.transpose(1, 2) | ||
|
||
new_context_layer_shape = hidden_states.size()[:-2] + (all_head_size,) | ||
|
||
hidden_states = hidden_states.reshape(new_context_layer_shape) | ||
hidden_states = self.dense(hidden_states) | ||
|
||
return hidden_states | ||
|
||
|
||
# causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8, | ||
# device=get_current_device())).view(1, 1, q_len, k_len).bool() | ||
# hidden_states = torch.where(causal_mask, hidden_states, torch.tensor(-1e4, dtype=hidden_states.dtype, device=get_current_device())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from torch import dtype, nn | ||
from colossalai.nn import Classifier1D | ||
|
||
|
||
class LMHead1D(nn.Module): | ||
def __init__(self, | ||
hidden_size:int, | ||
vocab_size:int, | ||
word_embedding_weight: nn.Parameter = None, | ||
bias:bool = False, | ||
dtype: dtype = None) -> None: | ||
super().__init__() | ||
self.dense = Classifier1D(hidden_size, vocab_size, word_embedding_weight, bias=bias, dtype=dtype) | ||
|
||
@property | ||
def weight(self): | ||
return self.dense.weight | ||
|
||
def forward(self, x): | ||
x = self.dense(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import torch | ||
from torch import nn as nn | ||
from torch import dtype | ||
from energonai.nn import VocabParallelEmbedding1D | ||
from energonai.utils import get_current_device | ||
|
||
class Embedding1D(nn.Module): | ||
def __init__(self, | ||
hidden_size: int, | ||
vocab_size: int, | ||
max_seq_len: int, | ||
num_tokentypes: int = 0, | ||
padding_idx: int = 0, | ||
dtype: dtype = None, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.word_embeddings = VocabParallelEmbedding1D(vocab_size, hidden_size, padding_idx=padding_idx, dtype=dtype, skip_tp=True) | ||
|
||
self.position_embeddings = VocabParallelEmbedding1D(max_seq_len, hidden_size, dtype=dtype, skip_tp=True) | ||
|
||
if num_tokentypes > 0: | ||
self.tokentype_embeddings = VocabParallelEmbedding1D(num_tokentypes, hidden_size, dtype=dtype, skip_tp=True) | ||
else: | ||
self.tokentype_embeddings = None | ||
|
||
# self.position_ids = torch.arange(max_seq_len, dtype=torch.long, device=get_current_device()).expand((1, -1)) | ||
|
||
@property | ||
def word_embedding_weight(self): | ||
return self.word_embeddings.weight | ||
|
||
def forward(self, | ||
input_ids, | ||
position_ids=None, | ||
tokentype_ids=None, | ||
past_key_values_length: int = 0): | ||
|
||
seq_length = input_ids.size(1) | ||
if position_ids is None: | ||
position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0) | ||
# position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] | ||
|
||
x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids) | ||
|
||
if self.tokentype_embeddings is not None and tokentype_ids is not None: | ||
x = x + self.tokentype_embeddings(tokentype_ids) | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from typing import Callable | ||
import torch | ||
from torch import dtype | ||
from torch import nn | ||
from colossalai.nn import LayerNorm1D | ||
|
||
from .mlp import MLP1D | ||
from .attention import MultiHeadAttention1D | ||
|
||
|
||
|
||
class Block1D(nn.Module): | ||
def __init__(self, | ||
hidden_size: int, | ||
num_heads: int, | ||
mlp_ratio: float, | ||
activation: Callable = nn.functional.gelu, | ||
layernorm_epsilon:float = 1e-5, | ||
dtype: dtype = torch.float16, | ||
bias: bool = True, | ||
apply_post_layernorm: bool = False, | ||
max_seq_len: int = 512, | ||
fused_qkv:bool = True, | ||
is_decoder:bool = True) -> None: | ||
super().__init__() | ||
|
||
self.apply_post_layernorm = apply_post_layernorm | ||
self.norm1 = LayerNorm1D(hidden_size, eps=layernorm_epsilon) | ||
|
||
self.attn = MultiHeadAttention1D(hidden_size = hidden_size, | ||
num_heads = num_heads, | ||
bias = bias, | ||
dtype = dtype, | ||
max_seq_len = max_seq_len, | ||
fused_qkv = fused_qkv, | ||
is_decoder = is_decoder) | ||
|
||
self.norm2 = LayerNorm1D(hidden_size, eps=layernorm_epsilon) | ||
|
||
self.mlp = MLP1D(hidden_size = hidden_size, | ||
mlp_ratio = mlp_ratio, | ||
activation = activation, | ||
dtype = dtype, | ||
bias = bias) | ||
def forward(self, hidden_states, attention_mask=None): | ||
|
||
if not self.apply_post_layernorm: | ||
residual = hidden_states | ||
hidden_states = self.norm1(hidden_states) | ||
|
||
if self.apply_post_layernorm: | ||
residual = hidden_states | ||
hidden_states = residual + self.attn(hidden_states, attention_mask) | ||
|
||
if not self.apply_post_layernorm: | ||
residual = hidden_states | ||
|
||
hidden_states = self.norm2(hidden_states) | ||
|
||
if self.apply_post_layernorm: | ||
residual = hidden_states | ||
hidden_states = residual + self.mlp(hidden_states) | ||
|
||
return hidden_states |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
|
||
from typing import Callable | ||
import torch | ||
from torch import dtype, nn | ||
from colossalai.nn import Linear1D_Col, Linear1D_Row, Classifier1D | ||
|
||
|
||
class MLP1D(nn.Module): | ||
|
||
def __init__(self, | ||
hidden_size: int, | ||
mlp_ratio: float, | ||
activation: Callable, | ||
dtype: dtype = torch.float16, | ||
bias: bool = True): | ||
super().__init__() | ||
intermediate_dim = int(hidden_size * mlp_ratio) | ||
self.dense_1 = Linear1D_Col(hidden_size, intermediate_dim, bias=bias, dtype=dtype, gather_output=False) | ||
self.activation = activation | ||
self.dense_2 = Linear1D_Row(intermediate_dim, hidden_size, bias=bias, dtype=dtype, parallel_input=True) | ||
|
||
def forward(self, hidden_states): | ||
hidden_states = self.dense_1(hidden_states) | ||
hidden_states = self.activation(hidden_states) | ||
hidden_states = self.dense_2(hidden_states) | ||
return hidden_states |
Oops, something went wrong.