Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

66B model load checkpoint #137

Merged
merged 3 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions energonai/model/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,13 @@ def __init__(self,
checkpoint: str = None,
model_name: str = None,
is_decoder: bool = True,
disable_past_cache = False,
last_layer_norm = True) -> None:
disable_past_cache = False) -> None:
super().__init__()
self.hidden_size = hidden_size
self.first = first
self.last = last
self.max_seq_len = max_seq_len
self.model_name = model_name
self.last_layer_norm = last_layer_norm

if first:
self.embed = Embedding1D(hidden_size=hidden_size,
Expand All @@ -86,8 +84,7 @@ def __init__(self,
is_decoder=is_decoder,
disable_past_cache=disable_past_cache))
if last:
if self.last_layer_norm:
self.norm = LayerNorm1D(normalized_shape=hidden_size, eps=layernorm_epsilon)
self.norm = LayerNorm1D(normalized_shape=hidden_size, eps=layernorm_epsilon)
self.head = LMHead1D(hidden_size=hidden_size, vocab_size=vocab_size, bias=False, dtype=dtype)

def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_lens=None, max_tokens: Optional[int] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None):
Expand Down Expand Up @@ -116,8 +113,7 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_l
first_cache = first_cache)

if self.last:
if self.last_layer_norm:
hidden_states = self.norm(hidden_states)
hidden_states = self.norm(hidden_states)
hidden_states = self.head(hidden_states)
hidden_states = self.generate(input_ids, hidden_states, top_k=top_k,
top_p=top_p, temperature=temperature)
Expand Down Expand Up @@ -307,7 +303,6 @@ def opt_66B(**kwargs):
fused_qkv=False,
model_name="opt",
disable_past_cache=False,
last_layer_norm = False,
**kwargs)
return create_pipeline_model(**model_kwargs)

Expand Down
6 changes: 4 additions & 2 deletions energonai/utils/checkpointing_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

name_map = {
'embed_tokens': 'embed.word_embeddings',
'embed_positions': 'position_embeddings',
'embed_positions': 'embed.position_embeddings',
# 'layers': 'blocks',
'self_attn.q_proj': 'attn.query_',
'self_attn.k_proj': 'attn.key_',
Expand Down Expand Up @@ -86,7 +86,9 @@ def processing_OPT(state_dict: OrderedDict):
# else:
# new_dict[new_k] = new_v
# print(new_dict.keys())
new_dict['head.dense.weight'] = new_dict['embed.word_embeddings.weight'].clone()
if 'head.dense.weight' not in new_dict:
new_dict['head.dense.weight'] = new_dict['embed.word_embeddings.weight'].clone()

del new_dict['decoder.version']
# print("="*100)
# print(new_dict.keys())
Expand Down
18 changes: 11 additions & 7 deletions examples/opt/opt_config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from energonai.model import opt_30B, opt_125M
from energonai.model import opt_30B, opt_125M, opt_66B
from opt_server import launch_engine

# for engine
model_class = opt_30B
model_class = opt_66B
model_type = "gpt"
host = "127.0.0.1"
port = 29402
half = True
# checkpoint = "/data/user/djs_model_checkpoint/opt_metaseq_125m/model/restored.pt"

# If serving using a docker, map your own checkpoint directory to /model_checkpoint
checkpoint = '/model_checkpoint/'
# "/data/user/djs_model_checkpoint/opt-30B-singleton/opt_metaseq_30000m/model/restored.pt"
# checkpoint = '/model_checkpoint/'

# checkpoint = "/data/user/djs_model_checkpoint/opt_metaseq_125m/model/restored.pt"
# checkpoint = "/data/user/lclhx/opt-30B"
checkpoint="/data/user/djs_model_checkpoint/opt-66B-fragment"

backend = "nccl"

# for parallel
Expand All @@ -20,8 +23,9 @@

# for server
engine_server = launch_engine
tokenizer_path = "facebook/opt-350m"
# server_host = "127.0.0.1"
# tokenizer_path = "facebook/opt-125m"
tokenizer_path = "facebook/opt-30b"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use the 30b tokenizer? The 66b tokenizer has bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the config is for 30B model by default.
I add the 66B tokenizer as the candidate.

# tokenizer_path = "facebook/opt-66b"
server_host = "0.0.0.0"
server_port = 8020
log_level = "info"
Expand Down