Skip to content

Instantly share code, notes, and snippets.

@harpone
Created March 23, 2023 20:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save harpone/e018bf25059687a4355e7c76d6807de8 to your computer and use it in GitHub Desktop.
Save harpone/e018bf25059687a4355e7c76d6807de8 to your computer and use it in GitHub Desktop.
LRU with self-attention
def forward_sequential(h, xs, U, W, nu, theta):
"""Forward pass through the network sequentially over input `xs` of any length.
NOTE: has no batch dimension. To be batched with `vmap`.
Args:
h (torch.tensor): shape [D_h, ]; previous state
xs (torch.tensor): shape [T, D_x]; input sequence
U (torch.tensor): Parameter matrix of shape [D_h, D_x]
W (torch.tensor): Parameter matrix of shape [D_h, D_x]
xi (torch.tensor): Parameter vector of shape [D_h, ]
eta (torch.tensor): Parameter vector of shape [D_h, ]
Returns:
hs (torch.tensor): shape [T, D_h]; output sequence
"""
T = xs.shape[0]
D_h = h.shape[0]
hs = torch.zeros(T, D_h, device=xs.device)
for t in range(T):
h = torch.exp(U @ xs[t] - nu - theta * 1j) * h + W @ xs[t]
hs[t] = h
return hs.real
def forward_parallel(h, xs, U, W, nu, theta):
"""Forward pass through the network in parallel over input `xs` of any length by using
the exact solution of the recurrence relation.
NOTE: has no batch dimension. To be batched with `vmap`.
Args:
h (torch.tensor): shape [D_h, ]; previous state
xs (torch.tensor): shape [T, D_x]; input sequence
U (torch.tensor): Parameter matrix of shape [D_h, D_x]
W (torch.tensor): Parameter matrix of shape [D_h, D_x]
xi (torch.tensor): Parameter vector of shape [D_h, ]
eta (torch.tensor): Parameter vector of shape [D_h, ]
Returns:
hs (torch.tensor): shape [T, D_h]; output sequence
"""
gammas = torch.cumsum(torch.matmul(xs, U.T) - nu - theta * 1j, dim=0) # [T, D_h]
betas = torch.matmul(xs, W.T) # [T, D_h]
source = torch.cumsum(torch.exp(-gammas) * betas, dim=0) # [T, D_h]
hs = torch.exp(gammas) * (h[None] + source)
return hs.real
#### Benchmark code:
device = torch.device('cuda')
D_h = 256
D_x = 64
U = torch.randn(D_h, D_x, device=device)
W = torch.randn(D_h, D_x, device=device)
xi = torch.linspace(0.001, 0.5, D_h, device=device)
eta = torch.linspace(0, 2 * math.pi * (D_h - 1) / D_h, D_h, device=device)
T = 1024
xs = torch.randn(T, D_x, device=device)
h = torch.randn(D_h, device=device)
def sequential_timer():
hs_seq = forward_sequential(h, xs, U, W, xi, eta)
torch.cuda.synchronize()
def parallel_timer():
hs_par = forward_parallel(h, xs, U, W, xi, eta)
torch.cuda.synchronize()
@harpone
Copy link
Author

harpone commented Mar 23, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment