👾Tricks for Image Datasets

Summary of Preprocessing Tricks for Image Datasets in Numerous Research Papers’ Source Code

import os.path
from itertools import cycle
import numpy as np
import torch
from torch.utils.data import DataLoader
from options.vae_options import VAEOptions
from vae.data.base_dataset import BaseDataset, get_transform
from PIL import Image
import random
import warnings

from vae.vae_utils import imshow_grid

warnings.filterwarnings("ignore")

"""
    The source code refers to the data processing code of AttentionGAN.
    AttentionGAN Code: https://github.com/Ha0Tang/AttentionGAN.git
    AttentionGAN Paper: https://arxiv.org/abs/1911.11897
"""


IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]


class ImageDataset(BaseDataset):
    """
    This dataset class can load unaligned/unpaired datasets.
    """

    def __init__(self, opt):
        """Initialize this dataset class.
        Parameters: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseDataset.__init__(self, opt)

        """
        1. To retrieve the file where 'data' is located
        * "opt.data_root" is the root file path for the data.
        * "IMG_EXTENSIONS" is a collection of file extensions for various image file formats.
        """
        FileList = []
        for dirname in os.listdir(opt.data_root):
            path = os.path.join(opt.data_root, dirname)
            if '.DS_Store' in path:  # [1:] Excludes .DS_Store from macOS
                continue
            for filename in os.listdir(path):
                if filename.endswith(tuple(IMG_EXTENSIONS)):
                    FileList.append(os.path.join(opt.data_root, dirname, filename))
        random.shuffle(FileList)  # [2:] Shuffle the sorting of FileList.
        self.pic_paths = FileList
        self.pics_size = len(self.pic_paths)
        input_nc = self.opt.input_nc  # [3:] get the number of channels of input image
        output_nc = self.opt.output_nc  # get the number of channels of output image
        self.transform_pic = get_transform(self.opt, grayscale=(input_nc == 1))

        """
             2. Build a dictionary of image data in the format of { label : image_path }.
             * The format of the image path is '/path/to/data/label-xxx/demo.jpg', 
               where "label" refers to the label of the image.
        """
        self.pic_dict = {}
        for i in range(self.pics_size):
            image_path = self.pic_paths[i]
            # [1:] Return the parent directory and the name of the image.
            root_path, image_name = os.path.split(image_path)
            # [2:] '/path/to/data/label-xxx/demo.jpg' or '/path/to/data/label/demo.jpg'
            label = int(os.path.basename(root_path).split('-')[0])
            try:
                self.pic_dict[label]
            except KeyError:
                self.pic_dict[label] = []
            self.pic_dict[label].append(image_path)

    def __getitem__(self, index):
        """
        Return a data point and its metadata information.
        Parameters: index (int) -- a random integer for data indexing
        Returns a dictionary that contains Image A, Image B (image with the same label as A), label.
        """
        pic_path = self.pic_paths[index % self.pics_size]  # make sure index is within then range
        imageA = Image.open(pic_path).convert('RGB')
        imageA = self.transform_pic(imageA)
        label = int(os.path.basename(os.path.dirname(pic_path)).split('-')[0])
        imageB_path = random.SystemRandom().choice(self.pic_dict[label])
        imageB = Image.open(imageB_path).convert('RGB')
        imageB = self.transform_pic(imageB)
        return imageA, imageB, label

    def __len__(self):
        """Return the total number of images in the dataset.
        """
        return self.pics_size


if __name__ == '__main__':
    """
    test code for data loader
    """
    opt = VAEOptions().parse()  # get training options
    dataset = ImageDataset(opt)

    print("dataset [%s] was created" % type(dataset).__name__)

    dataloader = cycle(torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        shuffle=not opt.serial_batches,
        num_workers=int(opt.num_threads)))

    print(dataset.pic_dict.keys())
    image_batch, image_batch_2, labels_batch = next(dataloader)
    print(labels_batch)

    image_batch = np.transpose(image_batch, (0, 2, 3, 1))
    imshow_grid(image_batch)

    image_batch_2 = np.transpose(image_batch_2, (0, 2, 3, 1))
    imshow_grid(image_batch_2)

1. Aligning image formats using the transforms module

The torchvision.transformsmodule provides many common image transformations, including:

  • CenterCrop: Crops a given area from the center of the image.

  • ColorJitter: Randomly changes the brightness, contrast and saturation of an image.

  • Grayscale: Converts an image to grayscale.

  • Normalize: Normalizes a tensor image with given mean and standard deviation.

  • RandomHorizontalFlip: Horizontally flips an image with a given probability.

  • RandomRotation: Rotates an image by a random angle within a given range.

  • Resize: Resizes the input image to the given size.

# The code in this section is referenced from the code in the link below :
# Code: https://github.com/Ha0Tang/AttentionGAN.git
# Paper: https://arxiv.org/abs/1911.11897
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    if 'resize' in opt.preprocess:
        osize = [opt.load_size, opt.load_size]
        # 调整图像大小
        transform_list.append(transforms.Resize(osize, method))
    elif 'scale_width' in opt.preprocess:
        # 按比例缩放图像
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))

    if 'crop' in opt.preprocess:
        if params is None:
            # 随机裁剪图像
            transform_list.append(transforms.RandomCrop(opt.crop_size))
        else:
            # 裁剪图像
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

    if opt.preprocess == 'none':
        # 将图像大小调整为2的幂次方
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

    if not opt.no_flip:
        if params is None:
            # 随机水平翻转图像
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            # 翻转图像
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    if convert:
        # 转换图像为张量
        transform_list += [transforms.ToTensor()]
        if grayscale:
            # 归一化图像
            transform_list += [transforms.Normalize((0.5,), (0.5,))]
        else:
            # 归一化图像
            transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

最后更新于

这有帮助吗?