Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

add bert example #39

Merged
merged 1 commit into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions examples/bert/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,20 @@ def forward(self, hidden_states, attention_mask=None, seq_lens=None, batch_size=
new_qkv_shape = attention_output.shape[:-1] + (num_attention_heads, 3*self.attention_head_size)
attention_output = attention_output.view(new_qkv_shape)

print(f'1: {attention_output.size()}')
# print(f'1: {attention_output.size()}')
if seq_lens is not None:
# TODO: use FasterTransformer's implementation.
attention_output = transpose_pad(attention_output, batch_size, max_padding_size, seq_lens, num_attention_heads, self.attention_head_size*3)
else:
attention_output = attention_output.permute(0, 2, 1, 3)
# TODO: make sure self.attention_head_size*3 is correct

print(f'2: {attention_output.size()}')
# print(f'2: {attention_output.size()}')

q, k, v = torch.chunk(attention_output, 3, dim = -1)

attention_output = torch.matmul(q, k.transpose(-1, -2))
print(f'3: {attention_output.size()}')
# print(f'3: {attention_output.size()}')
if self.fuse_scale_mask_softmax:
raise NotImplementedError
else:
Expand All @@ -124,15 +125,15 @@ def forward(self, hidden_states, attention_mask=None, seq_lens=None, batch_size=

attention_output = torch.matmul(attention_output, v)

print(f'4: {attention_output.size()}')
# print(f'4: {attention_output.size()}')

if seq_lens is not None:
sum_seq = torch.sum(seq_lens)
attention_output = transpose_depad(attention_output, batch_size, sum_seq, max_padding_size, seq_lens, num_attention_heads, self.attention_head_size)
else:
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()

print(f'5: {attention_output.size()}')
# print(f'5: {attention_output.size()}')


new_context_layer_shape = attention_output.size()[:-2] + (all_head_size,)
Expand Down Expand Up @@ -258,8 +259,8 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_l
batch_size = input_ids.shape[0]
max_padding_size = input_ids.shape[1]

print(self.first)
print(self.last)
# print(self.first)
# print(self.last)

if self.first:
hidden_states = self.embed(input_ids=input_ids, position_ids=None, tokentype_ids=None, seq_lens=seq_lens, batch_size=batch_size, max_padding_size=max_padding_size) #, seq_lens
Expand All @@ -269,7 +270,7 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_l

if self.last:
hidden_states = hidden_states[:, 1, :]
print(f'Hidden States: {hidden_states.size()}')
# print(f'Hidden States: {hidden_states.size()}')

return hidden_states

Expand Down
15 changes: 15 additions & 0 deletions examples/bert/bert_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from bert import bert_small, bert_large, bert_xl, bert_8B, bert_175B
from bert_server import launch_engine

model_class = bert_8B
model_type = "bert"
engine_server = launch_engine
tp_init_size = 2
pp_init_size = 2
host = "127.0.0.1"
port = 29400
half = False
server_host = "127.0.0.1"
server_port = 8010
log_level = "info"
backend = "nccl"
92 changes: 92 additions & 0 deletions examples/bert/bert_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
import torch
import uvicorn
from fastapi import FastAPI
from fastapi import Response
import torch.distributed.rpc as rpc
from energon.engine import InferenceEngine

app = FastAPI() # 创建 api 对象

@app.get("/") # 根路由
def root():
return {"200"}

@app.get("/model_with_padding")
def run():
# for the performance only
seq_len = 512
batch_size = 32

input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len, seq_len), dtype=torch.int64)
# seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int64) # generate seq_lens randomly
hidden_states = None
sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)

output = engine.run(sample)
output = output.to_here()
print(output)
return {"To return the string result."}

@app.get("/model_rm_padding")
def run():
# for the performance only
seq_len = 512
batch_size = 32

input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len, seq_len), dtype=torch.int64)
seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int64) # generate seq_lens randomly
hidden_states = None
sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask, seq_lens=seq_lens)

output = engine.run(sample)
output = output.to_here()
print(output)
return {"To return the string result."}


@app.get("/shutdown")
async def shutdown():
engine.clear()
server.should_exit = True
server.force_exit = True
await server.shutdown()


def launch_engine(model_name,
model_type,
max_batch_size: int = 1,
tp_init_size: int = -1,
pp_init_size: int = -1,
host: str = "localhost",
port: int = 29500,
dtype = torch.float,
checkpoint: str = None,
tokenizer_path: str = None,
server_host = "localhost",
server_port = 8005,
log_level = "info"
):

if checkpoint:
model_config = {'dtype': dtype, 'checkpoint': True, 'checkpoint_path': checkpoint}
else:
model_config = {'dtype': dtype}

global engine
engine = InferenceEngine(model_name,
model_config,
model_type,
max_batch_size = max_batch_size,
tp_init_size = tp_init_size,
pp_init_size = pp_init_size,
host = host,
port = port,
dtype = dtype)

global server
config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level)
server = uvicorn.Server(config=config)
server.run()
2 changes: 1 addition & 1 deletion examples/hf_gpt2/hf_gpt2_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def root():
return {"200"}

@app.get("/run_hf_gpt2/{request}")
@app.get("/run/{request}")
def run(request: str, max_seq_length: int):

input_token = tokenizer(request, return_tensors="pt")
Expand Down