Skip to content

Instantly share code, notes, and snippets.

@drhead
drhead / channels_last_group_norm.py
Created May 18, 2024 00:29
Naive group norm implementation that somehow beats the native kernel? For channels last. But for some reason also faster on channels first.
def cl_weight_hook(state_dict, *args, **kwargs):
for key in state_dict.keys():
state_dict[key] = state_dict[key].reshape(1, -1, 1, 1).to(memory_format=torch.channels_last)
class CLGroupNorm(torch.nn.GroupNorm):
def __init__(self, num_groups: int, num_channels: int, eps: float = 0.00001, affine: bool = True, device=None, dtype=None) -> None:
super().__init__(num_groups, num_channels, eps, affine, device, dtype)
if self.weight.ndim == 1:
self.weight.data = self.weight.data.reshape(1, -1, 1, 1).to(memory_format=torch.channels_last)
if self.bias.ndim == 1:
@drhead
drhead / loss_mlp.py
Last active April 20, 2024 02:52
Loss weighting MLP prototype
def normalize(x: torch.Tensor, dim=None, eps=1e-4) -> torch.Tensor:
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) # type: torch.Tensor
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class MPFourier(nn.Module):
def __init__(self, num_channels, bandwidth=1):
super().__init__()