AI Development/PyTorch

[Pytorch] 이미지 데이터세트에 대한 평균(mean)과 표준편차(std) 구하기

꾸준희 2021. 5. 11. 15:50
728x90
반응형

 

http://aikorea.org/cs231n/neural-networks-2-kr/

 

이미지 데이터는 촬영된 환경에 따라 명도나 채도 등이 서로 모두 다르기 때문에

영상 기반 딥러닝 모델을 학습시키기 전에 모든 이미지들을 동일한 환경으로 맞춰주는 작업이 필요하다.

즉, 전체 이미지에 대한 화소 값의 평균(mean)과, 표준편차(standard deviation)를 구하여 이 값들을 영상에 일괄적으로 적용하는 과정이 필요하다. 

 

보통 Imagenet 데이터 세트에서 계산된 평균과 표준을 사용하게 된다. 이는 수백만 개의 이미지를 기반으로 계산된다.

 

자신의 데이터 세트에서 처음부터 학습하려는 경우 평균과 표준을 계산할 수 있지만,

그렇지 않은 경우(대부분) 자체 평균 및 표준이있는 Imagenet 으로 학습된 pre-trained model 을 사용하는 것이 좋다. 

 

 

 

 

 

 

파이토치 데이터 세트에서 평균과 표준편차 구하는 예제 (mnist)

import os
import torch
from torchvision import datasets, transforms
from torch.utils.data.dataset import Dataset
from tqdm.notebook import tqdm
from time import time

N_CHANNELS = 1

dataset = datasets.MNIST("data", download=True,
                 train=True, transform=transforms.ToTensor())
full_loader = torch.utils.data.DataLoader(dataset, shuffle=False, num_workers=os.cpu_count())

before = time()
mean = torch.zeros(1)
std = torch.zeros(1)
print('==> Computing mean and std..')
for inputs, _labels in tqdm(full_loader):
    for i in range(N_CHANNELS):
        mean[i] += inputs[:,i,:,:].mean()
        std[i] += inputs[:,i,:,:].std()
mean.div_(len(dataset))
std.div_(len(dataset))
print(mean, std)

print("time elapsed: ", time()-before)

 

 

 

 

파이토치 데이터 세트에서 정규화 예시

 import torch
 from torchvision import transforms, datasets

 data_transform = transforms.Compose([
         transforms.RandomSizedCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
     ])
 hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                            transform=data_transform)
 dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                              batch_size=4, shuffle=True,
                                              num_workers=4)

 

 

 

 

참고로 이미지 처리에서는 각 픽셀 값이 이미 동일한 스케일(0~255)을 갖고 있는 경우가 대부분 이기 때문에
정규화 전처리 기법을 "반드시" 사용해야 하는 것은 아니라고 한다. 

 

 

 

 

 

 

 

참고자료 1 : stackoverflow.com/questions/58151507/why-pytorch-officially-use-mean-0-485-0-456-0-406-and-std-0-229-0-224-0-2

 

Why Pytorch officially use mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] to normalize images?

In this page (https://pytorch.org/docs/stable/torchvision/models.html), it says that "All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB image...

stackoverflow.com

참고자료 2 : github.com/pytorch/examples/tree/master/imagenet

 

pytorch/examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc. - pytorch/examples

github.com

 

728x90
반응형