Skip to content

Instantly share code, notes, and snippets.

@thecharlieblake
Last active September 6, 2022 10:50
Show Gist options
  • Save thecharlieblake/2d443c0def7eb8b3dd5703aa217b41bc to your computer and use it in GitHub Desktop.
Save thecharlieblake/2d443c0def7eb8b3dd5703aa217b41bc to your computer and use it in GitHub Desktop.
import numpy as np
class Matmul:
def __init__(self, batch_sz, in_dim, out_dim):
self.batch_sz, self.in_dim, self.out_dim = batch_sz, in_dim, out_dim
glorot = np.sqrt(2 / (in_dim * out_dim))
self._init_w(scale=glorot)
def _init_w(self, scale):
self.w = np.random.normal(0, scale, (self.in_dim, self.out_dim))
def fwd(self, x):
self.x = x
return x @ self.w
def grad_x(self, d_y):
return d_y @ self.w.T
def grad_w(self, d_y):
return self.x.T @ d_y
class UnitScaleMatmul(Matmul):
def __init__(self, batch_sz, in_dim, out_dim):
super().__init__(batch_sz, in_dim, out_dim)
self._init_w(scale=1)
def fwd(self, x):
return super().fwd(x) / np.sqrt(self.in_dim)
def grad_x(self, d_y):
return super().grad_x(d_y) / np.sqrt(self.out_dim)
def grad_w(self, d_y):
return super().grad_w(d_y) / np.sqrt(self.batch_sz)
if __name__ == "__main__":
dims = batch_sz, d_in, d_out = 32, 64, 128
x = np.random.normal(0, 1, (batch_sz, d_in))
d_y = np.random.normal(0, 1, (batch_sz, d_out))
matmul = Matmul(*dims)
us_matmul = UnitScaleMatmul(*dims)
print("Regular Scale | Unit Scale")
print(f"w: {matmul.w.var():.2} "
f"| {us_matmul.w.var():.2}")
print(f"fwd: {matmul.fwd(x).var():.2} "
f"| {us_matmul.fwd(x).var():.2}")
print(f"grad_x: {matmul.grad_x(d_y).var():.2} "
f"| {us_matmul.grad_x(d_y).var():.2}")
print(f"grad_w: {matmul.grad_w(d_y).var():.2} "
f"| {us_matmul.grad_w(d_y).var():.2}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment