1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
|
"""Train a Fast R-CNN network on a region of interest database."""
import _init_paths from fast_rcnn.train import get_training_roidb, train_net from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir from datasets.factory import get_imdb from networks.factory import get_network import argparse import pprint import numpy as np import sys import pdb
def parse_args(): """ Parse input arguments 配置传入的参数变量 """ parser = argparse.ArgumentParser(description='Train a Fast R-CNN network') parser.add_argument('--device', dest='device', help='device to use', default='cpu', type=str)
parser.add_argument('--device_id', dest='device_id', help='device id to use', default=0, type=int) parser.add_argument('--solver', dest='solver', help='solver prototxt', default=None, type=str) parser.add_argument('--iters', dest='max_iters', help='number of iterations to train', default=70000, type=int) parser.add_argument('--weights', dest='pretrained_model', help='initialize with pretrained model weights', default=None, type=str) parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default=None, type=str) parser.add_argument('--imdb', dest='imdb_name', help='dataset to train on', default='kitti_train', type=str) parser.add_argument('--rand', dest='randomize', help='randomize (do not use a fixed seed)', action='store_true') parser.add_argument('--network', dest='network_name', help='name of the network', default=None, type=str) parser.add_argument('--set', dest='set_cfgs', help='set config keys', default=None, nargs=argparse.REMAINDER)
if len(sys.argv) == 1: parser.print_help() sys.exit(1)
args = parser.parse_args() return args
if __name__ == '__main__': args = parse_args()
print('Called with args:') print(args)
if args.cfg_file is not None: cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs)
print('Using config:') pprint.pprint(cfg)
if not args.randomize: np.random.seed(cfg.RNG_SEED) imdb = get_imdb(args.imdb_name) print 'Loaded dataset `{:s}` for training'.format(imdb.name) roidb = get_training_roidb(imdb)
output_dir = get_output_dir(imdb, None) print 'Output will be saved to `{:s}`'.format(output_dir)
device_name = '/{}:{:d}'.format(args.device,args.device_id) print device_name
network = get_network(args.network_name) print 'Use network `{:s}` in training'.format(args.network_name)
train_net(network, imdb, roidb, output_dir, pretrained_model=args.pretrained_model, max_iters=args.max_iters)
|