Skip to content

Instantly share code, notes, and snippets.

@AndrovT
Created June 12, 2023 18:02
Show Gist options
  • Save AndrovT/9e3fbaebb7082734dc84d27e02094cb3 to your computer and use it in GitHub Desktop.
Save AndrovT/9e3fbaebb7082734dc84d27e02094cb3 to your computer and use it in GitHub Desktop.
import torch
def permute_ft_output(nnue, permutation):
l1_size = nnue.layer_stacks.l1.in_features
assert l1_size == len(permutation)*2
permutation.extend([x + l1_size // 2 for x in permutation])
ft_permutation = permutation + list(range(l1_size, nnue.input.num_outputs))
nnue.input.weight.data = nnue.input.weight.data[:, ft_permutation]
nnue.input.bias.data = nnue.input.bias.data[ft_permutation]
nnue.layer_stacks.l1.weight.data = nnue.layer_stacks.l1.weight.data[:, permutation]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment