-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
58 lines (46 loc) · 1.76 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import onnxruntime as ort
import torch
from evaluation.eval_wrapper import eval_lane
from utils.common import CallableSession
from utils.common import get_model
from utils.common import merge_config
if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
args, cfg = merge_config()
distributed = False
if "WORLD_SIZE" in os.environ:
distributed = int(os.environ["WORLD_SIZE"]) > 1
cfg.distributed = distributed
if distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
if cfg.onnx_path:
providers = ["CUDAExecutionProvider"]
provider_options = [{}]
ort_session_options = ort.SessionOptions()
session = ort.InferenceSession(
cfg.onnx_path,
providers=providers,
provider_options=provider_options,
sess_options=ort_session_options,
)
net = CallableSession(session)
else:
net = get_model(cfg)
if cfg.model_ckpt:
net = torch.load(cfg.model_ckpt, map_location="cpu")["model_ckpt"].cuda()
else:
state_dict = torch.load(cfg.test_model, map_location="cpu")["model"].cuda()
compatible_state_dict = {}
for k, v in state_dict.items():
if "module." in k:
compatible_state_dict[k[7:]] = v
else:
compatible_state_dict[k] = v
net.load_state_dict(compatible_state_dict, strict=True)
if distributed:
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.local_rank])
if not os.path.exists(cfg.test_work_dir):
os.mkdir(cfg.test_work_dir)
eval_lane(net, cfg)