[Deep Learning] Pytorch 를 이용한 다양한 손실 함수 구현(Loss Function Implementation) 참고자료
|2020. 11. 2. 16:52
728x90
반응형
딥러닝에서 사용되는 다양한 손실 함수를 구현해 놓은 좋은 Github 를 아래와 같이 소개한다.
github.com/CoinCheung/pytorch-loss
from pytorch_loss import SwishV1, SwishV2, SwishV3
from pytorch_loss import HSwishV1, HSwishV2, HSwishV3
from pytorch_loss import MishV1, MishV2, MishV3
from pytorch_loss import convert_to_one_hot, convert_to_one_hot_cu, OnehotEncoder
from pytorch_loss import EMA
from pytorch_loss import TripletLoss
from pytorch_loss import SoftDiceLossV1, SoftDiceLossV2, SoftDiceLossV3
from pytorch_loss import PCSoftmaxCrossEntropyV1, PCSoftmaxCrossEntropyV2
from pytorch_loss import LargeMarginSoftmaxV1, LargeMarginSoftmaxV2, LargeMarginSoftmaxV3
from pytorch_loss import LabelSmoothSoftmaxCEV1, LabelSmoothSoftmaxCEV2, LabelSmoothSoftmaxCEV3
from pytorch_loss import generalized_iou_loss
from pytorch_loss import FocalLossV1, FocalLossV2, FocalLossV3
from pytorch_loss import Dual_Focal_loss
from pytorch_loss import GeneralizedSoftDiceLoss, BatchSoftDiceLoss
from pytorch_loss import AMSoftmax
from pytorch_loss import AffinityFieldLoss
from pytorch_loss import OhemCELoss, OhemLargeMarginLoss
from pytorch_loss import LovaszSoftmax
from pytorch_loss import CoordConv2d, DY_Conv2d
위와 같이 손쉽게 쓸 수 있고, 아래와 같이 구현을 깔끔하게 해놓았다.
아래 예시는 hard swish 구현 예시이다.
import torch
import torch.nn as nn
import torch.nn.functional as F
##
# version 1: use pytorch autograd
class HSwishV1(nn.Module):
def __init__(self):
super(HSwishV1, self).__init__()
def forward(self, feat):
return feat * F.relu6(feat + 3) / 6
##
# version 2: use derived formula to compute grad
class HSwishFunctionV2(torch.autograd.Function):
@staticmethod
def forward(ctx, feat):
# act = (feat + 3).mul_(feat).div_(6).clip_(0)
act = F.relu6(feat + 3).mul_(feat).div_(6)
ctx.variables = feat
return act
@staticmethod
def backward(ctx, grad_output):
feat = ctx.variables
grad = F.relu6(feat + 3).div_(6)
grad.add_(torch.where(
torch.eq(-3 < feat, feat < 3),
torch.ones_like(feat).div_(6),
torch.zeros_like(feat)).mul_(feat))
grad *= grad_output
return grad
class HSwishV2(nn.Module):
def __init__(self):
super(HSwishV2, self).__init__()
def forward(self, feat):
return HSwishFunctionV2.apply(feat)
##
# version 3: write with cuda which requires less memory and can be faster
import swish_cpp
class HSwishFunctionV3(torch.autograd.Function):
@staticmethod
def forward(ctx, feat):
ctx.feat = feat
return swish_cpp.hswish_forward(feat)
@staticmethod
def backward(ctx, grad_output):
feat = ctx.feat
return swish_cpp.hswish_backward(grad_output, feat)
class HSwishV3(nn.Module):
def __init__(self):
super(HSwishV3, self).__init__()
def forward(self, feat):
return HSwishFunctionV3.apply(feat)
if __name__ == "__main__":
import torchvision
net = torchvision.models.resnet50(pretrained=True)
sd = {k: v for k, v in net.state_dict().items() if k.startswith('conv1.') or k.startswith('bn1.')}
class Net(nn.Module):
def __init__(self, act='hswishv1'):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 7, 2, 3)
self.bn1 = nn.BatchNorm2d(64)
if act == 'hswishv1':
self.act1 = HSwishV1()
elif act == 'hswishv2':
self.act1 = HSwishV2()
elif act == 'hswishv3':
self.act1 = HSwishV3()
self.dense = nn.Linear(64, 10, bias=False)
self.crit = nn.CrossEntropyLoss()
state = self.state_dict()
state.update(sd)
self.load_state_dict(state)
# torch.nn.init.constant_(self.dense.weight, 1)
def forward(self, feat, label):
feat = self.conv1(feat)
feat = self.bn1(feat)
feat = self.act1(feat)
feat = torch.mean(feat, dim=(2, 3))
logits = self.dense(feat)
loss = self.crit(logits, label)
return loss
net1 = Net(act='hswishv1')
net2 = Net(act='hswishv3')
net2.load_state_dict(net1.state_dict())
net1.cuda()
net2.cuda()
opt1 = torch.optim.SGD(net1.parameters(), lr=1e-3)
opt2 = torch.optim.SGD(net2.parameters(), lr=1e-3)
bs = 32
for i in range(10000):
inten = torch.randn(bs, 3, 224, 224).cuda().detach()
label = torch.randint(0, 10, (bs, )).cuda().detach()
loss1 = net1(inten, label)
opt1.zero_grad()
loss1.backward()
opt1.step()
loss2 = net2(inten, label)
opt2.zero_grad()
loss2.backward()
opt2.step()
if i % 200 == 0:
print('====')
print('loss diff: ', loss1.item() - loss2.item())
print('weight diff: ', torch.sum(torch.abs(net1.conv1.weight - net2.conv1.weight)).item())
from torch.autograd import gradcheck
inten = torch.randn(3, 4, 6, 6).cuda()
inten.requires_grad_(True)
gradcheck(HSwishFunctionV3.apply, [inten, ])
728x90
반응형