Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #94 from dujiangsu/feature/opt-175b
Browse files Browse the repository at this point in the history
add basic model as component
  • Loading branch information
MaruyamaAya authored Jul 8, 2022
2 parents fa363d3 + 659e368 commit 69bf5a6
Show file tree
Hide file tree
Showing 7 changed files with 473 additions and 0 deletions.
1 change: 1 addition & 0 deletions energonai/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model_factory import gpt2_8B, gpt3, hf_gpt2
94 changes: 94 additions & 0 deletions energonai/model/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import math
import torch
from torch import nn, dtype

from colossalai.nn.layer.utils import divide
from colossalai.nn import Linear1D_Col, Linear1D_Row

from energonai.utils import get_current_device

class MultiHeadAttention1D(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
bias: bool = True,
dtype: dtype = torch.float16,
max_seq_len: int = 512,
fused_qkv: bool = True,
is_decoder:bool = True
) -> None:
super().__init__()

self.hidden_size = hidden_size
self.attention_head_size = divide(hidden_size, num_heads)
self.fused_qkv = fused_qkv
self.is_decoder = is_decoder

if fused_qkv:
self.query_key_value = Linear1D_Col(hidden_size, 3 * hidden_size, bias=bias, dtype=dtype)
else:
self.query_ = Linear1D_Col(hidden_size, hidden_size, bias=bias, dtype=dtype)
self.key_ = Linear1D_Col(hidden_size, hidden_size, bias=bias, dtype=dtype)
self.value_ = Linear1D_Col(hidden_size, hidden_size, bias=bias, dtype=dtype)

self.softmax = nn.Softmax(dim=-1)

self.dense = Linear1D_Row(hidden_size, hidden_size, bias=True, dtype=dtype, parallel_input=True)

if is_decoder:
self.causal_mask = torch.tril(torch.ones((max_seq_len, max_seq_len), dtype=torch.uint8,
device=get_current_device())).view(1, 1, max_seq_len, max_seq_len).bool()
self.causal_mask_bias = torch.tensor(-1e4, dtype=dtype, device=get_current_device())

def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3)

def forward(self,
hidden_states,
attention_mask=None):

if self.fused_qkv:
qkv = self.query_key_value(hidden_states)
all_head_size = qkv.shape[-1] // 3
num_attention_heads = divide(all_head_size, self.attention_head_size)

qkv = self._split_heads(qkv, num_attention_heads, 3 * self.attention_head_size)
q, k, v = torch.chunk(qkv, 3, dim=-1)
else:
q = self.query_(hidden_states)
k = self.key_(hidden_states)
v = self.value_(hidden_states)
all_head_size = q.shape[-1]
num_attention_heads = divide(all_head_size, self.attention_head_size)
q = self._split_heads(q, num_attention_heads, self.attention_head_size)
k = self._split_heads(k, num_attention_heads, self.attention_head_size)
v = self._split_heads(v, num_attention_heads, self.attention_head_size)

hidden_states = torch.matmul(q, k.transpose(-1, -2))
hidden_states = hidden_states / math.sqrt(self.attention_head_size)
q_len, k_len = q.size(-2), k.size(-2)

if self.is_decoder:
hidden_states = torch.where(self.causal_mask[: ,: ,0:q_len , 0:k_len], hidden_states, self.causal_mask_bias)

if attention_mask is not None:
hidden_states = hidden_states + attention_mask
hidden_states = self.softmax(hidden_states)

hidden_states = torch.matmul(hidden_states, v)

hidden_states = hidden_states.transpose(1, 2)

new_context_layer_shape = hidden_states.size()[:-2] + (all_head_size,)

hidden_states = hidden_states.reshape(new_context_layer_shape)
hidden_states = self.dense(hidden_states)

return hidden_states


# causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8,
# device=get_current_device())).view(1, 1, q_len, k_len).bool()
# hidden_states = torch.where(causal_mask, hidden_states, torch.tensor(-1e4, dtype=hidden_states.dtype, device=get_current_device()))
21 changes: 21 additions & 0 deletions energonai/model/downstream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from torch import dtype, nn
from colossalai.nn import Classifier1D


class LMHead1D(nn.Module):
def __init__(self,
hidden_size:int,
vocab_size:int,
word_embedding_weight: nn.Parameter = None,
bias:bool = False,
dtype: dtype = None) -> None:
super().__init__()
self.dense = Classifier1D(hidden_size, vocab_size, word_embedding_weight, bias=bias, dtype=dtype)

@property
def weight(self):
return self.dense.weight

def forward(self, x):
x = self.dense(x)
return x
49 changes: 49 additions & 0 deletions energonai/model/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from torch import nn as nn
from torch import dtype
from energonai.nn import VocabParallelEmbedding1D
from energonai.utils import get_current_device

class Embedding1D(nn.Module):
def __init__(self,
hidden_size: int,
vocab_size: int,
max_seq_len: int,
num_tokentypes: int = 0,
padding_idx: int = 0,
dtype: dtype = None,
) -> None:
super().__init__()

self.word_embeddings = VocabParallelEmbedding1D(vocab_size, hidden_size, padding_idx=padding_idx, dtype=dtype, skip_tp=True)

self.position_embeddings = VocabParallelEmbedding1D(max_seq_len, hidden_size, dtype=dtype, skip_tp=True)

if num_tokentypes > 0:
self.tokentype_embeddings = VocabParallelEmbedding1D(num_tokentypes, hidden_size, dtype=dtype, skip_tp=True)
else:
self.tokentype_embeddings = None

# self.position_ids = torch.arange(max_seq_len, dtype=torch.long, device=get_current_device()).expand((1, -1))

@property
def word_embedding_weight(self):
return self.word_embeddings.weight

def forward(self,
input_ids,
position_ids=None,
tokentype_ids=None,
past_key_values_length: int = 0):

seq_length = input_ids.size(1)
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0)
# position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids)

if self.tokentype_embeddings is not None and tokentype_ids is not None:
x = x + self.tokentype_embeddings(tokentype_ids)

return x
64 changes: 64 additions & 0 deletions energonai/model/endecoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Callable
import torch
from torch import dtype
from torch import nn
from colossalai.nn import LayerNorm1D

from .mlp import MLP1D
from .attention import MultiHeadAttention1D



class Block1D(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
mlp_ratio: float,
activation: Callable = nn.functional.gelu,
layernorm_epsilon:float = 1e-5,
dtype: dtype = torch.float16,
bias: bool = True,
apply_post_layernorm: bool = False,
max_seq_len: int = 512,
fused_qkv:bool = True,
is_decoder:bool = True) -> None:
super().__init__()

self.apply_post_layernorm = apply_post_layernorm
self.norm1 = LayerNorm1D(hidden_size, eps=layernorm_epsilon)

self.attn = MultiHeadAttention1D(hidden_size = hidden_size,
num_heads = num_heads,
bias = bias,
dtype = dtype,
max_seq_len = max_seq_len,
fused_qkv = fused_qkv,
is_decoder = is_decoder)

self.norm2 = LayerNorm1D(hidden_size, eps=layernorm_epsilon)

self.mlp = MLP1D(hidden_size = hidden_size,
mlp_ratio = mlp_ratio,
activation = activation,
dtype = dtype,
bias = bias)
def forward(self, hidden_states, attention_mask=None):

if not self.apply_post_layernorm:
residual = hidden_states
hidden_states = self.norm1(hidden_states)

if self.apply_post_layernorm:
residual = hidden_states
hidden_states = residual + self.attn(hidden_states, attention_mask)

if not self.apply_post_layernorm:
residual = hidden_states

hidden_states = self.norm2(hidden_states)

if self.apply_post_layernorm:
residual = hidden_states
hidden_states = residual + self.mlp(hidden_states)

return hidden_states
26 changes: 26 additions & 0 deletions energonai/model/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

from typing import Callable
import torch
from torch import dtype, nn
from colossalai.nn import Linear1D_Col, Linear1D_Row, Classifier1D


class MLP1D(nn.Module):

def __init__(self,
hidden_size: int,
mlp_ratio: float,
activation: Callable,
dtype: dtype = torch.float16,
bias: bool = True):
super().__init__()
intermediate_dim = int(hidden_size * mlp_ratio)
self.dense_1 = Linear1D_Col(hidden_size, intermediate_dim, bias=bias, dtype=dtype, gather_output=False)
self.activation = activation
self.dense_2 = Linear1D_Row(intermediate_dim, hidden_size, bias=bias, dtype=dtype, parallel_input=True)

def forward(self, hidden_states):
hidden_states = self.dense_1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.dense_2(hidden_states)
return hidden_states
Loading

0 comments on commit 69bf5a6

Please sign in to comment.