人工知能とか犬とか

人工知能と犬に興味があります。しょぼしょぼ更新してゆきます。

PyTorchのDatasetとDataLoader

概要

PyTorchのチュートリアルData Loading and Processing Tutorial をやってみて、DatasetとDataLoaderの使い方を学ぶのです。

DatasetとDataLoader

そもそも、深層学習で用いる教師データは、以下のような処理を必要とします。

  • データの読み込み
  • データの前処理
  • データ拡張
  • データの順番をシャッフルする
  • ミニバッチを作成する
  • シングルマシン or 分散システム中の複数のGPUにミニバッチを配布する

ここらへんの処理を簡単に実現できるのが、DdatasetとDataLoaderです。

Dataset

抽象基底クラスtorch.utils.data.Datasetを継承して、以下のメソッドを実装すれば、独自のDatasetを定義できます。

  • __len__: データセットに含まれる全要素の数を返す。
  • __getitem__: i番目のサンプルをdataset[i]という形で取得できるようにするためのメソッド。
class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

これは、以下のようにiterableなオブジェクトとして扱うことができます。

face_dataset = FaceLandmarksDataset(
    csv_file='faces/face_landmarks.csv',
    root_dir='faces/')

for i, sample in enumerate(face_dataset):
    print(sample['image'].shape, sample['landmarks'].shape)
    
    if i > 4:
        break
(324, 215, 3) (68, 2)
(500, 333, 3) (68, 2)
(250, 258, 3) (68, 2)
(434, 290, 3) (68, 2)
(828, 630, 3) (68, 2)
(402, 500, 3) (68, 2)

DataLoader

Datasetは単体では、ただのiterableなオブジェクトに見えますが、DataLoaderと組み合わせることで、最初に挙げた処理を効率的に記述することができます。

まずはバッチサイズが1のミニバッチを生成するDataLoaderを作ってみます。

face_dataset = FaceLandmarksDataset(
    csv_file='faces/face_landmarks.csv',
    root_dir='faces/')
dataloader = DataLoader(face_dataset, batch_size=1,
                        shuffle=True, num_workers=4)

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())
    
    if i_batch > 4:
        break
0 torch.Size([1, 500, 365, 3]) torch.Size([1, 68, 2])
1 torch.Size([1, 333, 500, 3]) torch.Size([1, 68, 2])
2 torch.Size([1, 500, 335, 3]) torch.Size([1, 68, 2])
3 torch.Size([1, 402, 500, 3]) torch.Size([1, 68, 2])
4 torch.Size([1, 334, 500, 3]) torch.Size([1, 68, 2])
5 torch.Size([1, 333, 500, 3]) torch.Size([1, 68, 2])

このように、ミニバッチをDatasetから簡単に作成することができました。上記の例では、imageのサイズが全てバラバラなので、このままでは、複数のサンプルのimageをまとめてミニバッチを作ることはできません。ミニバッチを作るためには、Datasetの__getitem__内でサンプルであるimageのサイズを揃える前処理が必要です。

TransformとCompose

チュートリアルでは、callableなRescaleとRandomCroopというクラスを作成していますが、特にパラメータが必要なければ、sampleを引数として受け取り、変更されたsampleを返す関数でも問題ありません。

def flip_img(sample):
    """
    画像だけを左右反転
    """
    image, landmarks = sample['image'], sample['landmarks']
    image = image[:, ::-1]  # 画像データを左右反転
    return {'image': image, 'landmarks': landmarks}

また、複数のTransformをまとめて、順次処理のパイプラインにすることもできます。 torchvision.transforms.Composeは、以下のように使用できます。

composed = transforms.Compose([
    flip_img,
    Rescale(256),
    RandomCrop(224),
    ToTensor()
])

ComposeをDatasetの__getitem__中で使用することで、ミニバッチ中の画像サイズの大きさを揃えることができます。これでミニバッチにもできますね。

transformed_dataset = FaceLandmarksDataset(
    csv_file='faces/face_landmarks.csv',
    root_dir='faces/',
    transform=transforms.Compose([  # transform引数にcomposeを与える
       Rescale(256),
       RandomCrop(224),
       ToTensor()
    ]))

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    if i_batch == 3:
        break
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

(おまけ)DataLoaderのcollate_fn

DataLoaderにはcollate_fnという引数を指定できます。実装をみてみると、リストとして処理されるミニバッチ内のサンプルを、最後にTensorに変換する処理を行っていることがわかります。

今ひとつ使い所がわかりませんが、1つのサンプルからデータを抽出してミニバッチを作成するときなどに使うのかもしれません。

まとめ

チュートリアルをもとに、DatasetとDataLoaderの使い方を確認してきました。

DatasetとDataLoaderは適切に使えれば、データの前処理をきれいに書くことができる便利なクラスです。Datasetは、最初にデータを一括して読み込んでインメモリで処理したいとか、毎回ディスクから画像を読み込んでデータ拡張をしたいといった様々なケースに対応可能なようになっています。

効率的な処理ときれいなコードのために、習得しておきたいところです。