-
Notifications
You must be signed in to change notification settings - Fork 4
/
simtrack.py
111 lines (101 loc) · 5.1 KB
/
simtrack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from lib.test.tracker.basetracker import BaseTracker
import torch
from lib.train.data.processing_utils import sample_target
# for debug
import cv2
import os
from lib.utils.merge import merge_template_search
from lib.models.stark import build_simtrack
from lib.test.tracker.stark_utils import Preprocessor
from lib.utils.box_ops import clip_box
# import warning
# warnings.filterwarnings('ignore')
class SimTrack(BaseTracker):
def __init__(self, params, dataset_name):
super(SimTrack, self).__init__(params)
network = build_simtrack(params.cfg)
# network.load_state_dict(torch.load(self.params.checkpoint, map_location='cpu')['net'], strict=True)
network.load_state_dict(torch.load('please add the model path here', map_location='cpu'), strict=True)
self.cfg = params.cfg
self.network = network.cuda()
self.network.eval()
self.preprocessor = Preprocessor()
self.state = None
# for debug
self.debug = False
self.frame_id = 0
if self.debug:
self.save_dir = "debug"
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
# for save boxes from all queries
self.save_all_boxes = params.save_all_boxes
self.z_dict1 = {}
def initialize(self, image, info: dict):
# forward the template once
z_patch_arr, rz_factor, z_amask_arr = sample_target(image, info['init_bbox'], self.params.template_factor,
output_sz=self.params.template_size)
template = self.preprocessor.process(z_patch_arr, z_amask_arr)
bbox = info['init_bbox']
bbox_sz = torch.tensor(bbox[2:]) * rz_factor
template_anno = torch.tensor(
[int(self.params.template_size / 2 - bbox_sz[0] / 2), int(self.params.template_size / 2 - bbox_sz[1] / 2),
bbox_sz[0], bbox_sz[1]]).cuda()
self.init_input = {'template': template, 'template_anno': template_anno / self.params.template_size}
# save states
self.state = info['init_bbox']
self.frame_id = 0
if self.save_all_boxes:
'''save all predicted boxes'''
all_boxes_save = info['init_bbox'] * self.cfg.MODEL.NUM_OBJECT_QUERIES
return {"all_boxes": all_boxes_save}
def track(self, image, info: dict = None):
H, W, _ = image.shape
self.frame_id += 1
x_patch_arr, resize_factor, x_amask_arr = sample_target(image, self.state, self.params.search_factor,
output_sz=self.params.search_size) # (x1, y1, w, h)
search = self.preprocessor.process(x_patch_arr, x_amask_arr)
# import pdb
# pdb.set_trace()
with torch.no_grad():
x_dict = self.network.forward_backbone([self.init_input['template'], search, self.init_input['template_anno']])
# run the head
out_dict, _, _ = self.network.forward_head(seq_dict=[x_dict], run_box_head=True)
pred_boxes = out_dict['pred_boxes'].view(-1, 4)
# import pdb
# pdb.set_trace()
# Baseline: Take the mean of all pred boxes as the final result
pred_box = (pred_boxes.mean(dim=0) * self.params.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1]
# get the final box result
self.state = clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10)
# for debug
if self.debug:
x1, y1, w, h = self.state
image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.rectangle(image_BGR, (int(x1),int(y1)), (int(x1+w),int(y1+h)), color=(0,0,255), thickness=2)
save_path = os.path.join(self.save_dir, "%04d.jpg" % self.frame_id)
cv2.imwrite(save_path, image_BGR)
if self.save_all_boxes:
'''save all predictions'''
all_boxes = self.map_box_back_batch(pred_boxes * self.params.search_size / resize_factor, resize_factor)
all_boxes_save = all_boxes.view(-1).tolist() # (4N, )
return {"target_bbox": self.state,
"all_boxes": all_boxes_save}
else:
return {"target_bbox": self.state}
def map_box_back(self, pred_box: list, resize_factor: float):
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
cx, cy, w, h = pred_box
half_side = 0.5 * self.params.search_size / resize_factor
cx_real = cx + (cx_prev - half_side)
cy_real = cy + (cy_prev - half_side)
return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float):
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,)
half_side = 0.5 * self.params.search_size / resize_factor
cx_real = cx + (cx_prev - half_side)
cy_real = cy + (cy_prev - half_side)
return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1)
def get_tracker_class():
return SimTrack