From cb9cb8673d2398860140ba61b890041bc47f3d91 Mon Sep 17 00:00:00 2001 From: shenggan Date: Wed, 17 Aug 2022 14:24:20 +0800 Subject: [PATCH] update auto chunk and add performance benchmark (#1) * add automatic chunk size and update softmax_cuda_kernel * add performance benchmark for evoformer --- benchmark/perf.py | 172 ++++++++++++++++++ inference.py | 14 +- xtrimomultimer/model_acc/__init__.py | 2 +- .../cuda_native/csrc/softmax_cuda_kernel.cu | 48 ----- 4 files changed, 185 insertions(+), 51 deletions(-) create mode 100644 benchmark/perf.py diff --git a/benchmark/perf.py b/benchmark/perf.py new file mode 100644 index 0000000..9e11438 --- /dev/null +++ b/benchmark/perf.py @@ -0,0 +1,172 @@ +import argparse + +import torch +import torch.nn as nn + +from xtrimomultimer.model_acc.distributed import init_dap +from xtrimomultimer.model_acc.evoformer import EvoformerBlock as FFEvoformerBlock + + +def main(): + + parser = argparse.ArgumentParser(description='Evoformer Standalone Perf Benchmark') + parser.add_argument("--dap-size", default=1, type=int, help='dap size') + parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of MSA') + parser.add_argument('--res-length', default=256, type=int, help='Sequence Length of Residues') + parser.add_argument('--trials', default=50, type=int, help='Number of Trials to Execute') + parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard') + parser.add_argument('--layers', default=12, type=int, help='Evoformer Layers to Execute') + parser.add_argument('--cm', default=256, type=int, help='MSA hidden dimension') + parser.add_argument('--cz', default=128, type=int, help='Pair hidden dimension') + parser.add_argument('--heads', default=8, type=int, help='Number of Multihead Attention heads') + parser.add_argument('--openfold', + action='store_true', + help='Benchmark with Evoformer Implementation from OpenFold.') + parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.') + parser.add_argument('--prof', action='store_true', help='run with profiler.') + + args = parser.parse_args() + + init_dap(args.dap_size) + + precision = torch.bfloat16 + if args.dap_size > 1: + # (PyTorch issue) Currently All2All communication does not support the Bfloat16 datatype in PyTorch + precision = torch.float16 + + if not torch.cuda.is_available(): + raise NotImplementedError('Running on CPU is not supported') + + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) + + if args.openfold: + from openfold.model.evoformer import EvoformerBlock + + class OpenFoldEvoformer(nn.Module): + + def __init__(self, d_node, d_pair): + super(OpenFoldEvoformer, self).__init__() + self.d_node = d_node + self.d_pair = d_pair + + self.c_hidden_msa_att = int(d_node / 8) + self.c_hidden_pair_att = int(d_pair / 8) + + self.EvoformerBlock = EvoformerBlock(c_m=d_node, + c_z=d_pair, + c_hidden_msa_att=self.c_hidden_msa_att, + c_hidden_opm=self.c_hidden_msa_att, + c_hidden_mul=self.d_pair, + c_hidden_pair_att=self.c_hidden_pair_att, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.25, + inf=1e9, + eps=1e-10) + + def forward(self, node, pair, node_mask, pair_mask): + node, pair = self.EvoformerBlock(node, pair, node_mask, pair_mask) + return node, pair + + attn_layers = [] + for idx in range(0, args.layers): + if args.openfold: + attn_layers.append(OpenFoldEvoformer(d_node=args.cm, d_pair=args.cz)) + else: + attn_layers.append( + FFEvoformerBlock(c_m=args.cm, + c_z=args.cz, + first_block=(idx == 0), + last_block=(idx == args.layers - 1))) + attn_layers[idx].cuda() + attn_layers[idx].to(dtype=precision) + + start_evt_fwd = [] + start_evt_bwd = [] + stop_evt_bwd = [] + for recorded_trial in range(0, args.trials): + start_evt_fwd.append(torch.cuda.Event(enable_timing=True)) + start_evt_bwd.append(torch.cuda.Event(enable_timing=True)) + stop_evt_bwd.append(torch.cuda.Event(enable_timing=True)) + + inputs_node = torch.randn(args.msa_length // args.dap_size, + args.res_length, + args.cm, + dtype=precision, + device=torch.device("cuda")).requires_grad_(True) + inputs_pair = torch.randn(args.res_length // args.dap_size, + args.res_length, + args.cz, + dtype=precision, + device=torch.device("cuda")).requires_grad_(True) + node_mask = torch.ones((args.msa_length, args.res_length), + dtype=precision, + device=torch.device("cuda")).requires_grad_(False) + pair_mask = torch.ones((args.res_length, args.res_length), + dtype=precision, + device=torch.device("cuda")).requires_grad_(False) + grads_node = torch.randn_like(inputs_pair) + + if args.prof: + prof = torch.profiler.profile( + schedule=torch.profiler.schedule(wait=1, + warmup=args.warmup_trials, + active=args.trials, + repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/fastfold'), + profile_memory=False, + record_shapes=False, + with_stack=False) + prof.start() + + for trial in range(0, args.trials + args.warmup_trials): + layer_inputs = inputs_node, inputs_pair + evt_idx = trial - args.warmup_trials + + torch.distributed.barrier() + torch.cuda.synchronize() + + if evt_idx >= 0: + start_evt_fwd[evt_idx].record() + + for lyr_idx in range(0, args.layers): + layer_inputs = attn_layers[lyr_idx].forward(*layer_inputs, node_mask, pair_mask) + + torch.cuda.synchronize() + + if evt_idx >= 0: + start_evt_bwd[evt_idx].record() + + if not args.fwd: + layer_inputs[1].backward(grads_node) + + if evt_idx >= 0: + stop_evt_bwd[evt_idx].record() + + if args.prof: + prof.step() + + if args.prof: + prof.stop() + + torch.distributed.barrier() + torch.cuda.synchronize() + elapsed_time_fwd = 0.0 + elapsed_time_bwd = 0.0 + for evt_idx in range(0, args.trials): + elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx]) + elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx]) + + print("[ MSA Attn ] Input: {:4d}, {:4d}, ({:4d} {:4d}) Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format( + args.msa_length, args.res_length, \ + args.cm, args.cz, \ + elapsed_time_fwd / ( args.trials * args.layers ), \ + elapsed_time_bwd / ( args.trials * args.layers ))) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/inference.py b/inference.py index 23dafba..6150755 100644 --- a/inference.py +++ b/inference.py @@ -64,6 +64,16 @@ torch.set_float32_matmul_precision("high") +def get_chunk_size(seq_len): + if seq_len < 1024: + chunk_size = None + elif seq_len < 2048: + chunk_size = 64 + else: + chunk_size = 1 + return chunk_size + + def precompute_alignments(tags, seqs, alignment_dir, args): for tag, seq in zip(tags, seqs): tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") @@ -365,8 +375,8 @@ def predict_structure( os.makedirs(output_dir, exist_ok=True) if is_fastfold_optimize: - model.globals.chunk_size = 64 - set_chunk_size(64) + model.globals.chunk_size = get_chunk_size(processed_feature_dict["aatype"].shape[0]) + set_chunk_size(model.globals.chunk_size) batch = processed_feature_dict with torch.no_grad(): diff --git a/xtrimomultimer/model_acc/__init__.py b/xtrimomultimer/model_acc/__init__.py index 1e7ef84..b1dcb2a 100644 --- a/xtrimomultimer/model_acc/__init__.py +++ b/xtrimomultimer/model_acc/__init__.py @@ -2,4 +2,4 @@ from .ops import OutProductMean, set_chunk_size from .triangle import PairStack -__all__ = ["MSAStack", "OutProductMean", "PairStack", "Evoformer", "set_chunk_size"] +__all__ = ["MSAStack", "OutProductMean", "PairStack", "set_chunk_size"] diff --git a/xtrimomultimer/model_acc/kernel/cuda_native/csrc/softmax_cuda_kernel.cu b/xtrimomultimer/model_acc/kernel/cuda_native/csrc/softmax_cuda_kernel.cu index 0014459..b4f4314 100644 --- a/xtrimomultimer/model_acc/kernel/cuda_native/csrc/softmax_cuda_kernel.cu +++ b/xtrimomultimer/model_acc/kernel/cuda_native/csrc/softmax_cuda_kernel.cu @@ -227,22 +227,6 @@ at::Tensor softmax(at::Tensor input, long long rows, long long cols) { COLS_CASE(14) COLS_CASE(15) COLS_CASE(16) - COLS_CASE(17) - COLS_CASE(18) - COLS_CASE(19) - COLS_CASE(20) - COLS_CASE(21) - COLS_CASE(22) - COLS_CASE(23) - COLS_CASE(24) - COLS_CASE(25) - COLS_CASE(26) - COLS_CASE(27) - COLS_CASE(28) - COLS_CASE(29) - COLS_CASE(30) - COLS_CASE(31) - COLS_CASE(32) #undef COLS_CASE else { int grid_dim; @@ -476,22 +460,6 @@ at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, l COLS_CASE(14) COLS_CASE(15) COLS_CASE(16) - COLS_CASE(17) - COLS_CASE(18) - COLS_CASE(19) - COLS_CASE(20) - COLS_CASE(21) - COLS_CASE(22) - COLS_CASE(23) - COLS_CASE(24) - COLS_CASE(25) - COLS_CASE(26) - COLS_CASE(27) - COLS_CASE(28) - COLS_CASE(29) - COLS_CASE(30) - COLS_CASE(31) - COLS_CASE(32) #undef COLS_CASE else { int grid_dim; @@ -746,22 +714,6 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma COLS_CASE(14) COLS_CASE(15) COLS_CASE(16) - COLS_CASE(17) - COLS_CASE(18) - COLS_CASE(19) - COLS_CASE(20) - COLS_CASE(21) - COLS_CASE(22) - COLS_CASE(23) - COLS_CASE(24) - COLS_CASE(25) - COLS_CASE(26) - COLS_CASE(27) - COLS_CASE(28) - COLS_CASE(29) - COLS_CASE(30) - COLS_CASE(31) - COLS_CASE(32) #undef COLS_CASE else { int grid_dim;