PyTorchのRNNとRNNCell
概要
PyTorchでRNNを使った実装しようとするとき、torch.nn.RNNとtorch.nn.RNNCellというものがあることに気がつきました。 それぞれの違いを明らかにして、注意点を整理しておきたいのです。
リカレント層の実装方法
PyTorchチュートリアルの、名前分類をこなしていて、RNNの実装方法について調べようと思ったのがことの発端。チュートリアルでは、RNNモジュールをイチからで実装しているが、実務上イチからRNNを実装することはほぼ無いと思われるので、調べてみたら、torch.nn.RNNとtorch.nn.RNNCellを見つけました。また、代表的なリカレント系レイヤーであるLSTMとGRUについても、torch.nn.LSTM、torch.nn.GRU以外に、torch.nn.LSTMCellとtorch.nn.GRUCellがあることがわかりました。
そんなわけで、「Cell」の有無で何が違うのかを調べてみました。
RNNとRNNCellの違い
公式のフォーラムに、ズバリそのものの質問がありました。図を引用してみましょう。ここでは、時系列データが想定されていて、長さ7の入力系列を3層のRNNで受け取り、長さ7の出力系列を出力しています。図ではRNNの出力系列は7つになっていますが、Attention機構をつくらない場合は、出力系列の最後(一番右上の青い四角)のみを利用して、その後の分類やら回帰のためのレイヤーへとつなげることが一般的かと思います。 入力系列が赤、RNNが緑、出力系列が青で示されています。
この図の中で、torch.nn.RNNCellとtorch.nn.RNNがどこに対応するのかを見てみましょう。 まず、torch.nn.RNNCellは下図の赤枠で示すような範囲、つまり、一つ一つの緑の四角に対応します。torch.nn.RNNCellは、下と左から1つずつの入力を受け、上と右に出力します。つまり、入力系列の一つと前のRNNCellの状態を初期状態として入力し、更新された隠れ状態を上(スタックされたRNNCellまたは出力系列)と右(次のRNNCellへの状態入力)に出力しています。
次に、torch.nn.RNNは、下図の青枠で示すような範囲、つまり緑の四角すべてです。torch.nn..RNNは、下から入力系列を受け、上に出力系列を出力します。系列中のRNNCell間の状態のやり取りは、全てtorch.nn.RNN内部で行われます。
torch.nn.RNNCellの使い方
具体的なRNNCellの使い方を確認しましょう。チュートリアルで行っている、名前から国を予測するModuleを、GRUCellを使って実装してみました。 GRUでやっていますが、RNNでもほとんど同じです。
データの準備などはチュートリアルどおりなので割愛するとして、具体的にどのようにネットワークを定義するのかを示します。一応Notebookを上げておきます。
class StackedGRU(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(StackedGRU, self).__init__() self.hidden_size = hidden_size self.gru1 = nn.GRUCell(input_size, hidden_size) self.gru2 = nn.GRUCell(hidden_size, hidden_size) self.linear = nn.Linear(hidden_size, output_size) self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hiddens): hidden1 = self.gru1(input, hiddens[0]) hidden2 = self.gru2(hidden1, hiddens[1]) output = self.linear(hidden2) output = self.softmax(output) return output, [hidden1, hidden2] def initHidden(self): return [torch.zeros(1, self.hidden_size), torch.zeros(1, self.hidden_size)] n_hidden = 128 rnn = StackedGRU(n_letters, n_hidden, n_categories)
この例ではGRUCellを2層にスタックしています。forwardメソッドを確認すると、引数はinput(入力系列の1要素)とhiddens(2つのGRUCellの状態)で、所属クラスの予測値(output)と、2つのGRUCellの状態(hidden1とhidden2)を返します。これらの返り値は、最終的な所属クラスの予測や次の入力のhiddensになります。以下のtrain関数を見ればどんなふうに使われるのかがわかるでしょう。
def train(category_tensor, line_tensor): hidden = rnn.initHidden() rnn.zero_grad() for i in range(line_tensor.size()[0]): # 一文字ずつ入力する output, hidden = rnn(line_tensor[i], hidden) # hiddenは次の入力にする。outputは最後以外毎回捨てる。 loss = criterion(output, category_tensor) loss.backward() # Add parameters' gradients to their values, multiplied by learning rate for p in rnn.parameters(): p.data.add_(-learning_rate, p.grad.data) return output, loss.item()
RNNCellでは、一文字ずつ入力し、RNNCellの状態を更新し、最後に所属クラスを得る、という流れを自分で実装するので、何をやっているのかがわかりやすいといえばわかりやすいですが、めんどくさいといえばめんどくさいです。
torch.nn.RNNの使い方
次に、RNNを見ていきましょう。こちらもGRUでやっています。Notebookも一応上げておきます。
class BidirectionalGRU(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(BidirectionalGRU, self).__init__() self.hidden_size = hidden_size self.num_layers = 1 self.bigru = nn.GRU(input_size, hidden_size, num_layers=self.num_layers, bidirectional=True) self.linear = nn.Linear(hidden_size*2, output_size) self.softmax = nn.LogSoftmax(dim=1) def forward(self, input): input = input.to(device) hidden = torch.zeros(self.num_layers*2, 1, self.hidden_size).to(device) # 各GRUの初期状態は、一つのTensor。 output, _ = self.bigru(input, hidden) # (出力系列, 各GRUの状態が一つのTensorにまとめられたもの) output = output[-1] # 出力系列の最後のみ使用 output = self.linear(output) output = self.softmax(output) return output n_hidden = 128 rnn = BidirectionalGRU(n_letters, n_hidden, n_categories) rnn = rnn.to(device)
RNNはRNNCellとは違い、複数層のリカレントレイヤーを一行で書ける分、引数がいくらか複雑です。
インスタンス作成時の引数を見てみると、num_layers
は、リカレントレイヤーを何層スタックするのかを表す引数で、bidirectional
は双方向にするかを表す引数です。推論時の引数にちょっと注意が必要です*1。入力系列が第1引数で、第2引数に各GRUCellの状態の初期値を与える必要があるのですが、リストやタプルではなくて一つのTensorにしてまとめてあげる必要があります。ここのサイズの設定がややこしいのでちゃんと理解しておく必要があります。
hidden = torch.zeros(self.num_layers*2, 1, self.hidden_size).to(device) # 各GRUの初期状態は、一つのTensor。
サイズは、(Cellの個数, 1, 隠れ層の次元数)という順番で設定します。1になっているところは、この例のバッチサイズが1だからです。上の例では、GRUのスタック数が1で、双方向なので2倍しているという感じです。
LSTMの場合は、内部状態が2つあるためにまた違った指定になるようで、hとcを別々のTensorにして、タプルとして入力する必要があるようです(参考)。
torch.nn.RNNを使えば、学習時にfor文で回す必要がなくなります。こっちのほうがスッキリして書けるので好きです。
def train(category_tensor, line_tensor): rnn.zero_grad() output = rnn(line_tensor) # 一発で所属クラスの予測ができる。 loss = criterion(output, category_tensor.to(device)) loss.backward() # Add parameters' gradients to their values, multiplied by learning rate for p in rnn.parameters(): p.data.add_(-learning_rate, p.grad.data) return output, loss.item()
まとめ
PyTorchのtorch.nn.RNNとtorch.nn.RNNCellの違いについて確認しました。torch.nn.RNNを使うときは、内部状態の次元数、いくつスタックしているのか、双方向か否かをきちんと把握しておく必要があります。
*1:RNNの推論時に内部状態の初期値は、デフォルト値としてゼロが入るので、特にこだわりが無ければ入れないでいいです。しかし、RNNCellの内部状態の初期値の設定は、生成モデルなどでは必ず使う部分なので、デフォルト値以外の指定方法も知っておくべきだと思います。