人工知能とか犬とか

人工知能と犬に興味があります。しょぼしょぼ更新してゆきます。

PyTorchでわんにゃん分類器をつくる

概要

PyTorchで事前学習済みモデルのファインチューニングを行って、犬や猫の種類を分類できる分類器を作成しました。使用している事前学習済みモデルはResNet18、データセットThe Oxford-IIIT Pet Datasetを使用します。

特になんの工夫もしなくても、90%程度の精度で分類が実現できます。

Notebookはgithubに公開しています。

手法

PyTorchチュートリアルTransfer Learning tutorialを元に、事前学習済みのResNet18をファインチューニングすることで、わんにゃん分類器を作ります。

このチュートリアルでは以下の2通りの学習方法を示しています。

  • 事前学習済みモデル全体を学習
  • 事前学習済みモデルに追加した全結合相のみを学習

この記事では前者のみをまとめています。

データセット

チュートリアルでは、ImageNetのサブセットであるアリとハチのデータセットを用いていますが、せっかくならかわいいデータを使いたい。そういうわけで、The Oxford-IIIT Pet Datasetを使用します。

データセットは、犬25種、猫12種、全37個のクラスからなり、各クラスごとに大体200枚の画像が含まれています。

http://www.robots.ox.ac.uk/~vgg/data/pets/breed_count.jpg

データセットは展開して、以下のようなフォルダ構成にします。20%を評価用のデータに使いました

  • train
    • abyssinian
      • Abyssinian_1.jpg
      • Abyssinian_3.jpg
      • Abyssinian_5.jpg
      • ...
    • american_bulldog
    • ...
  • val
    • abyssinian
      • Abyssinian_2.jpg
      • Abyssinian_4.jpg
      • ...
    • american_bulldog
    • ...

実装

基本はチュートリアルのやり方そのままです。 一部混同行列の表示や学習の過程を示す損失と精度のプロットを入れています。詳細はNotebookを見てください。

結果

学習の過程はこんな感じ。 f:id:wanchan-daisuki:20180603133418p:plain

予測精度は92%くらい出せていて、けっこう合っています。

f:id:wanchan-daisuki:20180603135112p:plain

f:id:wanchan-daisuki:20180524130936p:plainf:id:wanchan-daisuki:20180524130935p:plainf:id:wanchan-daisuki:20180524130926p:plain
予測結果

一方で、ハズレの例を見てみると、ラグドールバーマンを混同している例が見られました。似てるからしゃーないね。 あとは、スタッフォードシャーブルテリアアメリカンピットブルテリアを混同していたり、納得の行く間違いが多いですね。

f:id:wanchan-daisuki:20180603133434p:plain

感想

公式のチュートリアルでは2クラスの分類だったのでちょっと感動が薄かったのですが、これくらいのクラス数があっても90%くらいの精度で見分けられると、なかなか楽しいです。

実はこのチュートリアル、だいぶ前にやって今回再びやり直してみたものです。 なので、所々に前のバージョンのチュートリアルのコードが残っているかもしれません。

前のバージョンでは、use_cudaというフラグを使って、CUDAで処理するならこっち、CPUでやるならそっち、というif文がいたるところに存在していました。torch.deviceやTensorsのtoメソッドによってこれが無くなって、だいぶシンプルに書けるようになったと思います。

公式の0.4.0 MigrationGuideのWriting device-agnostic codeを見ると、 今後はこの書き方が推奨されるようですね。

# at beginning of the script
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

...

# then whenever you get a new Tensor or Module
# this won't copy if they are already on the desired device
input = data.to(device)
model = MyModule(...).to(device)