会社でも機械学習でちょっとやってみて、みたいなことを言われます。
それでWebを見ながらそれなりにやれちゃうわけですが、それをプレゼンするときに機械学習モデルのネットワーク構成図をいい感じに見せたくなることがあるはずです。
そんなときに活躍するgraphvizの記事になります。
graphvizで機械学習モデルのネットワーク層を視覚化する
今回は、tensorflow/kerasのモデルを視覚化します。
PyTorchはいつか記事にしたいです。
graphvizとpydotのインストールが必要です。
インストールしないとImportErrorになります。
ImportError: You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.
graphviz 公式ドキュメントに沿って、graphvizをインストールをしましょう。
pydotはpip install pydotです。
Windowsでgraphviz
Windowsにgraphvizをインストールします。
下記サイトからWIndows用のインストーラーをダウンロードします。
windows_10_cmake_Release_graphviz-install-5.0.0-win64.exe という実行ファイルをダブルクリックします。
私の環境はWindows11なのですが、やってみます。
ウィザードに沿ってインストールするだけなので難しくないです。
視覚化するモデルは、はじめてのニューラルネットワーク:分類問題の初歩で使われている画像分類モデルのネットワーク層とします。
Fashion MNIST の 28×28 の白黒画像を入力とし、一旦一次元配列にし、128の出力がある全結合層を通し、最後に10クラスに分類するネットワーク層となっています。
import tensorflow as tf model = tf.keras.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) model.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= flatten (Flatten) (None, 784) 0 dense (Dense) (None, 128) 100480 dense_1 (Dense) (None, 10) 1290 =================================================================
それでは、モデルの視覚化をします。
keras公式ページの モデルの可視化 にもコードの記載があります。
filename = '{}_graphviz_windows.png' .format(datetime.now().strftime('%Y%m%d')) file_path = os.path.join(os.getcwd(), filename) tf.keras.utils.plot_model( model, show_shapes=True, to_file=file_path, )
Linux(WSL)でgraphviz
Linux(WSL)の場合は、graphvizのインストール方法が異なるだけです。
graphviz 公式ドキュメントを見てください。
$ sudo apt install graphviz
最初、このコマンドを打つとエラーになってしまいました。
E: Failed to fetch http://security.ubuntu.com/ubuntu/pool/main/t/tiff/libtiff5_4.1.0+git191117-2ubuntu0.20.04.2_amd64.deb 404 Not Found [IP: 91.189.91.38 80] E: Unable to fetch some archives, maybe run apt-get update or try with --fix-missing?
これを実行し、解決しました。
$ sudo apt-get update
再度インストールコマンドを打ち、インストール完了です。
$ sudo apt install graphviz
pydotもpip install pydotでインストールします。
ネットワーク構成図を視覚化してみます。
同じネットワーク層なのに、WindowsとLinuxでは見た目がちょっと違いました。
import tensorflow as tf model = tf.keras.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) model.summary() filename = '{}_graphviz_linux.png' .format(datetime.now().strftime('%Y%m%d')) file_path = os.path.join(os.getcwd(), filename) tf.keras.utils.plot_model( model, show_shapes=True, to_file=file_path, )