Faster R-CNN源码阅读之十:Faster R-CNN/lib/fast_rcnn/train.py

Faster R-CNN源码阅读之十:Faster R-CNN/lib/fast_rcnn/train.py

一、介绍

   本demo由Faster R-CNN官方提供,我只是在官方的代码上增加了注释,一方面方便我自己学习,另一方面贴出来和大家一起交流。
   该文件中的函数的主要目的是训练整个Faster R-CNN网络。

二、代码以及注释
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
# coding=utf-8
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""Train a Fast R-CNN network."""

from fast_rcnn.config import cfg
import gt_data_layer.roidb as gdl_roidb
import roi_data_layer.roidb as rdl_roidb
from roi_data_layer.layer import RoIDataLayer
from utils.timer import Timer
import numpy as np
import os
import tensorflow as tf
import sys
from tensorflow.python.client import timeline
import time


class SolverWrapper(object):
"""
A simple wrapper around Caffe's solver.
This wrapper gives us control over the snapshot process, which we
use to unnormalize the learned bounding-box regression weights.

对Caffe的Solver进行了简单的封装。
这个封装可以让我们控制snapshot过程,在snapshot过程中,我们对学习得到的bbox回归权重进行了去规范化(unnormalize)。
"""

def __init__(self, sess, saver, network, imdb, roidb, output_dir, pretrained_model=None):
"""Initialize the SolverWrapper."""
# 使用的Faster RCNN网络结构
self.net = network
# 图片数据集
self.imdb = imdb
# rois数据集
self.roidb = roidb
# 网络结构和权重保存输出目录
self.output_dir = output_dir
# 预训练文件模型路径
self.pretrained_model = pretrained_model

print 'Computing bounding-box regression targets...'
# cfg.TRAIN.BBOX_REG默认为True
if cfg.TRAIN.BBOX_REG:
# 不同类的均值与方差,返回格式means.ravel(), stds.ravel()
self.bbox_means, self.bbox_stds = rdl_roidb.add_bbox_regression_targets(roidb)
print 'done'

# For checkpoint
self.saver = saver

def snapshot(self, sess, iter):
"""
Take a snapshot of the network after unnormalizing the learned
bounding-box regression weights. This enables easy use at test-time.
在对学习的边界框回归权重进行非标准化(unnormalize)后获取网络snapshot。
这样可以在测试使用时比较方便
"""
net = self.net

if cfg.TRAIN.BBOX_REG and net.layers.has_key('bbox_pred'):
# save original values
# 将原来的值保存下来
with tf.variable_scope('bbox_pred', reuse=True):
weights = tf.get_variable("weights")
biases = tf.get_variable("biases")

orig_0 = weights.eval()
orig_1 = biases.eval()

# scale and shift with bbox reg unnormalization; then save snapshot
# 更新weights和bias
weights_shape = weights.get_shape().as_list()
sess.run(net.bbox_weights_assign,
feed_dict={net.bbox_weights: orig_0 * np.tile(self.bbox_stds, (weights_shape[0], 1))})
sess.run(net.bbox_bias_assign,
feed_dict={net.bbox_biases: orig_1 * self.bbox_stds + self.bbox_means})

# 如果网络保存的目录不存在则重新创建一个
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)

# 中缀
infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
# 文件名的创建
filename = (cfg.TRAIN.SNAPSHOT_PREFIX + infix +
'_iter_{:d}'.format(iter + 1) + '.ckpt')
filename = os.path.join(self.output_dir, filename)

# 保存网络
self.saver.save(sess, filename)
print 'Wrote snapshot to: {:s}'.format(filename)

# 恢复原始的状态
if cfg.TRAIN.BBOX_REG and net.layers.has_key('bbox_pred'):
with tf.variable_scope('bbox_pred', reuse=True):
# restore net to original state
sess.run(net.bbox_weights_assign, feed_dict={net.bbox_weights: orig_0})
sess.run(net.bbox_bias_assign, feed_dict={net.bbox_biases: orig_1})

# smooth l1方法
def _modified_smooth_l1(self, sigma, bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights):
"""
ResultLoss = outside_weights * SmoothL1(inside_weights * (bbox_pred - bbox_targets))
SmoothL1(x) = 0.5 * (sigma * x)^2, if |x| < 1 / sigma^2
|x| - 0.5 / sigma^2, otherwise
"""
# 计算sigma^2
sigma2 = sigma * sigma

# 计算所需要处理的x的矩阵,这里利用了之前返回的inside weights。
inside_mul = tf.multiply(bbox_inside_weights, tf.subtract(bbox_pred, bbox_targets))
# 获取inside mul矩阵中小于 1 / sigma ^ 2的信息,在每个位置设置为True 或者False。然后转换为1.0或者0.0。
smooth_l1_sign = tf.cast(tf.less(tf.abs(inside_mul), 1.0 / sigma2), tf.float32)
# 计算上面公式中的第一个式子,这里并没有关注到后面的判断条件。
smooth_l1_option1 = tf.multiply(tf.multiply(inside_mul, inside_mul), 0.5 * sigma2)
# 计算第二个式子。
smooth_l1_option2 = tf.subtract(tf.abs(inside_mul), 0.5 / sigma2)
# 这里根据上面产生的smooth l1 sign条件产生最后的结果,就是在这里才综合考虑后面的判断条件
smooth_l1_result = tf.add(tf.multiply(smooth_l1_option1, smooth_l1_sign),
tf.multiply(smooth_l1_option2, tf.abs(tf.subtract(smooth_l1_sign, 1.0))))

# 和outside weights相乘并返回最后的结果。
outside_mul = tf.multiply(bbox_outside_weights, smooth_l1_result)

return outside_mul

def train_model(self, sess, max_iters):
"""Network training loop."""

data_layer = get_data_layer(self.roidb, self.imdb.num_classes)

# RPN
# classification loss
# rpn-data数据都是在anchor target layer中产生
# 将'rpn_cls_score_reshape'层的输出(1, n,n,18)reshape为(-1, 2), 其中2为前景与背景的多分类得分()
rpn_cls_score = tf.reshape(self.net.get_output('rpn_cls_score_reshape'), [-1, 2])
# 将labels展开成1维
rpn_label = tf.reshape(self.net.get_output('rpn-data')[0], [-1])
# 把rpn_label不等于-1对应引索的rpn_cls_score取出,重新组合成rpn_cls_score
rpn_cls_score = tf.reshape(tf.gather(rpn_cls_score, tf.where(tf.not_equal(rpn_label, -1))), [-1, 2])
# 把rpn_label不等于 - 1对应引索的rpn_label取出,重新组合成rpn_label
rpn_label = tf.reshape(tf.gather(rpn_label, tf.where(tf.not_equal(rpn_label, -1))), [-1])
# labels的交叉熵损失。
# tf.nn.sparse_softmax_cross_entropy_with_logits返回的是一个向量,最后需要通过规约操作生成损失数值。
rpn_cross_entropy = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(logits=rpn_cls_score, labels=rpn_label))

# bounding box regression L1 loss
# 获取RPN网络产生的bbox回归目标
rpn_bbox_pred = self.net.get_output('rpn_bbox_pred')
# 获取rpn-data层产生的bbox回归目标和inside weights和outside weights,并将通道顺序更改为[N, H, W, C]
rpn_bbox_targets = tf.transpose(self.net.get_output('rpn-data')[1], [0, 2, 3, 1])
rpn_bbox_inside_weights = tf.transpose(self.net.get_output('rpn-data')[2], [0, 2, 3, 1])
rpn_bbox_outside_weights = tf.transpose(self.net.get_output('rpn-data')[3], [0, 2, 3, 1])

# 计算smooth l1的结果
rpn_smooth_l1 = self._modified_smooth_l1(3.0, rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights,
rpn_bbox_outside_weights)
# 对smooth l1的结果进行归约操作,因为smooth l1返回的结果是一个矩阵。
rpn_loss_box = tf.reduce_mean(tf.reduce_sum(rpn_smooth_l1, reduction_indices=[1, 2, 3]))

# R-CNN
# classification loss
# roi-data由proposal target layer产生
# 获取每个roi的预测的分类概率分布
cls_score = self.net.get_output('cls_score')
# 获取每个roi的实际label,并展开成一维数组
label = tf.reshape(self.net.get_output('roi-data')[1], [-1])
# 计算rois分类的交叉熵损失
cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=cls_score, labels=label))

# bounding box regression L1 loss
# 获取Fast RCNN网络产生的预测的bbox回归目标
bbox_pred = self.net.get_output('bbox_pred')
# 获取roi-data层bbox的回归目标以及inside weights和outside weights。
bbox_targets = self.net.get_output('roi-data')[2]
bbox_inside_weights = self.net.get_output('roi-data')[3]
bbox_outside_weights = self.net.get_output('roi-data')[4]

# 计算smooth l1的结果
smooth_l1 = self._modified_smooth_l1(1.0, bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights)
# 归约smooth l1的计算结果。
loss_box = tf.reduce_mean(tf.reduce_sum(smooth_l1, reduction_indices=[1]))

# final loss
# 网络的总损失函数是上述四个损失值的相加
loss = cross_entropy + loss_box + rpn_cross_entropy + rpn_loss_box

# optimizer and learning rate
# 全局的步数
global_step = tf.Variable(0, trainable=False)
# 学习率设置
lr = tf.train.exponential_decay(cfg.TRAIN.LEARNING_RATE, global_step,
cfg.TRAIN.STEPSIZE, 0.1, staircase=True)
# momentum设置,默认值为0.9
momentum = cfg.TRAIN.MOMENTUM
# 优化器设置
train_op = tf.train.MomentumOptimizer(lr, momentum).minimize(loss, global_step=global_step)

# iintialize variables
# 初始化所有变量
sess.run(tf.global_variables_initializer())
# 如果提供了预训练模型,则加载预训练模型
if self.pretrained_model is not None:
print ('Loading pretrained model '
'weights from {:s}').format(self.pretrained_model)
self.net.load(self.pretrained_model, sess, self.saver, True)

last_snapshot_iter = -1
# 计时器
timer = Timer()
# 进入循环迭代训练
for iter in range(max_iters):
# get one batch
# 获取一个batch信息
blobs = data_layer.forward()

# Make one SGD update
# 准备feed进网络中的数据
feed_dict = {self.net.data: blobs['data'],
self.net.im_info: blobs['im_info'],
self.net.keep_prob: 0.5,
self.net.gt_boxes: blobs['gt_boxes']}

# cfg.TRAIN.DEBUG_TIMELINE默认为False。不建议设置为True,否则可能会出错。下同。
run_options = None
run_metadata = None
if cfg.TRAIN.DEBUG_TIMELINE:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()

# 记录开始时间戳
timer.tic()

# 进行一次训练
rpn_loss_cls_value, rpn_loss_box_value, loss_cls_value, loss_box_value, _ = sess.run(
[rpn_cross_entropy, rpn_loss_box, cross_entropy, loss_box, train_op],
feed_dict=feed_dict,
options=run_options,
run_metadata=run_metadata)

# 记录结束时间戳
timer.toc()

if cfg.TRAIN.DEBUG_TIMELINE:
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
trace_file = open(str(long(time.time() * 1000)) + '-train-timeline.ctf.json', 'w')
trace_file.write(trace.generate_chrome_trace_format(show_memory=False))
trace_file.close()

# 显示训练的阶段性结果,主要为各种loss值。
if (iter + 1) % (cfg.TRAIN.DISPLAY) == 0:
print 'iter: %d / %d, total loss: %.4f, rpn_loss_cls: %.4f, rpn_loss_box: %.4f, loss_cls: %.4f, loss_box: %.4f, lr: %f' % \
(iter + 1, max_iters, rpn_loss_cls_value + rpn_loss_box_value + loss_cls_value + loss_box_value,
rpn_loss_cls_value, rpn_loss_box_value, loss_cls_value, loss_box_value, lr.eval())
print 'speed: {:.3f}s / iter'.format(timer.average_time)

# 进行网络的snapshot获取并保存整个Faster RCNN网络。
if (iter + 1) % cfg.TRAIN.SNAPSHOT_ITERS == 0:
last_snapshot_iter = iter
self.snapshot(sess, iter)

# 结束的时候再进行依次snapshot获取和网络保存
if last_snapshot_iter != iter:
self.snapshot(sess, iter)


def get_training_roidb(imdb):
"""
Returns a roidb (Region of Interest database) for use in training.
获取一个训练时使用的roidb。
"""
if cfg.TRAIN.USE_FLIPPED:
print 'Appending horizontally-flipped training examples...'
imdb.append_flipped_images()
print 'done'

print 'Preparing training data...'
if cfg.TRAIN.HAS_RPN:
if cfg.IS_MULTISCALE:
gdl_roidb.prepare_roidb(imdb)
else:
rdl_roidb.prepare_roidb(imdb)
else:
rdl_roidb.prepare_roidb(imdb)
print 'done'

return imdb.roidb


def get_data_layer(roidb, num_classes):
"""
return a data layer.
获取并返回一个一个数据层
"""
if cfg.TRAIN.HAS_RPN:
if cfg.IS_MULTISCALE:
layer = GtDataLayer(roidb)
else:
layer = RoIDataLayer(roidb, num_classes)
else:
layer = RoIDataLayer(roidb, num_classes)

return layer


def filter_roidb(roidb):
"""
Remove roidb entries that have no usable RoIs.
移除没有可用ROIS的roidb条目
"""

def is_valid(entry):
# Valid images have:
# (1) At least one foreground RoI OR
# (2) At least one background RoI
overlaps = entry['max_overlaps']
# find boxes with sufficient overlap
fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
# Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
(overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
# image is only valid if such boxes exist
valid = len(fg_inds) > 0 or len(bg_inds) > 0
return valid

num = len(roidb)
filtered_roidb = [entry for entry in roidb if is_valid(entry)]
num_after = len(filtered_roidb)
print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
num, num_after)
return filtered_roidb


def train_net(network, imdb, roidb, output_dir, pretrained_model=None, max_iters=40000):
"""
Train a Fast R-CNN network.
:param network: Faster RCNN训练的网络结构
:param imdb: 图片数据集
:param roidb: rois数据集
:param output_dir: 网络权重文件的保存目录
:param pretrained_model: 预训练网络权重文件路径
:param max_iters: 最大迭代次数
:return: None
"""
# 筛选roidb
roidb = filter_roidb(roidb)
# tf网络保存器
saver = tf.train.Saver(max_to_keep=100)
# tf会话
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
# solver封装
sw = SolverWrapper(sess, saver, network, imdb, roidb, output_dir, pretrained_model=pretrained_model)
print 'Solving...'
# 训练网络
sw.train_model(sess, max_iters)
print 'done solving'

Comments

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×