728x90
반응형

 

딥러닝에서 사용되는 다양한 손실 함수를 구현해 놓은 좋은 Github 를 아래와 같이 소개한다. 

 

 

 

github.com/CoinCheung/pytorch-loss

 

CoinCheung/pytorch-loss

label-smooth, amsoftmax, focal-loss, triplet-loss. Maybe useful - CoinCheung/pytorch-loss

github.com

 

from pytorch_loss import SwishV1, SwishV2, SwishV3
from pytorch_loss import HSwishV1, HSwishV2, HSwishV3
from pytorch_loss import MishV1, MishV2, MishV3
from pytorch_loss import convert_to_one_hot, convert_to_one_hot_cu, OnehotEncoder
from pytorch_loss import EMA

from pytorch_loss import TripletLoss
from pytorch_loss import SoftDiceLossV1, SoftDiceLossV2, SoftDiceLossV3
from pytorch_loss import PCSoftmaxCrossEntropyV1, PCSoftmaxCrossEntropyV2
from pytorch_loss import LargeMarginSoftmaxV1, LargeMarginSoftmaxV2, LargeMarginSoftmaxV3
from pytorch_loss import LabelSmoothSoftmaxCEV1, LabelSmoothSoftmaxCEV2, LabelSmoothSoftmaxCEV3
from pytorch_loss import generalized_iou_loss
from pytorch_loss import FocalLossV1, FocalLossV2, FocalLossV3
from pytorch_loss import Dual_Focal_loss
from pytorch_loss import GeneralizedSoftDiceLoss, BatchSoftDiceLoss
from pytorch_loss import AMSoftmax
from pytorch_loss import AffinityFieldLoss
from pytorch_loss import OhemCELoss, OhemLargeMarginLoss
from pytorch_loss import LovaszSoftmax

from pytorch_loss import CoordConv2d, DY_Conv2d

 

위와 같이 손쉽게 쓸 수 있고, 아래와 같이 구현을 깔끔하게 해놓았다.

 

아래 예시는 hard swish 구현 예시이다.

 

import torch
import torch.nn as nn
import torch.nn.functional as F


##
# version 1: use pytorch autograd
class HSwishV1(nn.Module):

    def __init__(self):
        super(HSwishV1, self).__init__()

    def forward(self, feat):
        return feat * F.relu6(feat + 3) / 6


##
# version 2: use derived formula to compute grad
class HSwishFunctionV2(torch.autograd.Function):

    @staticmethod
    def forward(ctx, feat):
        #  act = (feat + 3).mul_(feat).div_(6).clip_(0)
        act = F.relu6(feat + 3).mul_(feat).div_(6)
        ctx.variables = feat
        return act

    @staticmethod
    def backward(ctx, grad_output):
        feat = ctx.variables
        grad = F.relu6(feat + 3).div_(6)
        grad.add_(torch.where(
            torch.eq(-3 < feat, feat < 3),
            torch.ones_like(feat).div_(6),
            torch.zeros_like(feat)).mul_(feat))
        grad *= grad_output
        return grad


class HSwishV2(nn.Module):

    def __init__(self):
        super(HSwishV2, self).__init__()

    def forward(self, feat):
        return HSwishFunctionV2.apply(feat)


##
# version 3: write with cuda which requires less memory and can be faster
import swish_cpp
class HSwishFunctionV3(torch.autograd.Function):

    @staticmethod
    def forward(ctx, feat):
        ctx.feat = feat
        return swish_cpp.hswish_forward(feat)

    @staticmethod
    def backward(ctx, grad_output):
        feat = ctx.feat
        return swish_cpp.hswish_backward(grad_output, feat)


class HSwishV3(nn.Module):

    def __init__(self):
        super(HSwishV3, self).__init__()

    def forward(self, feat):
        return HSwishFunctionV3.apply(feat)


if __name__ == "__main__":
    import torchvision
    net = torchvision.models.resnet50(pretrained=True)
    sd = {k: v for k, v in net.state_dict().items() if k.startswith('conv1.') or k.startswith('bn1.')}

    class Net(nn.Module):
        def __init__(self, act='hswishv1'):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 64, 7, 2, 3)
            self.bn1 = nn.BatchNorm2d(64)
            if act == 'hswishv1':
                self.act1 = HSwishV1()
            elif act == 'hswishv2':
                self.act1 = HSwishV2()
            elif act == 'hswishv3':
                self.act1 = HSwishV3()
            self.dense = nn.Linear(64, 10, bias=False)
            self.crit = nn.CrossEntropyLoss()
            state = self.state_dict()
            state.update(sd)
            self.load_state_dict(state)
            #  torch.nn.init.constant_(self.dense.weight, 1)
        def forward(self, feat, label):
            feat = self.conv1(feat)
            feat = self.bn1(feat)
            feat = self.act1(feat)
            feat = torch.mean(feat, dim=(2, 3))
            logits = self.dense(feat)
            loss = self.crit(logits, label)
            return loss

    net1 = Net(act='hswishv1')
    net2 = Net(act='hswishv3')
    net2.load_state_dict(net1.state_dict())
    net1.cuda()
    net2.cuda()
    opt1 = torch.optim.SGD(net1.parameters(), lr=1e-3)
    opt2 = torch.optim.SGD(net2.parameters(), lr=1e-3)
    bs = 32
    for i in range(10000):
        inten = torch.randn(bs, 3, 224, 224).cuda().detach()
        label = torch.randint(0, 10, (bs, )).cuda().detach()

        loss1 = net1(inten, label)
        opt1.zero_grad()
        loss1.backward()
        opt1.step()

        loss2 = net2(inten, label)
        opt2.zero_grad()
        loss2.backward()
        opt2.step()

        if i % 200 == 0:
            print('====')
            print('loss diff: ', loss1.item() - loss2.item())
            print('weight diff: ', torch.sum(torch.abs(net1.conv1.weight - net2.conv1.weight)).item())

    from torch.autograd import gradcheck
    inten = torch.randn(3, 4, 6, 6).cuda()
    inten.requires_grad_(True)
    gradcheck(HSwishFunctionV3.apply, [inten, ])
728x90
반응형