graphvizで機械学習モデルを視覚化する。WindowsとLinux(WSL)で。

graphvizで機械学習モデルのネットワーク層を視覚化する

今は、会社でも、機械学習でちょっとやってみて、みたいなことを言われる時代です。
それで、Web を見ながら、それなりにやれちゃうわけですが、それをプレゼンするときに、機械学習モデルのネットワーク構成図をいい感じに見せたくなることがあるはずです。

まずは、graphviz と pydot のインストールが必要です。
何もせずにやると、ImportError になります。

ImportError: You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.

graphviz 公式ドキュメント に沿って、まずは graphviz をインストールをしましょう。
pydot は pip install pydot です。
今回は、tensorflow/keras のモデルを視覚化します。
PyTorch の方は、すいません。

Windowsでgraphviz

まずは、Windows に graphviz をインストールします。
ここから、WIndows 用のインストーラーをダウンロードします。
windows_10_cmake_Release_graphviz-install-5.0.0-win64.exe という実行ファイルをダブルクリックします。
私の環境は Windows11 なのですが、やってみます。
ウィザードに沿って、インストールするだけなので、難しくないはずです。

視覚化するモデルは、はじめてのニューラルネットワーク:分類問題の初歩 で使われている画像分類モデルのネットワーク層とします。
Fashion MNIST の 28x28 の白黒画像を入力とし、一旦一次元配列にし、128の出力がある全結合層を通し、最後に10クラスに分類するネットワーク層となっています。

Python

    import tensorflow as tf

    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) (None, 784) 0
dense (Dense) (None, 128) 100480
dense_1 (Dense) (None, 10) 1290
=================================================================

それでは、モデルの視覚化をします。
keras 公式ページの モデルの可視化 にもコードの記載があります。

Python

    filename = '{}_graphviz_windows.png' .format(datetime.now().strftime('%Y%m%d'))
    file_path = os.path.join(os.getcwd(), filename)
    tf.keras.utils.plot_model(
        model,
        show_shapes=True,
        to_file=file_path,
        )

ネットワーク構成図

Linux(WSL)でgraphviz

Linux (WSL) の場合は、graphviz のインストール方法が異なるだけです。
graphviz 公式ドキュメント を見てください。

$ sudo apt install graphviz

最初、このコマンドを打つと、エラーになってしまいました。

E: Failed to fetch http://security.ubuntu.com/ubuntu/pool/main/t/tiff/libtiff5_4.1.0+git191117-2ubuntu0.20.04.2_amd64.deb 404 Not Found [IP: 91.189.91.38 80] E: Unable to fetch some archives, maybe run apt-get update or try with --fix-missing?

これを実行し、解決しました。

$ sudo apt-get update

再度、インストールコマンドを打ち、インストール完了です。

$ sudo apt install graphviz

pydot も pip install pydot でインストールします。

ネットワーク構成図を視覚化してみます。
同じネットワーク層なのに、Windows と Linux では、見た目がちょっと違いました。

Python

    import tensorflow as tf

    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.summary()
    
    filename = '{}_graphviz_linux.png' .format(datetime.now().strftime('%Y%m%d'))
    file_path = os.path.join(os.getcwd(), filename)
    tf.keras.utils.plot_model(
        model,
        show_shapes=True,
        to_file=file_path,
        )

ネットワーク構成図