-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
116 lines (96 loc) · 4.37 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# for RA-L
import argparse
import logging
import os
import torch
import torch.distributed as dist
from ssd.engine.inference import do_evaluation
from ssd.config import cfg
from ssd.data.build import make_data_loader
from ssd.engine.trainer import do_train
from ssd.modeling.detector import build_detection_model
from ssd.solver.build import make_optimizer, make_lr_scheduler
from ssd.utils import dist_util, mkdir
from ssd.utils.checkpoint import CheckPointer
from ssd.utils.dist_util import synchronize
from ssd.utils.logger import setup_logger
from ssd.utils.misc import str2bool
def train(cfg, args):
logger = logging.getLogger('SSD.trainer')
model = build_detection_model(cfg) #实例化模型
device = torch.device(cfg.MODEL.DEVICE) #获得设备序号
model.to(device) #加载模型到GPU
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) #pytorch并行化
# , find_unused_parameters=True
lr = cfg.SOLVER.LR * args.num_gpus # scale by num gpus
optimizer = make_optimizer(cfg, model, lr)
milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS] #模型eval和保存checkpoint的关键时间
scheduler = make_lr_scheduler(cfg, optimizer, milestones) #模型训练步骤超参
arguments = {"iteration": 0}
save_to_disk = dist_util.get_rank() == 0
checkpointer = CheckPointer(model, optimizer, scheduler, cfg.OUTPUT_DIR, save_to_disk, logger)
extra_checkpoint_data = checkpointer.load() # 加载checkpoints param:cfg.SOLVER.CHECKPOINTS
arguments.update(extra_checkpoint_data)
max_iter = cfg.SOLVER.MAX_ITER // args.num_gpus
train_loader = make_data_loader(cfg, is_train=True, distributed=args.distributed, max_iter=max_iter, start_iter=arguments['iteration'])
model = do_train(cfg, model, train_loader, optimizer, scheduler, checkpointer, device, arguments, args)
return model
def main():
parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training With PyTorch')
parser.add_argument(
"--config-file",
default="/home/ubuntu/Code/SSD/configs/vgg_ssd300_voc0712.yaml",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument('--log_step', default=10, type=int, help='Print logs every log_step')
parser.add_argument('--save_step', default=2500, type=int, help='Save checkpoint every save_step')
parser.add_argument('--eval_step', default=2500, type=int, help='Evaluate dataset every eval_step, disabled when eval_step < 0')
parser.add_argument('--use_tensorboard', default=True, type=str2bool)
parser.add_argument(
"--skip-test",
dest="skip_test",
help="Do not test the final model",
action="store_true",
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
args.distributed = num_gpus > 1
args.num_gpus = num_gpus
if torch.cuda.is_available():
# This flag allows you to enable the inbuilt cudnn auto-tuner to
# find the best algorithm to use for your hardware.
torch.backends.cudnn.benchmark = True
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
synchronize()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
if cfg.OUTPUT_DIR:
mkdir(cfg.OUTPUT_DIR)
logger = setup_logger("SSD", dist_util.get_rank(), cfg.OUTPUT_DIR)
logger.info("Using {} GPUs".format(num_gpus))
logger.info(args)
logger.info("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, "r") as cf:
config_str = "\n" + cf.read()
logger.info(config_str)
logger.info("Running with config:\n{}".format(cfg))
model = train(cfg, args)
if not args.skip_test:
logger.info('Start evaluating...')
torch.cuda.empty_cache() # speed up evaluating after training finished
do_evaluation(cfg, model, distributed=args.distributed)
if __name__ == '__main__':
main()