Skip to content

Instantly share code, notes, and snippets.

@Ryu1845
Last active September 23, 2023 12:03
Show Gist options
  • Save Ryu1845/09d51411f78252f5f98f03ae5527abae to your computer and use it in GitHub Desktop.
Save Ryu1845/09d51411f78252f5f98f03ae5527abae to your computer and use it in GitHub Desktop.
ZerO Initialization copied from the original repo (https://github.com/jiaweizzhao/ZerO-initialization/)
import math
import torch
def hadamard(n: int, dtype=torch.int8):
"""This function is a port of the one in scipy.linalg"""
if n < 1:
lg2 = 0
else:
lg2 = int(math.log(n, 2))
if 2 ** lg2 != n:
raise ValueError("n must be an positive integer, and n must be "
"a power of 2")
H = torch.tensor([[1]], dtype=dtype)
# Sylvester's construction
for i in range(0, lg2):
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
return H
@torch.compile()
@torch.no_grad()
def linear_ZerO_init_(tensor: torch.Tensor):
# Algorithm 1 in the paper.
assert len(tensor.shape) == 2, "linear_ZerO_init_ only works on 2D tensors"
m, n = tensor.shape
if m <= n:
tensor[:] = torch.nn.init.eye_(torch.empty(m, n))
else: # m > n
clog_m = math.ceil(math.log2(m))
p = 2**(clog_m)
tensor[:] = torch.nn.init.eye_(torch.empty(m, p)) @ (hadamard(p, dtype=tensor.dtype)/(2**(clog_m/2))) @ torch.nn.init.eye_(torch.empty(p, n))
@torch.compile()
@torch.no_grad()
def conv2d_ZerO_init_(tensor: torch.Tensor):
"""Source: https://github.com/jiaweizzhao/ZerO-initialization/issues/1#issuecomment-1405598940"""
assert len(tensor.shape) == 4, "conv2d_ZerO_init_ only works on 4D tensors"
m, n, k, l = tensor.shape
index = int(math.floor(k/2))
if m <= n:
tensor[:, :, index, index] = torch.nn.init.eye_(torch.empty(m, n))
else: # m > n
clog_m = math.ceil(math.log2(m))
p = 2**(clog_m)
tensor[:, :, index, index] = torch.nn.init.eye_(torch.empty(m, p)) @ (hadamard(p, dtype=tensor.dtype)/(2**(clog_m/2))) @ torch.nn.init.eye_(torch.empty(p, n))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment