Skip to content

Instantly share code, notes, and snippets.

@aminnj
Created April 10, 2019 19:36
Show Gist options
  • Save aminnj/3eb1a8e88d9e5b82d7ccb1f3d3e8d0b1 to your computer and use it in GitHub Desktop.
Save aminnj/3eb1a8e88d9e5b82d7ccb1f3d3e8d0b1 to your computer and use it in GitHub Desktop.
nn capacity
import numpy as np
# number of nodes per layer
nodes = np.array([2,3,4])
# upper bound on capacity according to eq 1.1 of https://arxiv.org/pdf/1901.00434.pdf
# pairwise product of nodes weighted by the minimum layer width up to that point
capacity = np.sum(np.minimum.accumulate(nodes)[:-1]*nodes[:-1]*nodes[1:])
print(capacity)
# a fully connected arch with equal number of nodes per layer will have a capacity
# that scales linearly with the total matrix multiplication time (is this true?)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment