Batch Normalization

Batch Normalization

传统的 Mini-batch 随机梯度下降法训练神经网络时,调参工作变得非常复杂,网络学习的效率不高,主要是因为 Internal Convariate Shift 问题,而 Batch Normalization 就是用来解决这个问题的一个方案。

Internal Convariate Shift

在深层网络的训练过程中,网络参数不断变化,导致每一层的输入数据分布很可能会发生变化,这个过程被称为 Internal Convariate Shift.

ICS 会带来很多问题:

  • 底层网络参数的变化会导致高层次网络的数据输入分布改变,高层次的网络就需要不停地调整以适应这种变化,有可能导致学习速率过慢;
  • 如果激活函数使用的是饱和类激活函数(例如:sigmoid, tanh 等函数),随着权重 $W$ 的不断增大,输出 $Z=Wx+b$ 会接近梯度饱和区,梯度消失问题严重,参数更新速度变慢,网络的收敛速度也受到严重影响。

之前的一些解决方案

针对梯度消失的问题,使用 ReLU 代替饱和类激活函数。

使用 Whitening 方法(e.g. PCA Whitening, ZCA Whitening),这两种方法可以使输入特征具有相同的均值和方差,并且去除特征之间的相关性。但如果对每一层的数据都进行 Whitening,开销会非常大(PCA 分解时间复杂度很高),并且 Whitening 一定程度上改变了网络每一层的分布,数据的表达能力减弱。

Batch Normalization

通过一定的规范化手段,使得每个隐藏层的输出(也就是下一个隐藏层的输入)的均值和方差变为 0 和 1,使之分布的变化不是那么剧烈。这样做可以一定程度上远离梯度饱和区,解决梯度消失的问题。另外减小底层权重变化对上层分布的影响,相使得输入对稳定,高层的变化可以不那么剧烈。

但 BN 之后的输出大多位于激活函数的线性区,这意味着网络的非线性表达能力下降,所以可以加入 $\gamma, \beta$ 两个可学习参数,将 BN 变为:

即特征每个维度对 batch 求平均,之后再做一个线性变换,再经过一个激活函数。

预测阶段,测试的样本可能很少,这时使用和训练阶段一样的正则化可能导致测试集的 $\mu, \sigma^2$ 是整个分布的有偏估计。解决方案是保留训练集每一组 batch 在网络每一层的 $\mu_{batch}, \sigma_{batch}$,使用整个训练集的统计量来对测试集进行归一化:

之后:

就可以正常进行 BN 操作了。

这适用于训练集和测试集同分布的情况,如果是迁移学习等不同分布的情况,应该需要另行处理。

参考文献

[1] Batch Normalization 原理与实战

[2] 【深度学习】深入理解Batch Normalization批标准化

[3] Why does Batch Normalization work?

打赏
  • 版权声明: 本博客所有文章除特别声明外,均采用 Apache License 2.0 许可协议。转载请注明出处!
  • © 2020 Bowen
  • Powered by Hexo Theme Ayer

请我喝杯咖啡吧~

支付宝
微信