AI Development/PyTorch
[Pytorch] 이미지 데이터세트에 대한 평균(mean)과 표준편차(std) 구하기
꾸준희
2021. 5. 11. 15:50
728x90
반응형
이미지 데이터는 촬영된 환경에 따라 명도나 채도 등이 서로 모두 다르기 때문에
영상 기반 딥러닝 모델을 학습시키기 전에 모든 이미지들을 동일한 환경으로 맞춰주는 작업이 필요하다.
즉, 전체 이미지에 대한 화소 값의 평균(mean)과, 표준편차(standard deviation)를 구하여 이 값들을 영상에 일괄적으로 적용하는 과정이 필요하다.
보통 Imagenet 데이터 세트에서 계산된 평균과 표준을 사용하게 된다. 이는 수백만 개의 이미지를 기반으로 계산된다.
자신의 데이터 세트에서 처음부터 학습하려는 경우 평균과 표준을 계산할 수 있지만,
그렇지 않은 경우(대부분) 자체 평균 및 표준이있는 Imagenet 으로 학습된 pre-trained model 을 사용하는 것이 좋다.
파이토치 데이터 세트에서 평균과 표준편차 구하는 예제 (mnist)
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data.dataset import Dataset
from tqdm.notebook import tqdm
from time import time
N_CHANNELS = 1
dataset = datasets.MNIST("data", download=True,
train=True, transform=transforms.ToTensor())
full_loader = torch.utils.data.DataLoader(dataset, shuffle=False, num_workers=os.cpu_count())
before = time()
mean = torch.zeros(1)
std = torch.zeros(1)
print('==> Computing mean and std..')
for inputs, _labels in tqdm(full_loader):
for i in range(N_CHANNELS):
mean[i] += inputs[:,i,:,:].mean()
std[i] += inputs[:,i,:,:].std()
mean.div_(len(dataset))
std.div_(len(dataset))
print(mean, std)
print("time elapsed: ", time()-before)
파이토치 데이터 세트에서 정규화 예시
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)
참고로 이미지 처리에서는 각 픽셀 값이 이미 동일한 스케일(0~255)을 갖고 있는 경우가 대부분 이기 때문에
정규화 전처리 기법을 "반드시" 사용해야 하는 것은 아니라고 한다.
참고자료 2 : github.com/pytorch/examples/tree/master/imagenet
728x90
반응형