テクノロジー

pytorchを用いて学習曲線をプロットするwebアプリケーションをつくる(第2回)

前回まで

こんにちは、D2Cのビジネスエンジニアリング部 データソリューション リサーチチーム所属 吉田です。
前回に引き続き、学習曲線をプロットするwebアプリケーションをつくっていきます。
【前回の記事】
pytorchを用いて学習曲線をプロットするwebアプリケーションをつくる(第1回)

前回はこんなシステム構成図で、

こんなプロットをしました。

前回は開発時に参加していたkaggleのコンペにフィッティングした簡易的な構築でした。
今回はより汎用的にpytorchで活用できるものにしていきたいと思います。

今回のゴール

前回との差分を踏まえて、今回のゴールはこちら。

  • ログ記録用のモジュールの実装(前回はshellで部分的に用意していたが、今回は機能の一部として実装)
  • 過去ログとの比較機能(新機能)
  • リアルタイムplot(前回はページをリロードしないと更新できなかった)
  • mnistを用いた検証

【ソースコード】
mnist:MNIST_pytorch
可視化webアプリケーション:ML_VISUALIZER

環境

開発環境はEC2

  • Ubuntu 16.04.5
  • Python 3.6.6
  • Flask 1.0.2
  • gevent-websocket 0.10.1
  • jquery 3.3.1
  • D3.js v5
  • C3.js 0.6.8

システム概要図

図を少し変更しましたが、前回から構成は大きくは変更していません。

mnistの実装

ネットに転がっているものを参考に、Accuracy, Loss, LearningRateが見れるように改良しました。
【再掲】
MNIST_pytorch

いくつか今回のwebアプリケーションに関連した箇所を説明します。

importする

検証環境では/data/visualizer/にアプリケーションが存在するので、こちらをインポート

ハイパーパラメーターの記録

paramをdict型で保存してmodule.set_hyperparameterに渡す(中身については後ほど説明)

ログの記録

保存したいログとファイル名(module.set_hyperparameterの戻り値), 保存するログのmodeをmodule.set_logに渡す(中身については後ほど説明)

ログ記録用のモジュールの実装

ではmnistのソースコード内で見たモジュールについて解説します。

ハイパーパラメーターの記録

【処理の流れ】
1. title(file名)の決定(titleは指定できるが、何も指定しないと日時になる)
2. 決定したtitleと引数で受け取ったパラメーター(param)をjson形式で保存
3. titleを返す

ログの記録

基本的にepochが回るたびにログを追記。
今回はmodeとしてloss, metric, lrの3種類を上限としている。
残念ながら、複数のlossを扱うような学習にはまだ対応していない。
また、手軽な導入を考慮してDBなどでデータを扱っていないため、ハイパーパラメータと各ログが同一の学習であったことへの手がかりはファイル名となっている。そのため、module.set_hyperparameterの戻り値であるタイトルを受け取る仕様。

【処理の流れ】
1. ログの形式(loss, metric, lr)を確認
2. ファイル名を決定
3. 決定したファイル名にログを書き込んでいく

過去ログとの比較機能の実装

logファイルを保存しておくディレクトリを切って、そこに溜めていきます。
ブラウザアクセス時は、ごっそりデータがフロント側に送られる仕様です。
ユーザはハイパーパラメータのリストをもとにcheckboxで可視化したいグラフをプロットさせることができます。

ハイパーパラメータ

学習単位で個別に溜めたjsonファイルをまとめて送る

【処理の流れ】
1. jsonのファイルリストを取得する
2. それぞれpandasで読み込んで結合、DataFrameを作る
3. nullを埋めてから、もう一度jsonに変換して返す

これをflaskでxxx.xxx.xxx.xxx:oooo/parametersに送る
フロント側では、上記URLにGETのリクエストを投げることでいつでも取得できる状態

ログ

こちらは前回の仕様同様flaskでwebページをレンダリングする際に、整形した配列形式のログをフロント側へ送り、jinja2で受け取る。

【処理の流れ】
1. .logのファイルリストを取得する
2. それぞれ配列に格納して返す

これをflaskで送る。送る際にadjust_c3で今回の使用する可視化ライブラリのC3.jsのデータ形式にしておく。
例) [{‘hoge.log’: [0,1,2,3]}, {‘fuga.log’: [1,2,3,4]}]
-> [[‘hoge.log’,0,1,2,3], [‘fuga.log’,1,2,3,4]]

リアルタイムplotの実装

前回は可視化ライブラリにchart.jsを用いました。今回はより拡張性のあるライブラリを使用したかったのと、chart.jsのポップなデザインが個人的にあまり好ましくなかったので、D3.jsを検討しました。
ただ、D3.jsはフレキシブルすぎる故に今回のような単純なplotを行うアプリケーションにはオーバースペックと判断し、グラフに特化したD3.jsのラッパーライブラリであるC3.jsを採用しました。

【処理の流れ】
1. [バックエンド]初回アクセス時はflaskのrender_templateでログを送る
2. [フロント]jinja2で受け取り、c3.jsで可視化
3. [バックエンド]1秒おき(config.pyで間隔を指定可能)にログファイルが格納されているディレクトリを監視
4. [バックエンド]ログの追記を検知すると、websocketを用いてフロント側へログを送る
5. [フロント]websocketで受け取ってc3.jsで再描画

websocketで送る

flask側は下記のようなコード

1秒おきに監視というのはwhileの無限ループで1秒sleepさせる事によって実装できる。
ただ、ここでは単に1秒スリープさせるのではなく、フロント側のリクエストを1秒だけ構える実装を行った。
ws.receive()はリクエストを受け取るまで構え続けるので、with Timeout(frequency, False)frequency秒のリミットを設ける。
フロント側のページで、リロードやページを閉じた際にメッセージを受け取り、flaskがsocket通信のclosedを検知する役割がある。

下記で1つ前のループ時のログと比較を行い、

差分があった場合のみ、ログをsocketで送信する。

websocketで受け取る

websocketの初期化。ホストとポートはec2のIPとconf.pyで指定したportと同一なので、flaskでレンダリングする際に情報を送ってしまい、jinja2で受け取る。

ページから離れる際にはバックエンドにclosedを送信して、socketを切る。

下記でsocketによって送られてきたデータを扱うことができる。

下記でc3.jsの再描画を行う。
unloadでは再描画させるデータのラベルリストを、columnsではデータ自体を指定する。

実行結果

まとめ

画面のレイアウトやデザインはともかく、より汎用的な学習曲線をプロットするwebアプリケーションの開発ができました。
個人的にはwebsocketのクローズドの処理に少し苦戦しました。

ただ実践で使うとなると要望はきっとまだ出てくるでしょう。
ページも格好悪いので直したいですし、エラーハンドリングなんかはさらっとしすぎてます。
次回はもう少しブラッシュアップできればと思います。

最後まで読んでくださり、ありがとうございました。

D2C 吉田


関連タグ