Created
April 10, 2019 19:36
-
-
Save aminnj/3eb1a8e88d9e5b82d7ccb1f3d3e8d0b1 to your computer and use it in GitHub Desktop.
nn capacity
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
# number of nodes per layer | |
nodes = np.array([2,3,4]) | |
# upper bound on capacity according to eq 1.1 of https://arxiv.org/pdf/1901.00434.pdf | |
# pairwise product of nodes weighted by the minimum layer width up to that point | |
capacity = np.sum(np.minimum.accumulate(nodes)[:-1]*nodes[:-1]*nodes[1:]) | |
print(capacity) | |
# a fully connected arch with equal number of nodes per layer will have a capacity | |
# that scales linearly with the total matrix multiplication time (is this true?) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment