PyTorchのSeq2Seqをミニバッチ化するときに気をつけたこと
概要
PyTorchチュートリアルに、英仏の機械翻訳モデルを作成するTranslation with a Sequence to Sequence Network and Attentionがあります。 このチュートリアルは、教師データを一つずつ与える形になっており、結構遅いのです。 なので、バッチでの学習ができるように修正を試みたところ、注意ポイントがいくつかあったのでまとめておきます。
RNNのバッチ学習の実装
RNNでバッチ学習を行う際に問題となるのが、入力されるデータ系列の長さがバッチ内で異なることです。 この問題には一般的に、バッチ内での長さを揃えるためのパディングと、パディングした部分が学習の邪魔にならないようにするマスキングを実装して対処する必要があります。
実装自体は割と簡単にできますが、きちんと実装しないと学習が全然進まなかったりするので注意が必要です。
パディング
パディング自体はそう難しい処理ではありませんが、ググったりフォーラムを参照したり、調べ始めるといろいろやり方があって混乱してしまいました。結果として2つに落ち着きました。
- 元データに対して、パディングトークン(今回は0)を必要な長さになるまで追加する。
- 元データを個別にTensorにし、
pad_sequence
を使う。
大規模な学習を行う必要のない場合は、前者で良さそうです。
1つ目のやり方
こちらは説明の必要も無いくらい簡単にできます。 単純に足りない長さ分のパディングトークンを追加するだけです。
max_length = 10 seq = [1, 1, 1, 1, 1] seq += [0] * (max_length - len(seq)) seq
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
悩むのは、max_length
をバッチ内での最大系列長にするか、グローバルな定数にしておくかぐらいでしょうか。
2つ目のやり方
こちらはpad_sequence
を使う方法ですが、シーケンスの長さが降順になるようにソートする必要があり、少々めんどくさいです。
a = torch.ones(2) b = torch.ones(3) c = torch.ones(4) d = torch.ones(3) e = torch.ones(1) # 長さが降順になるようにソート sorted_tensors = sorted([a, b, c, d, e], key=lambda x: x.shape[0], reverse=True) # Padding nn.utils.rnn.pad_sequence(sorted_tensors, batch_first=True)
tensor([[ 1., 1., 1., 1.], [ 1., 1., 1., 0.], [ 1., 1., 1., 0.], [ 1., 1., 0., 0.], [ 1., 0., 0., 0.]])
pad_sequence
は、すでにTensorとして用意されているデータをリストに格納し、パディング処理するという用途には向いていそうですが、これを単体で使うことはあまり無い気がしています。
公式ドキュメントのFAQにある、My recurrent network doesn’t work with data parallelismを読む限りでは、複数GPUや分散環境での学習時に必要な、PackedSequenceという仕組みを使うときに使うようです。単一の環境で行う場合は、無理に使う必要も無いでしょう。
マスキング
マスキングは、以下の2点が適切に行われる必要があるでしょう。
- Embedding
- Loss
Embedding
Embeddingでは、padding_idx
引数を指定する必要があります。
これを指定することで、パディングされた部分の埋め込みをすべて0にすることができます。
num_input = 10 emb_size = 5 embedding = nn.Embedding(num_input, emb_size, padding_idx=0) batch = torch.LongTensor([ [1, 2, 3, 0, 0, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 0, 0, 0, 0, 0] ]) embedding(batch)
tensor([[[ 0.6500, 0.1616, -1.1696, -0.0516, -0.9050], [-0.4270, 1.1525, -0.8994, -1.0899, -0.6576], [ 0.4006, 0.3189, 0.1728, 1.4344, 2.0811], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], [[ 0.6500, 0.1616, -1.1696, -0.0516, -0.9050], [-0.4270, 1.1525, -0.8994, -1.0899, -0.6576], [ 0.4006, 0.3189, 0.1728, 1.4344, 2.0811], [ 0.6779, -0.7535, 0.1944, 0.8275, -0.5984], [ 0.9328, -1.4141, 1.0738, 1.5253, -1.1572], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
Loss
Seq2Seqをミニバッチで行う場合、損失の計算を行う際にパディング部分を適切に処理しないと、損失の計算結果がで大きく変わってしまいます。これによって、損失を過大評価したり過小評価したりといったことが生じてしまうので、パディング部分をマスクして損失は計算する必要があります。
loss1 = nn.NLLLoss() loss2 = nn.NLLLoss(ignore_index=0) # ignore_indexを指定 pred = torch.rand(100, 5) true = torch.randint(high=5, size=(100,), dtype=torch.long) print(loss1(pred, true)) print(loss1(pred[true > 0], true[true > 0])) print(loss2(pred, true))
tensor(-0.4701) tensor(-0.4473) tensor(-0.4473)
パディング部分をマスクした損失は、上記の例だとloss1(pred[true > 0], true[true > 0])
で計算することも可能ですが、loss2
のようにignore_index
を指定することでも実現できます。面倒なので、ignore_index
を指定したほうが良いでしょう。
その他
PyTorchに限らず、深層学習の実装を行う際は、層に期待されている入出力がどういうサイズのテンソルなのかを適切に把握しておく必要があります。PyTorchのドキュメントでは、引数にテンソルを取る場合のサイズや順序を明確に示しているので、ドキュメントをよく読みましょう。
僕が実装の過程で躓いたのは以下の2つくらいでした。
- RNNの
batch_first
引数:デフォルトでは入力も出力も(seq_len, batch, input_size)
というサイズのTensorだが、これをTrue
にすると(batch, seq_len, input_size)
になる。 - 内積計算:バッチに含まれているTensor同士の内積の計算は、
torch.bmm
で実現できる。torch.transpose
やsqueeze、unsqeezeなどと組み合わせてサイズの順番を適切にしてから計算する必要がある。
まとめ
PyTorchのチュートリアルも、Attention機構になってくると複雑になってきます。 パディングとマスキング周りは結構調べながら実装しました。テンソルのサイズは最初は混乱しますが、丁寧に一つ一つの処理を追い、transposeやsqueeze, unsqeezeを駆使しながら実装するのは、パズルのようで楽しい作業でもあります。
一応、今回実装したものをNotebookにしてあります。Decoderは、チュートリアルとは別のバージョンを一つ実装しています。Attentionっていろいろ種類がありますね。 いずれもろくにチューニングしていないので、精度的に良いという感じではありません。残念ながら、Attentionの結果も納得の行くものにはなっていません。
PyTorchにはfairseqのようなパッケージもあるので、こういうものを活用して、Seq2Seqは手軽に実装できるようになっておきたいです。