设为首页 加入收藏

TOP

ShuffleNet总结(二)
2017-12-23 06:06:50 】 浏览:385
Tags:ShuffleNet 总结
含的信息可能是相同的,如果在不同的组之后交换一些通道,那么就能交换信息,使得各个组的信息更丰富,能提取到的特征自然就更多,这样是有利于得到更好的结果。

shufleChannels

图5 分组交换通道的示意图,a)是不交换通道但是分成3组了,要吧看到,不同的组是完全独立的;b)每组内又分成3组,不分别交换到其它组中,这样信息就发生了交换,c)这个是与b)是等价的。

ShuffleUnit

ShuffleUnit的设计参考了ResNet,总有两个基本单元,两人个基本单元功能不一样,将他们组合起来就可以得到ShuffleNet。这样的设计可以在增加网络的深度(比mobilenet深约一倍)的同时,减少参数总量与计算量(本人运行Cifar10时,速度大约是molibenet的10倍)。

shufleunit

图6 b)与c)是两人个ShuffleNet的基本单元,这两个单元是参考了a)的设计,单元b)输出与输入的Shape一致,只是丰富了每个通道的信息,单元c)增加了一倍的通道数且输出的$H_{out}$、$W_{out}$ 比$H_{in}$、$W_{in}$减少了一倍

源码解读

def combine(residual, data, combine):
    if combine == 'add':
        return residual + data
    elif combine == 'concat':
        return mx.sym.concat(residual, data, dim=1)
    return None

add是代表图6中的单元b),concat是代表图6中的单元c)。

def channel_shuffle(data, groups):
    data = mx.sym.reshape(data, shape=(0, -4, groups, -1, -2))
    data = mx.sym.swapaxes(data, 1, 2)
    data = mx.sym.reshape(data, shape=(0, -3, -2))
    return data

这个函数就是交换通道的函数,函数的第一行data = mx.sym.reshape(data, shape=(0, -4, groups, -1, -2))是将输入为\(n \times C_{in} \times H_{in} \times W_{in}\)reshape成\(n \times (C_{in}/g) \times g\times H_{in} \times W_{in}\),要注意的是mxnet中reshape不会改变张量在内存中的排列顺序。至于要mxnet中的0,-1,-2,-3,-4的具体意义可以这样看到:

import mxnet as mx
print(help(mx.sym.reshape))

可以看到输出以下(只提取出一小部分,其余的可用上述方法查看),这里有各个参数的具体意义:

- ``0``  copy this dimension from the input to the output shape.
- ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions
- ``-2`` copy all/remainder of the input dimensions to the output shape.
- ``-3`` use the product of two consecutive dimensions of the input shape as the output dimension.
- ``-4`` split one dimension of the input into two dimensions passed subsequent to -4 in shape (can contain -1).

函数的第二行是交换第一与第二个维度,那么现在这个symbol的符号的shape就变成了\(n \times g \times (C_{in}/g) \times H_{in} \times W_{in}\)。这里的第零个维度是\(n\)要注意的是交换维度改变了张量在内存中的排列顺序,改变了内存中的顺序实现上就是完成了图5c)中的Channel Shuffle操作,不同的颜色代码数据在原来内存中的位置。
函数的最后一行合并了第一与第二个维度,输出的张量与输入的张量shape都是\(n \times C_{in} \times H_{in} \times W_{in}\)

def shuffleUnit(residual, in_channels, out_channels, combine_type, groups=3, grouped_conv=True):

    if combine_type == 'add':
        DWConv_stride = 1
    elif combine_type == 'concat':
        DWConv_stride = 2
        out_channels -= in_channels

    first_groups = groups if grouped_conv else 1

    bottleneck_channels = out_channels // 4

    data = mx.sym.Convolution(data=residual, num_filter=bottleneck_channels, 
                      kernel=(1, 1), stride=(1, 1), num_group=first_groups)
    data = mx.sym.BatchNorm(data=data)
    data = mx.sym.Activation(data=data, act_type='relu')

    data = channel_shuffle(data, groups)

    data = mx.sym.Convolution(data=data, num_filter=bottleneck_channels, kernel=(3, 3), 
                       pad=(1, 1), stride=(DWConv_stride, DWConv_stride), num_group=groups)
    data = mx.sym.BatchNorm(data=data)

    data = mx.sym.Convolution(data=data, num_filter=out_channels, 
                       kernel=(1, 1), stride=(1, 1), num_group=groups)
    data = mx.sym.BatchNorm(data=data)

    if combine_type == 'concat':
        residual = mx.sym.Pooling(data=residual, kernel=(3, 3), pool_type='avg', 
                              stride=(2, 2), pad=(1, 1))

    data = combine(residual, data, combine_type)

    return data

ShuffleUnit这个函数实现上是实现图6的b)与c),add对应成b),comcat对应于c)。

def make_stage(data, stage, groups=3):
    stage_repeats = [3, 7, 3]

    grouped_conv = stage > 2

    if groups == 1:
        out_channels = [-1, 24, 144, 288,
首页 上一页 1 2 3 下一页 尾页 2/3/3
】【打印繁体】【投稿】【收藏】 【推荐】【举报】【评论】 【关闭】 【返回顶部
上一篇每天一点爬虫(一) 下一篇机器学习实践之决策树算法学习

最新文章

热门文章

Hot 文章

Python

C 语言

C++基础

大数据基础

linux编程基础

C/C++面试题目