Skip to content

Instantly share code, notes, and snippets.

@segyges
Last active June 27, 2025 02:39
Show Gist options
  • Select an option

  • Save segyges/1886759fa65a6ff9d63d843e9c0ea49e to your computer and use it in GitHub Desktop.

Select an option

Save segyges/1886759fa65a6ff9d63d843e9c0ea49e to your computer and use it in GitHub Desktop.
import torch.nn as nn
import torch.nn.functional as F
import math
# N log N approximation for a multiplication by a random orthogonal matrix
# Mostly untested, depends on having fast_hadamard_transform installed via git
class FastFood(nn.Module):
def __init__(self, size):
super().__init__()
# Check that size is a power of two
if not size & (size - 1) == 0:
raise ValueError("Size must be a power of two")
# Vector of random values drawn from -1 and +1 of equal size to 'size'
self.size = size
self.binary_scaling = torch.randint(0, 2, (size,)) * 2 - 1
self.permutation_matrix = torch.randperm(size)
self.gaussian_scaling = torch.randn(size)
self.linear = nn.Identity() # Add a placeholder for the linear layer
def forward(self, input):
# Assign input to x at the beginning
x = input
# Check if the last dimension of the input matches self.size
input_size = x.size(-1)
if input_size != self.size:
# Calculate padding needed for the last dimension
padding_needed = self.size - input_size
if padding_needed < 0:
# Handle cases where input is larger than expected
raise ValueError(f"Input size ({input_size}) is larger than the expected size ({self.size}). Truncation is not implemented.")
# Pad the last dimension with zeros
padding = (0, padding_needed)
x = F.pad(x, padding)
x = x * self.binary_scaling.to(x.device)
x = fast_hadamard_transform.hadamard_transform(x)
# Scale after the first Hadamard transform
x = x / math.sqrt(self.size)
# Permute
# Ensure x has at least 2 dimensions for stacking
original_shape = None
if x.ndim == 1:
original_shape = x.shape
x = x.unsqueeze(0)
# Reshape for permutation
x = x.view(-1, self.size)
x = torch.stack([x[i, self.permutation_matrix] for i in range(x.shape[0])])
# Reshape back to original shape
if original_shape:
x = x.view(original_shape)
x = x * self.gaussian_scaling.to(x.device)
x = fast_hadamard_transform.hadamard_transform(x)
# Scale after the second Hadamard transform
x = x / math.sqrt(self.size)
# We would have to scale by sigma, according to the paper
# I am wild-ass guessing that this means the standard deviation of our gaussian scaling
# Fortunately our gaussian scaling has var=std=1, so we do nothing
# We can in principle scale by something else to get a non-rbf kernel, but why would you
# The FastFood layer itself does not have a linear layer as per the original paper's definition of the random features.
# The linear layer should be applied *after* the FastFood transformation in the main network class.
# Returning the transformed x directly.
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment