batch Normalization
定义
将均值为$\mu$, 标准差$\sigma$的数据集合$x$,转化为均值为$\beta$, 标准差$\gamma$的数据集合$x^{’}$:
$$x^{’}=\gamma\frac{(x-\mu)}{\sigma}+\beta$$
作用
解决梯度弥散的问题
实现步骤
- 计算数据集合$x$的均值为$\mu$, 标准差$\sigma$
- 套用变化公式
问题:
- 为什么采用滑动平均的方式来求解均值和方差?
先说结论:并不是测试时的mean,var的计算方式与训练时不同,而是测试时的mean,var在训练完成整个网络中就全部固定了。由于在优化网络的时候,我们一般采用的是batch梯度下降。所以在训练过程中,只能计算当前batch样本上的mean和var。但是我们做的normalization是对于整个输入样本空间,因此需要对每个batch的mean, var做指数加权平均来将batch上的mean和var近似成整个样本空间上的mean和var.而在测试Inference过程中,一般不必要也不合适去计算测试时的batch的mean和var,比如测试仅对单样本输入进行测试时,这时去计算单样本输入的mean和var是完全没有意义的。因此会直接拿训练过程中对整个样本空间估算的mean和var直接来用。此时对于inference来说,BN就是一个线性变换。
- 起始值为什么是 0 1?
说白了就是,怎么去估计数据的真实分布?
https://www.cnblogs.com/34fj/p/8805979.html
Tensorflow 实现
Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. Also, be sure to add any batch_normalization ops before getting the update_ops collection. Otherwise, update_ops will be empty, and training/inference will not work properly. For example:
1 |
|
Preliminary
计算均值和方差
1 | tf.nn.moments( |
Args:
- x: 张量
- axes: 列表。求均值和方差的方向
- name: 操作的名字
- keep_dims: true,均值和方差的shape与输入一致
Returns:
- mean,variance
When using these moments for batch normalization (see tf.nn.batch_normalization):
for so-called “global normalization”, used with convolutional filters with shape [batch, height, width, depth], pass axes=[0, 1, 2].即,对同一个channel的数据进行计算均值和方差。
for simple batch normalization pass axes=[0] (batch only).
1 |
|
Args:
- x : 张量
- mean: 均值
- variance: 方差
- offset: \beta, default None
- scale: \gamma, default None
- variance_epsilon: A small float number to avoid dividing by 0
- name: 操作的名字
性质:
如果不提供scale,offset,则新的张量的均值和方差分别为0,1。
mean, variance, offset and scale are all expected to be of one of two shapes:
- In all generality, they can have the same number of dimensions as the input x, with identical sizes as x for the dimensions that are not normalized over (the ‘depth’ dimension(s)), and dimension 1 for the others which are being normalized over. mean and variance in this case would typically be the outputs of tf.nn.moments(…, keep_dims=True) during training, or running averages thereof during inference.
- In the common case where the ‘depth’ dimension is the last dimension in the input tensor x, they may be one dimensional tensors of the same size as the ‘depth’ dimension. This is the case for example for the common [batch, depth] layout of fully-connected layers, and [batch, height, width, depth] for convolutions. mean and variance in this case would typically be the outputs of tf.nn.moments(…, keep_dims=False) during training, or running averages thereof during inference.
Args:
1 |
|
Args:
- ref: A mutable Tensor. Should be from a Variable node. May be uninitialized.
- value: A Tensor. Must have the same type as ref. The value to be assigned to the variable.
- validate_shape: An optional bool. Defaults to True. If true, the operation will validate that the shape of ‘value’ matches the shape of the Tensor being assigned to. If false, ‘ref’ will take on the shape of ‘value’.
use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. - name: A name for the operation (optional).
Returns: - A Tensor that will hold the new value of ‘ref’ after the assignment has completed.
1 |
|
Args:
- pred: A scalar determining whether to return the result of true_fn or false_fn.
- true_fn: The callable to be performed if pred is true.
- false_fn: The callable to be performed if pred is false.
- strict: A boolean that enables/disables ‘strict’ mode; see above.
- name: Optional name prefix for the returned tensors.
Returns: - Tensors returned by the call to either true_fn or false_fn. If the callables return a singleton list, the element is extracted from the list.
BN在神经网络进行training和testing
Example
1 | """ |
Reference
https://www.cnblogs.com/zhengmingli/p/8031690.html
https://www.tensorflow.org/api_docs/python/tf/nn/batch_normalization?hl=zh-cn
https://www.jianshu.com/p/615113382fac
http://lamda.nju.edu.cn/weixs/project/CNNTricks/CNNTricks.html
https://www.zhihu.com/question/38102762
http://www.cnblogs.com/cloud-ken/p/9314769.html