Skip to content

Instantly share code, notes, and snippets.

@lostmsu
Created November 7, 2022 18:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lostmsu/7f1c5d52e858d410911101a2664f381a to your computer and use it in GitHub Desktop.
Save lostmsu/7f1c5d52e858d410911101a2664f381a to your computer and use it in GitHub Desktop.
PseudoLinear performance is nearly identical to Linear despite ~160x less computation to be performed
class PseudoLinear(nn.Module):
def __init__(self, features, device=None):
super().__init__()
self.weight = nn.Parameter(torch.randn(features, device=device))
self.bias = nn.Parameter(torch.randn(features, device=device))
def forward(self, x):
return x * self.weight + self.bias + x
def make_linear(features, device):
return nn.Linear(features, features, device=device)
WIDTH = 160
def make_net(depth, linear, device=None):
layers = []
for _ in range(depth):
layers.append(linear(WIDTH, device=device))
layers.append(nn.ReLU())
return nn.Sequential(*layers)
device = "cuda"
real = make_net(40, make_linear, device)
fake = make_net(40, PseudoLinear, device)
def test(net):
import time
def step():
x = torch.randn(512, 192, WIDTH, device=device)
y = torch.randn(512, 192, WIDTH, device=device)
out = net(x)
loss = (out - y).mean()
loss.backward()
# warmup
for _ in range(10):
step()
start = time.time()
for _ in tqdm(range(400)):
step()
return time.time() - start
print(f"real: {test(real)}s")
print(f"fake: {test(fake)}s")
exit(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment