zuminote

個人的な勉強記録

PyTorchによる発展ディープラーニング 第1章 1-1 メモ

kaggleのRSNAコンペをやろうとしたのだけど、PyTorchを自分で理解して書けているか微妙な気がしたので、復習。

 去年夏ごろに買って物体認識のところのみやってみたが、何がなんだかよくわからず断念した本だった。今読んだら当時より分かるようにはなっていたので若干の進歩を感じる。

サンプルコードはこちら。

github.com

 

 ここでは、学習済みのVGG-16モデルを使って画像分類をする。

VGG-16の学習済みモデルは、torchvision.modelsのメソッドにより簡単に生成できる。
この学習済みモデルは、ImageNetデータセットのうち、ILSVRC2012データセットで学習させたものである。

vgg16 = models.vgg16(pretrained=True)  

VGG-16はfeaturesとclassifierという2つのモジュールに分かれており、畳み込みとプーリングを重ねたfeaturesで画像の特徴抽出を行い、全結合+ドロップアウトのclassifierに突っ込んで分類を行うというものらしい。

ちなみに16層のカウントは畳み込み層と全結合層のみ数えたもの。ReLU、プーリング層、ドロップアウト層を含めると全部で38層から成っている。

準備1. 入力画像の前処理クラスBaseTransformの実装

まず、学習済みVGG-16に画像を入力するために以下の処理が必要になる。

  • 画像サイズを224×224にリサイズする
  • 色情報の規格化

色情報の規格化というのは、学習済みVGGモデルが学習しているILSVRC2012データセットの条件に合わせて、RGBが平均(0.485, 0.456, 0.406), 標準偏差(0.229, 0.224, 0.225)になるように標準化する必要があるということである。

前処理は、行う一連の処理をtransforms.Composeを使って記述する。

ここでは短辺を224にリサイズ→画像中央を224×224に切り取り→Torchテンソルに変換→色情報の規格化 をかけている。

その後、PyTorchとPillowでは画像要素の順番が違うので、PyTorchの(色、高さ、幅)を PIL(高さ、幅、色)に変換している。

準備2. モデルの出力結果を予測ラベルにする後処理クラスILSVRCPredictorの実装

モデルの出力は、torch.Size([1, 1000])のtorchテンソルになっている。これをndarrayに変換し、最も予測確率が高かったインデックスを取り出し、JSONからラベル名を取り出す。

ここで、.detach()の説明について、「出力値をネットワークから切り離す」とあり、???と思った。

どうやら、numpy()はコピーされたデータではなく元のデータを参照しにいってしまうので、.detach()を使わずに直接numpy()を使うとPyTorch内のパラメータを与えたことになってしまい、変数をnumpyに与えることはできないというエラーが出る、ということらしい。

参考
https://rightcode.co.jp/blog/information-technology/pytorch-automatic-differential-linear-regression-complete

画像分類

いよいよ本題の画像分類を実施。

  1. 学習済みVGG16をロードし、推論モードにする。

  2. JSONをロードし、ラベルリストの辞書型変数を生成

  3. 2のリストを後処理クラスILSVRCPredictorに与え、インスタンスpredictorを生成

  4. 入力画像をPillowで読み込む

  5. 前処理クラスBaseTransformのインスタンスを生成し、画像を前処理する

    この前処理インスタンスで返ってきているもの(img_transformedの中身)は、transforms.Composeの中の前処理関数の最終的な返り値。つまり、画像の配列である。

    https://pytorch.org/vision/stable/transforms.html#compositions-of-transforms

  6. 5で前処理をかけた画像に、ミニバッチの次元を追加し、変数inputsに格納

  7. inputsをモデル(あらかじめロードした学習済みVGG16、推論モード)に入力

  8. 出力を、上記で生成した後処理インスタンスpredictor.predict_maxにかけて予測したラベルを取得

「入力画像の予測結果: golden_retriever」と表示され、モデルがちゃんとゴールデンレトリーバーを分かってくれた。