-
Notifications
You must be signed in to change notification settings - Fork 12
/
inspect_sample.py
82 lines (66 loc) · 2.81 KB
/
inspect_sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import sys
import argparse
import pickle
import torch
from pathlib import Path
from typing import Any, Dict
from alphafold.Model.features import AlphaFoldFeatures
from alphafold.Model.alphafold import AlphaFold
from alphafold.Common import protein
from custom_config import model_config
import deepspeed
if __name__=='__main__':
parser = argparse.ArgumentParser(description='Train deep protein docking')
parser.add_argument('-dataset_dir', default='/gpfs/gpfs0/g.derevyanko/OpenFold2Dataset/Features', type=str)
parser.add_argument('-sample_name', default='4mxn_1_a_features.pkl', type=str)
# parser.add_argument('-sample_name', default='2xon_1_a_features.pkl', type=str)
parser.add_argument('-log_dir', default='LogTrain', type=str)
# parser.add_argument('-log_dir', default=None, type=str)
parser.add_argument('-model_name', default='model_tiny', type=str)
# parser.add_argument('-model_name', default='model_small', type=str)
# parser.add_argument('-model_name', default='model_small', type=str)
# parser.add_argument('-precision', default='bf16')
parser.add_argument('-precision', default='fp16')
parser.add_argument('-deepspeed_config_path', default='deepspeed_config.json', type=str)
args = parser.parse_args()
args.dataset_dir = Path(args.dataset_dir)
args.sample_name = Path(args.sample_name)
config = model_config(args.model_name)
af2features = AlphaFoldFeatures(config=config, device=None, is_training=True)
af2features.device = 'cuda'
if args.precision == 'bf16':
af2features.dtype = torch.bfloat16
elif args.precision == 'fp16':
af2features.dtype = torch.float16
af2 = AlphaFold(config=config.model, target_dim=22, msa_dim=49, extra_msa_dim=25).to(device='cuda')
os.environ['RANK'] = '0'
os.environ['LOCAL_RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '6000'
deepspeed.init_distributed(auto_mpi_discovery=False)
af2, optimizer, _, _ = deepspeed.initialize(
model=af2,
model_parameters=af2.parameters(),
config = args.deepspeed_config_path,
dist_init_required=True
)
with open(args.dataset_dir/args.sample_name, 'rb') as f:
raw_features = pickle.load(f)
batch = af2features(raw_features)
if args.precision == 'fp16':
batch = af2features.convert(batch, dtypes={torch.float32: torch.float16,
torch.float64: torch.float32})
elif args.precision == 'bf16':
batch = af2features.convert(batch, dtypes={torch.float32: torch.bfloat16,
torch.float64: torch.float32})
else:
batch = af2features.convert(batch, dtypes={torch.float32: torch.float32,
torch.float64: torch.float32})
for key in batch.keys():
print(key, torch.any(torch.isnan(batch[key])))
output, loss = af2(batch)
print('Loss:', loss)
for key in output.keys():
print(key, torch.any(torch.isnan(output[key])))