Skip to content

Instantly share code, notes, and snippets.

@martinsotir
Last active March 1, 2021 12:21
Show Gist options
  • Save martinsotir/c621107f3521440c5f2378ec38b8a1ed to your computer and use it in GitHub Desktop.
Save martinsotir/c621107f3521440c5f2378ec38b8a1ed to your computer and use it in GitHub Desktop.
Quick function to display pytorch model's parameters in a plotly TreeMap

Pytorch Model Parameters in plotly TreeMap

Requirement: plotly, pandas and pytorch

import pandas as pd
import plotly.express as px

def plot_weights_treemap(model, max_levels=10):
    """Display pytorch module hierchachy in a treemap diagram
    (The cell area is proportional to the size of the tensors)
    """

    param_df = pd.DataFrame(
        [{'name': name, 'n_weights': param.shape.numel(), 'shape': param.shape}
        for name, param in model.named_parameters()])

    paths = param_df.name.str.split('.', max_levels, expand=True)
    paths_cols = [f"path_{i}" for i in range(paths.shape[1])]
    paths = paths.rename(columns={i: col for i, col in enumerate(paths_cols)})

    paths['n_weights'] = param_df['n_weights']
    paths['names'] = param_df['shape'].map(lambda shape: 'x'.join([str(s) for s in shape]))

    return px.treemap(paths, path=paths_cols, values='n_weights', hover_name='names')

model = ... # Pytorch module with `.named_parameters()` method.

plot_weights_treemap(model)

Example with Huggingface TransformerXL model:

image

And GPT2 model:

image


Copyright (c) 2021 Martin Sotir. All rights reserved. This work is licensed under the terms of the MIT license. For a copy, see https://opensource.org/licenses/MIT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment