-
Notifications
You must be signed in to change notification settings - Fork 5
/
submit_kitti.py
178 lines (133 loc) · 5.82 KB
/
submit_kitti.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from __future__ import absolute_import, division, print_function
from options import Options
options = Options()
opts = options.parse()
from trainer import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from torch.nn.parameter import Parameter
from dataloaders.utils import get_dataset, custom_collate
from metrics import Evaluator, TimeAverageMeter
import utils
import random
import numpy as np
import os
import network
from dataloaders.datasets import Cityscapes, CityLostFound
from dataloaders import custom_transforms as sw
import skimage.io
from tqdm import tqdm
from matplotlib import pyplot as plt
from PIL import Image
def check_path(path):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True) # explicitly set exist_ok when multi-processing
# for submit to KITTI 2015 test benchmark
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
device = torch.device('cuda:{}'.format(opts.gpu_id) if torch.cuda.is_available() else 'cpu')
print("Device: %s" % device)
n_gpus = len(opts.gpu_id)
print("Number of used GPU : {}".format(n_gpus))
print("Used GPU ID : {}".format(opts.gpu_id))
# Setup random seed
torch.manual_seed(opts.random_seed)
np.random.seed(opts.random_seed)
random.seed(opts.random_seed)
torch.backends.cudnn.benchmark = True
opts.data_root = os.path.join(opts.data_root, opts.dataset)
opts.num_classes = 19
denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
random_crop_size = (896, 256)
target_size_crops = random_crop_size
target_size_crops_feats = (random_crop_size[0] // 4, random_crop_size[1] // 4)
target_size = (1280, 384)
target_size_feats = (1280 // 4, 384 // 4)
test_transform = sw.Compose(
[
sw.Tensor(),
]
)
test_dst = Cityscapes(root=opts.data_root, dataset_name=opts.dataset,
mode='test', transform=test_transform, opts=opts)
test_loader = data.DataLoader(
test_dst, batch_size=opts.val_batch_size, shuffle=False, num_workers=4,
pin_memory=True, drop_last=False,
collate_fn=custom_collate)
model = network.RODSNet(opts,
opts.max_disp,
num_classes=opts.num_classes,
device=device,
refinement_type=opts.refinement_type,
)
model.to(device)
evaluator = Evaluator(opts.num_classes)
if opts.resume is not None:
if not os.path.isfile(opts.resume):
raise RuntimeError("=> no checkpoint found at '{}'".format(opts.resume))
# https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
checkpoint = torch.load(opts.resume, map_location=device)
loaded_pt = checkpoint['model_state']
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in loaded_pt.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict, strict=False)
else:
print("[!] No checkpoints found, Retrain...")
num_params = utils.count_parameters(model)
print('=> Number of trainable parameters: %d' % num_params)
# Inference
model.eval()
inference_time = 0
num_samples = len(test_loader)
print('=> %d samples found in the test set' % num_samples)
for i, sample in enumerate(test_loader):
left = sample['left'].to(device, dtype=torch.float32)
right = sample['right'].to(device, dtype=torch.float32)
# Pad
ori_height, ori_width = left.size()[2:]
if ori_height < opts.val_img_height or ori_width < opts.val_img_width:
top_pad = opts.val_img_height - ori_height
right_pad = opts.val_img_width - ori_width
# Pad size: (left_pad, right_pad, top_pad, bottom_pad)
left = F.pad(left, (0, right_pad, top_pad, 0))
right = F.pad(right, (0, right_pad, top_pad, 0))
# warming up
if i==0:
left_temp = left.clone()
right_temp = right.clone()
for j in range(10):
pred_disp_pyramid, left_seg = model(left_temp, right_temp)
with torch.no_grad():
time_start = time.time()
pred_disp_pyramid, left_seg = model(left, right)
model_time = time.time() - time_start
pred_disp = pred_disp_pyramid[-1]
inference_time += model_time
print('=> Inferencing %d/%d, time:%.3f' % (i, num_samples, model_time))
image = left[0].detach().cpu().numpy()
right_image = right[0].detach().cpu().numpy()
# Crop
if ori_height < opts.val_img_height or ori_width < opts.val_img_width:
if right_pad != 0:
pred_disp = pred_disp[:, top_pad:, :-right_pad]
image = image[:, top_pad:, :-right_pad]
right_image = right_image[:, top_pad:, :-right_pad]
else:
pred_disp = pred_disp[:, top_pad:]
image = image[:, top_pad:]
right_image = right_image[:, top_pad:]
for b in range(pred_disp.size(0)):
disp = pred_disp[b].detach().cpu().numpy() # [H, W]
save_name = sample['left_name'][b]
save_name = save_name.replace('image_2', 'disp_0')
save_name_disp = os.path.join(opts.output_dir, save_name)
check_path(os.path.dirname(save_name_disp))
skimage.io.imsave(save_name_disp, (disp * 256.).astype(np.uint16))
print("mean time of our models:%.3f" % (inference_time/num_samples))