一、介绍
本demo由Faster R-CNN官方提供,我只是在官方的代码上增加了注释,一方面方便我自己学习,另一方面贴出来和大家一起交流。
这里对之前使用Faster R-CNN的demo进行预测时候的代码进行补完。
二、代码和注释
文件目录:Faster-RCNN/lib/fast_rcnn/test.py
1 | def im_detect(sess, net, im, boxes=None): |
文件目录:Faster-RCNN/lib/fast_rcnn/test.py 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28def _get_blobs(im, rois):
"""
Convert an image and RoIs within that image into network inputs.
将im和图片内部的rois转换成网络的输入。
:param im: 图片的像素矩阵
:param rois: rois
:returns:
"""
# 由于这里是Faster RCNN,cfg.TEST.HAS_RPN默认为True,不使用RPN的代码此处忽略注释。
if cfg.TEST.HAS_RPN:
# 定义一个字典
blobs = {'data': None, 'rois': None}
# 存储blob数据块和缩放系数。
blobs['data'], im_scale_factors = _get_image_blob(im)
else:
blobs = {'data': None, 'rois': None}
blobs['data'], im_scale_factors = _get_image_blob(im)
# 多尺度图像(图像金字塔)
if cfg.IS_MULTISCALE:
if cfg.IS_EXTRAPOLATING:
blobs['rois'] = _get_rois_blob(rois, cfg.TEST.SCALES)
else:
blobs['rois'] = _get_rois_blob(rois, cfg.TEST.SCALES_BASE)
else:
blobs['rois'] = _get_rois_blob(rois, cfg.TEST.SCALES_BASE)
# 返回blob数据和缩放系数
return blobs, im_scale_factors
文件目录:Faster-RCNN/lib/fast_rcnn/test.py 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51def _get_image_blob(im):
"""Converts an image into a network input.
Arguments:
im (ndarray): a color image in BGR order
Returns:
blob (ndarray): a data blob holding an image pyramid
im_scale_factors (list): list of image scales (relative to im) used
in the image pyramid
将图片的像素矩阵转换成网络输入。
:param im: 图片的像素矩阵
:returns:
"""
# 将原始的像素矩阵复制一份
im_orig = im.astype(np.float32, copy=True)
# 图片像素的归一化,这里采用了减去各个通道设定的像素均值的方法。
im_orig -= cfg.PIXEL_MEANS
# 图片的shape
im_shape = im_orig.shape
# 获取图片尺寸(高度,宽度)的短边和长边
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
# 保存所放过的图片的list
processed_ims = []
# 保存缩放的系数
im_scale_factors = []
# 对每个缩放的尺寸, cfg.TEST.SCALES一般取值[600]。
for target_size in cfg.TEST.SCALES:
# 计算对短边的缩放系数
im_scale = float(target_size) / float(im_size_min)
# Prevent the biggest axis from being more than MAX_SIZE
# 为了防止缩放之后长边过长(超过设定的cfg.TEST.MAX_SIZE, 该值一般取值1000),
# 如果过长,则im scale主要表示针对长边的缩放,将长边缩放到可接受的最大长度。
if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE:
im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)
# 对图片进行缩放
im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
interpolation=cv2.INTER_LINEAR)
# 将图片的缩放系数和缩放之后的图片分别保存到list中。
im_scale_factors.append(im_scale)
processed_ims.append(im)
# Create a blob to hold the input images
# 对处理之后的图片进行处理,产生blob
blob = im_list_to_blob(processed_ims)
# 返回产生的blob和缩放系数的list
return blob, np.array(im_scale_factors)
文件目录:Faster-RCNN/lib/utils/blob.py 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22def im_list_to_blob(ims):
"""Convert a list of images into a network input.
Assumes images are already prepared (means subtracted, BGR order, ...).
将包含若干图片像素信息的list转换成blob数据块。这里的处理仅仅只是将所有的图片进行左上角的对齐。
:param ims: 一个list,里面包含若干个图片的像素信息。
:return: 处理之后的blob数据块。
"""
# 返回各个维度的最大长度,这里真真有用的是最大的高度和宽度。
max_shape = np.array([im.shape for im in ims]).max(axis=0)
# 获取图片的总数目
num_images = len(ims)
# 根据图片总数目,最大高度宽度等信息,生成一个全0numpy数组,用以将图片的左上角对齐。
blob = np.zeros((num_images, max_shape[0], max_shape[1], 3), dtype=np.float32)
# 对每个图片
for i in xrange(num_images):
im = ims[i]
# 进行赋值操作,这样的复制过程正好从blob数组的左上角开始。
blob[i, 0:im.shape[0], 0:im.shape[1], :] = im
# 返回
return blob
文件目录:Faster-RCNN/lib/utils/timer.py 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51# coding=utf-8
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import time
# 一个简单的计时器
class Timer(object):
"""A simple timer."""
def __init__(self):
# 总时间
self.total_time = 0.
# 被调用的次数
self.calls = 0
# 开始时间
self.start_time = 0.
# 结束时间和开始时间之间的时间差
self.diff = 0.
# 平均时间
self.average_time = 0.
def tic(self):
'''
记录开始时间
:return: None
'''
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()
def toc(self, average=True):
'''
结束计时
:param average: True或者False,为True时返回平均时间,否则返回时间差
:return: 平均时间或者时间差
'''
# 记录结束时距离开始的时间差值
self.diff = time.time() - self.start_time
# 将时间差值加到总时间上,并把调用次数加1
self.total_time += self.diff
self.calls += 1
# 重新计算平均时间
self.average_time = self.total_time / self.calls
# 根据average返回
if average:
return self.average_time
else:
return self.diff