Faster R-CNN源码阅读之三:Faster R-CNN/lib/networks/VGGnet_test.py

Faster R-CNN源码阅读之三:Faster R-CNN/lib/networks/VGGnet_test.py

一、介绍

   本demo由Faster R-CNN官方提供,我只是在官方的代码上增加了注释,一方面方便我自己学习,另一方面贴出来和大家一起交流。
   该文件中的函数和类的主要目的是定义Faster R-CNN中基于VGG16的网络结构。

二、代码以及注释
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# -*- coding:utf-8 -*-
import tensorflow as tf
from networks.network import Network

n_classes = 21 # 类别数目(含背景)
_feat_stride = [16, ] # 特征步长
anchor_scales = [8, 16, 32] # anchor尺寸


# test使用的vgg网络。该类继承了Network类,因此在方法中大量使用链式方法定义网络结构
class VGGnet_test(Network):
def __init__(self, trainable=True):
# 上一层网络的输出
self.inputs = []

# 图片输入数据
self.data = tf.placeholder(tf.float32, shape=[None, None, None, 3])

# 图片尺寸信息
self.im_info = tf.placeholder(tf.float32, shape=[None, 3])

# dropout的保留概率
self.keep_prob = tf.placeholder(tf.float32)

# 网络中所有的层
self.layers = dict({'data': self.data, 'im_info': self.im_info})

# 是否可训练
self.trainable = trainable

# 建立网络结构
self.setup()

def setup(self):
# VGG16的基础结构
(self.feed('data')
.conv(3, 3, 64, 1, 1, name='conv1_1', trainable=False)
.conv(3, 3, 64, 1, 1, name='conv1_2', trainable=False)
.max_pool(2, 2, 2, 2, padding='VALID', name='pool1')
.conv(3, 3, 128, 1, 1, name='conv2_1', trainable=False)
.conv(3, 3, 128, 1, 1, name='conv2_2', trainable=False)
.max_pool(2, 2, 2, 2, padding='VALID', name='pool2')
.conv(3, 3, 256, 1, 1, name='conv3_1')
.conv(3, 3, 256, 1, 1, name='conv3_2')
.conv(3, 3, 256, 1, 1, name='conv3_3')
.max_pool(2, 2, 2, 2, padding='VALID', name='pool3')
.conv(3, 3, 512, 1, 1, name='conv4_1')
.conv(3, 3, 512, 1, 1, name='conv4_2')
.conv(3, 3, 512, 1, 1, name='conv4_3')
.max_pool(2, 2, 2, 2, padding='VALID', name='pool4')
.conv(3, 3, 512, 1, 1, name='conv5_1')
.conv(3, 3, 512, 1, 1, name='conv5_2')
.conv(3, 3, 512, 1, 1, name='conv5_3'))

# ========= RPN ============
(self.feed('conv5_3')
.conv(3, 3, 512, 1, 1, name='rpn_conv/3x3')
.conv(1, 1, len(anchor_scales) * 3 * 2, 1, 1, padding='VALID', relu=False, name='rpn_cls_score'))

(self.feed('rpn_conv/3x3')
.conv(1, 1, len(anchor_scales) * 3 * 4, 1, 1, padding='VALID', relu=False, name='rpn_bbox_pred'))

# ========= RoI Proposal ============
(self.feed('rpn_cls_score')
.reshape_layer(2, name='rpn_cls_score_reshape')
.softmax(name='rpn_cls_prob'))

(self.feed('rpn_cls_prob')
.reshape_layer(len(anchor_scales) * 3 * 2, name='rpn_cls_prob_reshape'))

(self.feed('rpn_cls_prob_reshape', 'rpn_bbox_pred', 'im_info')
.proposal_layer(_feat_stride, anchor_scales, 'TEST', name='rois'))

# ========= RCNN ============
(self.feed('conv5_3', 'rois')
.roi_pool(7, 7, 1.0 / 16, name='pool_5')
.fc(4096, name='fc6')
.fc(4096, name='fc7')
.fc(n_classes, relu=False, name='cls_score')
.softmax(name='cls_prob'))

(self.feed('fc7')
.fc(n_classes * 4, relu=False, name='bbox_pred'))

Comments

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×