Last active
June 27, 2025 02:39
-
-
Save segyges/1886759fa65a6ff9d63d843e9c0ea49e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| # N log N approximation for a multiplication by a random orthogonal matrix | |
| # Mostly untested, depends on having fast_hadamard_transform installed via git | |
| class FastFood(nn.Module): | |
| def __init__(self, size): | |
| super().__init__() | |
| # Check that size is a power of two | |
| if not size & (size - 1) == 0: | |
| raise ValueError("Size must be a power of two") | |
| # Vector of random values drawn from -1 and +1 of equal size to 'size' | |
| self.size = size | |
| self.binary_scaling = torch.randint(0, 2, (size,)) * 2 - 1 | |
| self.permutation_matrix = torch.randperm(size) | |
| self.gaussian_scaling = torch.randn(size) | |
| self.linear = nn.Identity() # Add a placeholder for the linear layer | |
| def forward(self, input): | |
| # Assign input to x at the beginning | |
| x = input | |
| # Check if the last dimension of the input matches self.size | |
| input_size = x.size(-1) | |
| if input_size != self.size: | |
| # Calculate padding needed for the last dimension | |
| padding_needed = self.size - input_size | |
| if padding_needed < 0: | |
| # Handle cases where input is larger than expected | |
| raise ValueError(f"Input size ({input_size}) is larger than the expected size ({self.size}). Truncation is not implemented.") | |
| # Pad the last dimension with zeros | |
| padding = (0, padding_needed) | |
| x = F.pad(x, padding) | |
| x = x * self.binary_scaling.to(x.device) | |
| x = fast_hadamard_transform.hadamard_transform(x) | |
| # Scale after the first Hadamard transform | |
| x = x / math.sqrt(self.size) | |
| # Permute | |
| # Ensure x has at least 2 dimensions for stacking | |
| original_shape = None | |
| if x.ndim == 1: | |
| original_shape = x.shape | |
| x = x.unsqueeze(0) | |
| # Reshape for permutation | |
| x = x.view(-1, self.size) | |
| x = torch.stack([x[i, self.permutation_matrix] for i in range(x.shape[0])]) | |
| # Reshape back to original shape | |
| if original_shape: | |
| x = x.view(original_shape) | |
| x = x * self.gaussian_scaling.to(x.device) | |
| x = fast_hadamard_transform.hadamard_transform(x) | |
| # Scale after the second Hadamard transform | |
| x = x / math.sqrt(self.size) | |
| # We would have to scale by sigma, according to the paper | |
| # I am wild-ass guessing that this means the standard deviation of our gaussian scaling | |
| # Fortunately our gaussian scaling has var=std=1, so we do nothing | |
| # We can in principle scale by something else to get a non-rbf kernel, but why would you | |
| # The FastFood layer itself does not have a linear layer as per the original paper's definition of the random features. | |
| # The linear layer should be applied *after* the FastFood transformation in the main network class. | |
| # Returning the transformed x directly. | |
| return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment