👾Tricks for Image Datasets
Summary of Preprocessing Tricks for Image Datasets in Numerous Research Papers’ Source Code
import os.path
from itertools import cycle
import numpy as np
import torch
from torch.utils.data import DataLoader
from options.vae_options import VAEOptions
from vae.data.base_dataset import BaseDataset, get_transform
from PIL import Image
import random
import warnings
from vae.vae_utils import imshow_grid
warnings.filterwarnings("ignore")
"""
The source code refers to the data processing code of AttentionGAN.
AttentionGAN Code: https://github.com/Ha0Tang/AttentionGAN.git
AttentionGAN Paper: https://arxiv.org/abs/1911.11897
"""
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
class ImageDataset(BaseDataset):
"""
This dataset class can load unaligned/unpaired datasets.
"""
def __init__(self, opt):
"""Initialize this dataset class.
Parameters: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
"""
1. To retrieve the file where 'data' is located
* "opt.data_root" is the root file path for the data.
* "IMG_EXTENSIONS" is a collection of file extensions for various image file formats.
"""
FileList = []
for dirname in os.listdir(opt.data_root):
path = os.path.join(opt.data_root, dirname)
if '.DS_Store' in path: # [1:] Excludes .DS_Store from macOS
continue
for filename in os.listdir(path):
if filename.endswith(tuple(IMG_EXTENSIONS)):
FileList.append(os.path.join(opt.data_root, dirname, filename))
random.shuffle(FileList) # [2:] Shuffle the sorting of FileList.
self.pic_paths = FileList
self.pics_size = len(self.pic_paths)
input_nc = self.opt.input_nc # [3:] get the number of channels of input image
output_nc = self.opt.output_nc # get the number of channels of output image
self.transform_pic = get_transform(self.opt, grayscale=(input_nc == 1))
"""
2. Build a dictionary of image data in the format of { label : image_path }.
* The format of the image path is '/path/to/data/label-xxx/demo.jpg',
where "label" refers to the label of the image.
"""
self.pic_dict = {}
for i in range(self.pics_size):
image_path = self.pic_paths[i]
# [1:] Return the parent directory and the name of the image.
root_path, image_name = os.path.split(image_path)
# [2:] '/path/to/data/label-xxx/demo.jpg' or '/path/to/data/label/demo.jpg'
label = int(os.path.basename(root_path).split('-')[0])
try:
self.pic_dict[label]
except KeyError:
self.pic_dict[label] = []
self.pic_dict[label].append(image_path)
def __getitem__(self, index):
"""
Return a data point and its metadata information.
Parameters: index (int) -- a random integer for data indexing
Returns a dictionary that contains Image A, Image B (image with the same label as A), label.
"""
pic_path = self.pic_paths[index % self.pics_size] # make sure index is within then range
imageA = Image.open(pic_path).convert('RGB')
imageA = self.transform_pic(imageA)
label = int(os.path.basename(os.path.dirname(pic_path)).split('-')[0])
imageB_path = random.SystemRandom().choice(self.pic_dict[label])
imageB = Image.open(imageB_path).convert('RGB')
imageB = self.transform_pic(imageB)
return imageA, imageB, label
def __len__(self):
"""Return the total number of images in the dataset.
"""
return self.pics_size
if __name__ == '__main__':
"""
test code for data loader
"""
opt = VAEOptions().parse() # get training options
dataset = ImageDataset(opt)
print("dataset [%s] was created" % type(dataset).__name__)
dataloader = cycle(torch.utils.data.DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.num_threads)))
print(dataset.pic_dict.keys())
image_batch, image_batch_2, labels_batch = next(dataloader)
print(labels_batch)
image_batch = np.transpose(image_batch, (0, 2, 3, 1))
imshow_grid(image_batch)
image_batch_2 = np.transpose(image_batch_2, (0, 2, 3, 1))
imshow_grid(image_batch_2)
1. Aligning image formats using the transforms module
The torchvision.transforms
module provides many common image transformations, including:
CenterCrop
: Crops a given area from the center of the image.ColorJitter
: Randomly changes the brightness, contrast and saturation of an image.Grayscale
: Converts an image to grayscale.Normalize
: Normalizes a tensor image with given mean and standard deviation.RandomHorizontalFlip
: Horizontally flips an image with a given probability.RandomRotation
: Rotates an image by a random angle within a given range.Resize
: Resizes the input image to the given size.
# The code in this section is referenced from the code in the link below :
# Code: https://github.com/Ha0Tang/AttentionGAN.git
# Paper: https://arxiv.org/abs/1911.11897
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
transform_list = []
if grayscale:
transform_list.append(transforms.Grayscale(1))
if 'resize' in opt.preprocess:
osize = [opt.load_size, opt.load_size]
# 调整图像大小
transform_list.append(transforms.Resize(osize, method))
elif 'scale_width' in opt.preprocess:
# 按比例缩放图像
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
if 'crop' in opt.preprocess:
if params is None:
# 随机裁剪图像
transform_list.append(transforms.RandomCrop(opt.crop_size))
else:
# 裁剪图像
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
if opt.preprocess == 'none':
# 将图像大小调整为2的幂次方
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
if not opt.no_flip:
if params is None:
# 随机水平翻转图像
transform_list.append(transforms.RandomHorizontalFlip())
elif params['flip']:
# 翻转图像
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
if convert:
# 转换图像为张量
transform_list += [transforms.ToTensor()]
if grayscale:
# 归一化图像
transform_list += [transforms.Normalize((0.5,), (0.5,))]
else:
# 归一化图像
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
最后更新于
这有帮助吗?