Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jakelevi1996/94a77c6eb080ee6296fddb69954131c3 to your computer and use it in GitHub Desktop.
Save jakelevi1996/94a77c6eb080ee6296fddb69954131c3 to your computer and use it in GitHub Desktop.
Comparing normalisation for weight initialisation
from jutility import plotting, util, transform
import numpy as np
def main():
hidden_dim = 20
num_trials = 100
num_repeats = 15
subplots = []
rng = np.random.default_rng(0)
for normalise in [True, False]:
noisy_data = util.NoisyData()
for _ in range(num_repeats):
hidden_dim_list = np.linspace(2, 1000).astype(int)
var_list = [
get_output_variance(num_trials, h, rng, normalise=normalise)
for h in hidden_dim_list
]
for h, s in zip(hidden_dim_list, var_list):
noisy_data.update(h, s)
x_list, y_list = noisy_data.get_all_data()
f = transform.least_squares_affine(
np.array(x_list).reshape(1, -1),
np.array(y_list).reshape(1, -1),
)
title = "Normalise = %s" % normalise
x = [0, 1000]
sp = plotting.Subplot(
*plotting.get_noisy_data_lines(noisy_data),
plotting.Line(x, f([x]).reshape(-1), c="r", zorder=30),
axis_properties=plotting.AxisProperties(
xlabel="Hidden dimension",
ylabel="Output variance",
title=title,
),
)
subplots.append(sp)
print(f.w.item(), f.b.item())
title = "Comparing normalisation for weight initialisation"
plotting.set_latex_params()
mp = plotting.MultiPlot(
*subplots,
figure_properties=plotting.FigureProperties(
title=title,
top_space=0.15,
),
)
mp.save(title, dir_name=".")
def get_output_variance(num_trials, hidden_dim, rng, normalise=True):
if normalise:
w = rng.normal(size=[num_trials, hidden_dim]) / np.sqrt(hidden_dim)
else:
w = rng.normal(size=[num_trials, hidden_dim])
x = rng.uniform(0, 1, [hidden_dim, num_trials])
output_variance = (w @ x).var()
return output_variance
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment