728x90
반응형

 

Weight Standardization Paper : arxiv.org/abs/1903.10520

 

Micro-Batch Training with Batch-Channel Normalization and Weight Standardization

Batch Normalization (BN) has become an out-of-box technique to improve deep network training. However, its effectiveness is limited for micro-batch training, i.e., each GPU typically has only 1-2 images for training, which is inevitable for many computer v

arxiv.org

 

 

 

Weight Standardization은 Normalization 기법 중 하나이며, Batch Normalization이 minibatch 단위로 normalization을 수행하기 때문에 모델의 성능이 large batch size에 의존하게 된다는 단점이 있다. 이러한 문제를 해결하기 위하여 Group Normalization 등 다양한 기법이 제안되었으나, 일반적인 large batch size 를 사용하는 상황에서는 batch normalization 성능에 미치지 못하여 활용도가 현저히 떨어지게 된다. 

 

그래서 등장한게 Weight Standardization 인데, 이는 group normalization 과 같이 minibatch dependency를 완전히 제거하면서 large batch size 를 사용하는 학습에서 batch normalization 보다 좋은 성능을 달성함을 보였다. 

 

Weight Standardization의 아이디어는 아래와 같다.

 

 

Batch Norm 이나 Group Norm 같은 기존의 방법들은 주로 feature activation을 대상으로 normalization을 수행하지만, Weight Standardization은 weight(conv filter)를 대상으로 normalization 을 수행한다. 이는 conv filter의 mean 값을 0으로, variance 값을 1로 조정하게 된다. 즉, loss와 gradient의 landscape를 smoothing 하는 효과를 가져오게 된다.

 

Original filter weight 를 $ W\in \mathbb{R}^{O\times I} $ ($O$: number of output channels, $I$: number of input channels x kernel size) 라고 할 때, normalization 된 filter weight $ \hat{W}\in \mathbb{R}^{O\times I} $ 는 다음과 같이 계산한다. 

 

$ \hat{W}= \left [ \hat{W}_{i, j}\mid \hat{W}_{i, j} = \frac{W_{i, j} - \mu w_{i, \cdot }}{\sigma w_{i, \cdot } + \epsilon } \right ]  $

 

$ y = \hat{W} \ast x $

 

이 때, $ \mu w_{i, \cdot } $ 는 다음과 같이 계산된다. 

 

$ \mu w_{i, \cdot } = \frac{1}{I}\sum_{j=1}^{I}W_{i, \cdot }, \sigma w_{i, \cdot } = \sqrt{\frac{1}{I}\sum_{i=1}^{I}\left ( W_{i, j} - \mu w_{i, \cdot } \right )} $

 

 

 

smoothness를 formulate 하기 위해 Lipschitzness 라는 개념이 도입된다. 

 

아래와 같은 조건을 만족한다면 $ f $ : $ L-Lipschitz $ 라고 한다.

 

$ \forall x_{1}, x_{2} : \left | f(x_{1}) - f(x) \right |  
\leq L \left \| x_{1} - x_{2} \right \| $

 

이 때 $ L $을 Lipxchitz constant 라고 부르며, 이 값이 작을 수록 function $f$ 가 smooth 해짐을 의미한다고 한다 . 또한 Lipschitz constant는 gradient 크기에 의해 좌우되기 때문에 Loss(L)의 landscape를 smooth 하게 만들기 위해서는 gradient를 줄여야하며, gradient의 landscape를 smooth 하게 만들기 위해서는 gradient의 gradient($ \Delta ^{2}L $) 즉, Hessian($ H $)을 줄여야한다. 

 

이렇듯 Weight Standardization 을 사용하면 mini batch 에 대한 dependency 가 없으므로 batch size와 완전 무관하게 동작하며, CNN에서 weight는 actiavtion에 비해 용량이 훨씬 적기 때문에 memory 및 time에 효과적이며, 특히 inference 시 weight 가 fix 되기 때문에 computation이 전혀 없다는 장점을 가진다. 또한 batch norm, group norm 등 activation을 대상으로 한 normalization 과 동시에 사용하여 성능을 더욱 향상 시킬 수 있다고 한다. 

 

 

 

 

 

Weight Standarization의 2D, 3D 구현은 아래와 같다. 

 

 

Weight Standardization 2D : github.com/joe-siyuan-qiao/WeightStandardization

 

joe-siyuan-qiao/WeightStandardization

Standardizing weights to accelerate micro-batch training - joe-siyuan-qiao/WeightStandardization

github.com

class Conv2d(nn.Conv2d):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                 padding, dilation, groups, bias)

    def forward(self, x):
        weight = self.weight
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
        weight = weight / std.expand_as(weight)
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

 

 

Weight Standardization 3D : github.com/Sungman-Cho/weight-standardization-3d

 

Sungman-Cho/weight-standardization-3d

Contribute to Sungman-Cho/weight-standardization-3d development by creating an account on GitHub.

github.com

import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv3d(nn.Conv3d):
    def __init__(self, in_channels, output_channels, kernel_size, 
                stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(Conv3d, self).__init__(in_channels, output_channels, kernel_size, stride, padding, dilation, groups, bias)

    def forward(self, x):
        w = self.weight
        w_mean = w.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True)
        w = w - w_mean
        std = w.view(w.size(0), -1).std(dim=1).view(-1,1,1,1,1) + 1e-5
        w = w / std.expand_as(w)
        return F.conv3d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)


if __name__ == '__main__':
    conv3d = Conv3d(in_channels=3, output_channels=8, kernel_size=1)
    # b, c, z, h, w
    x = torch.randn(8, 3, 5, 32, 32).float()
    x = conv3d(x)

    print(x)

 

 

 

 

 

 

 

참고자료 : blog.lunit.io/2019/05/28/weight-standardization/

 

Weight Standardization

본 포스트는 최근 발표된 새로운 normalization 기법인 Weight Standardization에 대해 소개합니다. Introduction Normalization은 머신러닝에서 데이터의 불필요한 정보를 제거하고 학습을 용이하게 하기 위해

blog.lunit.io

 

728x90
반응형