PyTorchのDatasetとDataLoader
概要
PyTorchのチュートリアルData Loading and Processing Tutorial をやってみて、DatasetとDataLoaderの使い方を学ぶのです。
DatasetとDataLoader
そもそも、深層学習で用いる教師データは、以下のような処理を必要とします。
- データの読み込み
- データの前処理
- データ拡張
- データの順番をシャッフルする
- ミニバッチを作成する
- シングルマシン or 分散システム中の複数のGPUにミニバッチを配布する
ここらへんの処理を簡単に実現できるのが、DdatasetとDataLoaderです。
Dataset
抽象基底クラスtorch.utils.data.Dataset
を継承して、以下のメソッドを実装すれば、独自のDatasetを定義できます。
__len__
: データセットに含まれる全要素の数を返す。__getitem__
: i番目のサンプルをdataset[i]という形で取得できるようにするためのメソッド。
class FaceLandmarksDataset(Dataset): """Face Landmarks dataset.""" def __init__(self, csv_file, root_dir, transform=None): self.landmarks_frame = pd.read_csv(csv_file) self.root_dir = root_dir self.transform = transform def __len__(self): return len(self.landmarks_frame) def __getitem__(self, idx): img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0]) image = io.imread(img_name) landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix() landmarks = landmarks.astype('float').reshape(-1, 2) sample = {'image': image, 'landmarks': landmarks} if self.transform: sample = self.transform(sample) return sample
これは、以下のようにiterableなオブジェクトとして扱うことができます。
face_dataset = FaceLandmarksDataset( csv_file='faces/face_landmarks.csv', root_dir='faces/') for i, sample in enumerate(face_dataset): print(sample['image'].shape, sample['landmarks'].shape) if i > 4: break
(324, 215, 3) (68, 2) (500, 333, 3) (68, 2) (250, 258, 3) (68, 2) (434, 290, 3) (68, 2) (828, 630, 3) (68, 2) (402, 500, 3) (68, 2)
DataLoader
Datasetは単体では、ただのiterableなオブジェクトに見えますが、DataLoaderと組み合わせることで、最初に挙げた処理を効率的に記述することができます。
まずはバッチサイズが1のミニバッチを生成するDataLoaderを作ってみます。
face_dataset = FaceLandmarksDataset( csv_file='faces/face_landmarks.csv', root_dir='faces/') dataloader = DataLoader(face_dataset, batch_size=1, shuffle=True, num_workers=4) for i_batch, sample_batched in enumerate(dataloader): print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size()) if i_batch > 4: break
0 torch.Size([1, 500, 365, 3]) torch.Size([1, 68, 2]) 1 torch.Size([1, 333, 500, 3]) torch.Size([1, 68, 2]) 2 torch.Size([1, 500, 335, 3]) torch.Size([1, 68, 2]) 3 torch.Size([1, 402, 500, 3]) torch.Size([1, 68, 2]) 4 torch.Size([1, 334, 500, 3]) torch.Size([1, 68, 2]) 5 torch.Size([1, 333, 500, 3]) torch.Size([1, 68, 2])
このように、ミニバッチをDatasetから簡単に作成することができました。上記の例では、imageのサイズが全てバラバラなので、このままでは、複数のサンプルのimageをまとめてミニバッチを作ることはできません。ミニバッチを作るためには、Datasetの__getitem__
内でサンプルであるimageのサイズを揃える前処理が必要です。
TransformとCompose
チュートリアルでは、callableなRescaleとRandomCroopというクラスを作成していますが、特にパラメータが必要なければ、sampleを引数として受け取り、変更されたsampleを返す関数でも問題ありません。
def flip_img(sample): """ 画像だけを左右反転 """ image, landmarks = sample['image'], sample['landmarks'] image = image[:, ::-1] # 画像データを左右反転 return {'image': image, 'landmarks': landmarks}
また、複数のTransformをまとめて、順次処理のパイプラインにすることもできます。
torchvision.transforms.Compose
は、以下のように使用できます。
composed = transforms.Compose([ flip_img, Rescale(256), RandomCrop(224), ToTensor() ])
ComposeをDatasetの__getitem__
中で使用することで、ミニバッチ中の画像サイズの大きさを揃えることができます。これでミニバッチにもできますね。
transformed_dataset = FaceLandmarksDataset( csv_file='faces/face_landmarks.csv', root_dir='faces/', transform=transforms.Compose([ # transform引数にcomposeを与える Rescale(256), RandomCrop(224), ToTensor() ])) dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4) for i_batch, sample_batched in enumerate(dataloader): print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size()) if i_batch == 3: break
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2]) 1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2]) 2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2]) 3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
(おまけ)DataLoaderのcollate_fn
DataLoaderにはcollate_fn
という引数を指定できます。実装をみてみると、リストとして処理されるミニバッチ内のサンプルを、最後にTensorに変換する処理を行っていることがわかります。
今ひとつ使い所がわかりませんが、1つのサンプルからデータを抽出してミニバッチを作成するときなどに使うのかもしれません。
まとめ
チュートリアルをもとに、DatasetとDataLoaderの使い方を確認してきました。
DatasetとDataLoaderは適切に使えれば、データの前処理をきれいに書くことができる便利なクラスです。Datasetは、最初にデータを一括して読み込んでインメモリで処理したいとか、毎回ディスクから画像を読み込んでデータ拡張をしたいといった様々なケースに対応可能なようになっています。
効率的な処理ときれいなコードのために、習得しておきたいところです。