graphvizで機械学習モデルのネットワーク層を視覚化する
今は、会社でも、機械学習でちょっとやってみて、みたいなことを言われる時代です。
それで、Web を見ながら、それなりにやれちゃうわけですが、それをプレゼンするときに、機械学習モデルのネットワーク構成図をいい感じに見せたくなることがあるはずです。
まずは、graphviz と pydot のインストールが必要です。
何もせずにやると、ImportError になります。
graphviz 公式ドキュメント に沿って、まずは graphviz をインストールをしましょう。
pydot は pip install pydot です。
今回は、tensorflow/keras のモデルを視覚化します。
PyTorch の方は、すいません。
Windowsでgraphviz
まずは、Windows に graphviz をインストールします。
ここから、WIndows 用のインストーラーをダウンロードします。
windows_10_cmake_Release_graphviz-install-5.0.0-win64.exe という実行ファイルをダブルクリックします。
私の環境は Windows11 なのですが、やってみます。
ウィザードに沿って、インストールするだけなので、難しくないはずです。
視覚化するモデルは、はじめてのニューラルネットワーク:分類問題の初歩 で使われている画像分類モデルのネットワーク層とします。
Fashion MNIST の 28x28 の白黒画像を入力とし、一旦一次元配列にし、128の出力がある全結合層を通し、最後に10クラスに分類するネットワーク層となっています。
import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.summary()
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) (None, 784) 0
dense (Dense) (None, 128) 100480
dense_1 (Dense) (None, 10) 1290
=================================================================
それでは、モデルの視覚化をします。
keras 公式ページの モデルの可視化 にもコードの記載があります。
filename = '{}_graphviz_windows.png' .format(datetime.now().strftime('%Y%m%d'))
file_path = os.path.join(os.getcwd(), filename)
tf.keras.utils.plot_model(
model,
show_shapes=True,
to_file=file_path,
)
Linux(WSL)でgraphviz
Linux (WSL) の場合は、graphviz のインストール方法が異なるだけです。
graphviz 公式ドキュメント を見てください。
最初、このコマンドを打つと、エラーになってしまいました。
これを実行し、解決しました。
再度、インストールコマンドを打ち、インストール完了です。
pydot も pip install pydot でインストールします。
ネットワーク構成図を視覚化してみます。
同じネットワーク層なのに、Windows と Linux では、見た目がちょっと違いました。
import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.summary()
filename = '{}_graphviz_linux.png' .format(datetime.now().strftime('%Y%m%d'))
file_path = os.path.join(os.getcwd(), filename)
tf.keras.utils.plot_model(
model,
show_shapes=True,
to_file=file_path,
)