graphvizで機械学習モデルを視覚化する

会社でも機械学習でちょっとやってみて、みたいなことを言われます。

それでWebを見ながらそれなりにやれちゃうわけですが、それをプレゼンするときに機械学習モデルのネットワーク構成図をいい感じに見せたくなることがあるはずです。

そんなときに活躍するgraphvizの記事になります。

graphvizで機械学習モデルのネットワーク層を視覚化する

今回は、tensorflow/kerasのモデルを視覚化します。
PyTorchはいつか記事にしたいです。

graphvizpydotのインストールが必要です。
インストールしないとImportErrorになります。

ImportError: You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.

graphviz 公式ドキュメントに沿って、graphvizをインストールをしましょう。

pydotはpip install pydotです。

Windowsでgraphviz

Windowsにgraphvizをインストールします。

下記サイトからWIndows用のインストーラーをダウンロードします。

windows_10_cmake_Release_graphviz-install-5.0.0-win64.exe という実行ファイルをダブルクリックします。
私の環境はWindows11なのですが、やってみます。
ウィザードに沿ってインストールするだけなので難しくないです。

視覚化するモデルは、はじめてのニューラルネットワーク:分類問題の初歩で使われている画像分類モデルのネットワーク層とします。

Fashion MNIST の 28×28 の白黒画像を入力とし、一旦一次元配列にし、128の出力がある全結合層を通し、最後に10クラスに分類するネットワーク層となっています。

import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) (None, 784) 0
dense (Dense) (None, 128) 100480
dense_1 (Dense) (None, 10) 1290
=================================================================

それでは、モデルの視覚化をします。
keras公式ページの モデルの可視化 にもコードの記載があります。

filename = '{}_graphviz_windows.png' .format(datetime.now().strftime('%Y%m%d'))
file_path = os.path.join(os.getcwd(), filename)
tf.keras.utils.plot_model(
    model,
    show_shapes=True,
    to_file=file_path,
    )
ネットワーク構成図

Linux(WSL)でgraphviz

Linux(WSL)の場合は、graphvizのインストール方法が異なるだけです。
graphviz 公式ドキュメントを見てください。

$ sudo apt install graphviz

最初、このコマンドを打つとエラーになってしまいました。

E: Failed to fetch http://security.ubuntu.com/ubuntu/pool/main/t/tiff/libtiff5_4.1.0+git191117-2ubuntu0.20.04.2_amd64.deb 404 Not Found [IP: 91.189.91.38 80]
E: Unable to fetch some archives, maybe run apt-get update or try with --fix-missing?

これを実行し、解決しました。

$ sudo apt-get update

再度インストールコマンドを打ち、インストール完了です。

$ sudo apt install graphviz

pydotもpip install pydotでインストールします。

ネットワーク構成図を視覚化してみます。
同じネットワーク層なのに、WindowsとLinuxでは見た目がちょっと違いました。

import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
model.summary()

filename = '{}_graphviz_linux.png' .format(datetime.now().strftime('%Y%m%d'))
file_path = os.path.join(os.getcwd(), filename)
tf.keras.utils.plot_model(
    model,
    show_shapes=True,
    to_file=file_path,
    )
ネットワーク構成図