Skip to content

Commit

Permalink
update auto chunk and add performance benchmark (#1)
Browse files Browse the repository at this point in the history
* add automatic chunk size and update softmax_cuda_kernel

* add performance benchmark for evoformer
  • Loading branch information
Shenggan authored Aug 17, 2022
1 parent 8aad02f commit cb9cb86
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 51 deletions.
172 changes: 172 additions & 0 deletions benchmark/perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import argparse

import torch
import torch.nn as nn

from xtrimomultimer.model_acc.distributed import init_dap
from xtrimomultimer.model_acc.evoformer import EvoformerBlock as FFEvoformerBlock


def main():

parser = argparse.ArgumentParser(description='Evoformer Standalone Perf Benchmark')
parser.add_argument("--dap-size", default=1, type=int, help='dap size')
parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of MSA')
parser.add_argument('--res-length', default=256, type=int, help='Sequence Length of Residues')
parser.add_argument('--trials', default=50, type=int, help='Number of Trials to Execute')
parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')
parser.add_argument('--layers', default=12, type=int, help='Evoformer Layers to Execute')
parser.add_argument('--cm', default=256, type=int, help='MSA hidden dimension')
parser.add_argument('--cz', default=128, type=int, help='Pair hidden dimension')
parser.add_argument('--heads', default=8, type=int, help='Number of Multihead Attention heads')
parser.add_argument('--openfold',
action='store_true',
help='Benchmark with Evoformer Implementation from OpenFold.')
parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')
parser.add_argument('--prof', action='store_true', help='run with profiler.')

args = parser.parse_args()

init_dap(args.dap_size)

precision = torch.bfloat16
if args.dap_size > 1:
# (PyTorch issue) Currently All2All communication does not support the Bfloat16 datatype in PyTorch
precision = torch.float16

if not torch.cuda.is_available():
raise NotImplementedError('Running on CPU is not supported')

torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)

if args.openfold:
from openfold.model.evoformer import EvoformerBlock

class OpenFoldEvoformer(nn.Module):

def __init__(self, d_node, d_pair):
super(OpenFoldEvoformer, self).__init__()
self.d_node = d_node
self.d_pair = d_pair

self.c_hidden_msa_att = int(d_node / 8)
self.c_hidden_pair_att = int(d_pair / 8)

self.EvoformerBlock = EvoformerBlock(c_m=d_node,
c_z=d_pair,
c_hidden_msa_att=self.c_hidden_msa_att,
c_hidden_opm=self.c_hidden_msa_att,
c_hidden_mul=self.d_pair,
c_hidden_pair_att=self.c_hidden_pair_att,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.25,
inf=1e9,
eps=1e-10)

def forward(self, node, pair, node_mask, pair_mask):
node, pair = self.EvoformerBlock(node, pair, node_mask, pair_mask)
return node, pair

attn_layers = []
for idx in range(0, args.layers):
if args.openfold:
attn_layers.append(OpenFoldEvoformer(d_node=args.cm, d_pair=args.cz))
else:
attn_layers.append(
FFEvoformerBlock(c_m=args.cm,
c_z=args.cz,
first_block=(idx == 0),
last_block=(idx == args.layers - 1)))
attn_layers[idx].cuda()
attn_layers[idx].to(dtype=precision)

start_evt_fwd = []
start_evt_bwd = []
stop_evt_bwd = []
for recorded_trial in range(0, args.trials):
start_evt_fwd.append(torch.cuda.Event(enable_timing=True))
start_evt_bwd.append(torch.cuda.Event(enable_timing=True))
stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))

inputs_node = torch.randn(args.msa_length // args.dap_size,
args.res_length,
args.cm,
dtype=precision,
device=torch.device("cuda")).requires_grad_(True)
inputs_pair = torch.randn(args.res_length // args.dap_size,
args.res_length,
args.cz,
dtype=precision,
device=torch.device("cuda")).requires_grad_(True)
node_mask = torch.ones((args.msa_length, args.res_length),
dtype=precision,
device=torch.device("cuda")).requires_grad_(False)
pair_mask = torch.ones((args.res_length, args.res_length),
dtype=precision,
device=torch.device("cuda")).requires_grad_(False)
grads_node = torch.randn_like(inputs_pair)

if args.prof:
prof = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1,
warmup=args.warmup_trials,
active=args.trials,
repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/fastfold'),
profile_memory=False,
record_shapes=False,
with_stack=False)
prof.start()

for trial in range(0, args.trials + args.warmup_trials):
layer_inputs = inputs_node, inputs_pair
evt_idx = trial - args.warmup_trials

torch.distributed.barrier()
torch.cuda.synchronize()

if evt_idx >= 0:
start_evt_fwd[evt_idx].record()

for lyr_idx in range(0, args.layers):
layer_inputs = attn_layers[lyr_idx].forward(*layer_inputs, node_mask, pair_mask)

torch.cuda.synchronize()

if evt_idx >= 0:
start_evt_bwd[evt_idx].record()

if not args.fwd:
layer_inputs[1].backward(grads_node)

if evt_idx >= 0:
stop_evt_bwd[evt_idx].record()

if args.prof:
prof.step()

if args.prof:
prof.stop()

torch.distributed.barrier()
torch.cuda.synchronize()
elapsed_time_fwd = 0.0
elapsed_time_bwd = 0.0
for evt_idx in range(0, args.trials):
elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx])
elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx])

print("[ MSA Attn ] Input: {:4d}, {:4d}, ({:4d} {:4d}) Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format(
args.msa_length, args.res_length, \
args.cm, args.cz, \
elapsed_time_fwd / ( args.trials * args.layers ), \
elapsed_time_bwd / ( args.trials * args.layers )))


if __name__ == '__main__':
main()
14 changes: 12 additions & 2 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@
torch.set_float32_matmul_precision("high")


def get_chunk_size(seq_len):
if seq_len < 1024:
chunk_size = None
elif seq_len < 2048:
chunk_size = 64
else:
chunk_size = 1
return chunk_size


def precompute_alignments(tags, seqs, alignment_dir, args):
for tag, seq in zip(tags, seqs):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
Expand Down Expand Up @@ -365,8 +375,8 @@ def predict_structure(
os.makedirs(output_dir, exist_ok=True)

if is_fastfold_optimize:
model.globals.chunk_size = 64
set_chunk_size(64)
model.globals.chunk_size = get_chunk_size(processed_feature_dict["aatype"].shape[0])
set_chunk_size(model.globals.chunk_size)

batch = processed_feature_dict
with torch.no_grad():
Expand Down
2 changes: 1 addition & 1 deletion xtrimomultimer/model_acc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from .ops import OutProductMean, set_chunk_size
from .triangle import PairStack

__all__ = ["MSAStack", "OutProductMean", "PairStack", "Evoformer", "set_chunk_size"]
__all__ = ["MSAStack", "OutProductMean", "PairStack", "set_chunk_size"]
Original file line number Diff line number Diff line change
Expand Up @@ -227,22 +227,6 @@ at::Tensor softmax(at::Tensor input, long long rows, long long cols) {
COLS_CASE(14)
COLS_CASE(15)
COLS_CASE(16)
COLS_CASE(17)
COLS_CASE(18)
COLS_CASE(19)
COLS_CASE(20)
COLS_CASE(21)
COLS_CASE(22)
COLS_CASE(23)
COLS_CASE(24)
COLS_CASE(25)
COLS_CASE(26)
COLS_CASE(27)
COLS_CASE(28)
COLS_CASE(29)
COLS_CASE(30)
COLS_CASE(31)
COLS_CASE(32)
#undef COLS_CASE
else {
int grid_dim;
Expand Down Expand Up @@ -476,22 +460,6 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l
COLS_CASE(14)
COLS_CASE(15)
COLS_CASE(16)
COLS_CASE(17)
COLS_CASE(18)
COLS_CASE(19)
COLS_CASE(20)
COLS_CASE(21)
COLS_CASE(22)
COLS_CASE(23)
COLS_CASE(24)
COLS_CASE(25)
COLS_CASE(26)
COLS_CASE(27)
COLS_CASE(28)
COLS_CASE(29)
COLS_CASE(30)
COLS_CASE(31)
COLS_CASE(32)
#undef COLS_CASE
else {
int grid_dim;
Expand Down Expand Up @@ -746,22 +714,6 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
COLS_CASE(14)
COLS_CASE(15)
COLS_CASE(16)
COLS_CASE(17)
COLS_CASE(18)
COLS_CASE(19)
COLS_CASE(20)
COLS_CASE(21)
COLS_CASE(22)
COLS_CASE(23)
COLS_CASE(24)
COLS_CASE(25)
COLS_CASE(26)
COLS_CASE(27)
COLS_CASE(28)
COLS_CASE(29)
COLS_CASE(30)
COLS_CASE(31)
COLS_CASE(32)
#undef COLS_CASE
else {
int grid_dim;
Expand Down

0 comments on commit cb9cb86

Please sign in to comment.