PyTorchの分散環境学習
概要
PyTorchのチュートリアルに、分散環境での学習に関する記事がある。 自分の家のサーバーはシングルGPUなのだが、最近少々不満を感じてきている。
クラウドでの分散学習を見越して、今後のために勉強しておくことにしたのです。
なお、特に実装はありません。
分散学習の構成要素
torch.distributed
によって、マルチプロセスやクラスターでの実行、要は並列化を用意におこなえるようになります。
以下ではその構成要素を整理しています。
バックエンド
PyTorchでは、torch.multiprocessing
という並列処理の機能がありますが、これは単一のマシン内での並列化にとどまります。一方で、torch.distributed
は、以下のような並列処理用バックエンドを使った大規模クラスタでの並列処理をサポートします。
- TCP:CPUしかサポートしていない。初期実装向け?
- Gloo:CPUとGPU共にサポート。多分一番メジャー?
- NCCL:Gloo同様だが、このチュートリアル作成時に追加されたらしい。まだ詳しいことはわからないが、NVIDIA謹製だけあって、GPU間のやり取りは高速な模様。
- MPI:CPUとGPU共にサポートしているが、PyTorchのバイナリファイルにMPI用の実装が含まれていないので、自分で再コンパイルする必要があるらしい。
process, rank, group, world, size
用語の整理。
- process:分散環境で動作する個々のプロセス。
- rank:個々のprocessを指示するインデックス。Masterのprocessは0で固定みたい。
- group:Collectiveを行う範囲を指定するprocessのサブセット。
- world:全process。
- size:groupやworldに含まれるprocessの数。
通信方法
Point-to-Point
process間でTensorをやり取りをする方法です。
send/recv
dist.send
やdist.recv
によって、ノード間でTensorを送受信します。
isend/irecv
dist.isend
やdist.irecv
によって、ノード間で非同期に、Tensorを送受信します。受信内容の反映は、非同期の送受信が完了するのをwait()
してあげて、初めて保証されます。
Collective
複数のノード間でTensorを配布したり集計したりする方法です。
dist.scatter(tensor, src, scatter_list, group)
:src
のprocessからgroup
内の各processにscatter_list
で指定したTensorを送付し、各processのtensor
に格納する。dist.gather(tensor, dst, gather_list, group)
:group
内の各processからtensor
を集め、dst
processのgather_list
に格納する。dist.all_gather(tensor_list, tensor, group)
:group
内の各processからtensor
を集め、group内の全processのtensor_list
に格納する。dist.broadcast(tensor, src, group)
:group
内のsrc
processのtensor
を各processのtensor
に格納する。dist.reduce(tensor, dst, op, group)
:group
内の各processのtesnsor
を集めて、op
で指定された処理を行い、dst
processのtensor
に格納する。dist.all_reduce(tensor, op, group)
:group
内の各processのtensor
を集めて、op
で指定された処理を行い、各processのtensor
に格納する。dist.barrier(group)
:groupに含まれるprocessがこのポイントに至るまで待機する。
Operation
dist.reduce
とdist.all_reduce
で指定するop
、つまりOperationには、以下の4つがあります。すべて要素ごとに計算は行われます。
dist.reduce_op.SUM
:tensorの要素ごとの和dist.reduce_op.PRODUCT
:tensorの要素ごとの積dist.reduce_op.MAX
:tensorの要素ごとの最大値dist.reduce_op.MIN
:tensorの要素ごとの最小値
共有ファイルシステム
group内のprocess間で、同一のファイルを参照したり上書きしたりしながら処理を進めることがあります。競合が発生しないために、fcntlによるロックをサポートしているファイルシステムを採用する必要があるそうです。普通の分散ファイルシステムなら普通にサポートしていそうですが、どうなんでしょうかね?
まとめ
PyTorch TutorialのWriting Distributed Applications with PyTorchを読んで、まとめてみました。
クラウド上のリソースを使った分散学習で少しでも学習時間が減らせればと思って手を出そうとしているのですが、趣味でやっているので、お財布と要相談ですね…。