Last active
December 30, 2020 11:33
-
-
Save ddobbelaere/ad4e2645828b3fbd771249e37cfc2d0a to your computer and use it in GitHub Desktop.
Analyze Stockfish NNUE weights
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import struct | |
from functools import reduce | |
import operator | |
import matplotlib.pyplot as plt | |
class NNUEReader(): | |
def __init__(self, filename): | |
self.f = open(filename, 'rb') | |
self.read_header() | |
self.read_int32() # Feature transformer hash | |
self.input_biases = self.tensor(np.int16, (256,)) | |
self.input_weights = self.tensor(np.int16, (41024, 256)) | |
self.read_int32() # FC layers hash | |
self.l1_biases = self.tensor(np.int32, (32,)) | |
self.l1_weights = self.tensor(np.int8, (32, 512)) | |
# self.l2 = self.read_fc_layer() | |
# self.output = self.read_fc_layer(is_output=True) | |
def read_header(self): | |
self.read_int32() # version | |
self.read_int32() # halfkp network hash | |
desc_len = self.read_int32() # Network definition | |
description = self.f.read(desc_len) | |
def tensor(self, dtype, shape): | |
d = np.fromfile(self.f, dtype, reduce(operator.mul, shape, 1)) | |
d = d.reshape(shape) | |
return d | |
def read_int32(self, expected=None): | |
v = struct.unpack("<i", self.f.read(4))[0] | |
if expected is not None and v != expected: | |
raise Exception("Expected: %x, got %x" % (expected, v)) | |
return v | |
def get_image(weights, reorder=False, use_rgba=True): | |
weights = np.ndarray.flatten(weights) | |
hd = 256 | |
hdim = 4096 | |
totaldim = hdim*((hd*64*641)//hdim) | |
if use_rgba: | |
totaldim *= 4 | |
img = np.zeros(totaldim, dtype=np.int16) | |
for j in range(weights.size): | |
pi = (j // hd - 1) % 641 | |
ki = (j // hd - 1) // 641 | |
piece = pi // 64 | |
rank = (pi % 64) // 8 | |
if (pi == 640 or (rank == 0 or rank == 7) and (piece == 0 or piece == 1)): | |
continue | |
r = 0 | |
g = 0 | |
b = 0 | |
v = 0 | |
if (ki != pi % 64): | |
v = -weights[j] | |
if use_rgba: | |
# Same logic/colors as https://hxim.github.io/Stockfish-Evaluation-Guide/ | |
v *= 2 | |
if (v >= 0 and v < 1020): | |
r = v if v <= 255 else (- v+510 if v <= 512 else 0) | |
g = v if v <= 255 else (255 if v <= 765 else -v+1020) | |
b = v if v <= 255 else ( | |
255 if v <= 510 else(v+765 if v <= 765 else 0)) | |
elif (v < 0 and v > -1020): | |
r = - v if v >= -255 else (255 if v >= -765 else v+1020) | |
g = - v if v >= -255 else (v+510 if v >= -510 else 0) | |
b = 0 if v >= -510 else (- v-510 if v >= -765 else v+1020) | |
else: | |
print("error weight too big: " + v) | |
return | |
kipos = [ki % 8, ki // 8] | |
pipos = [pi % 8, rank] | |
inpos = [(7-kipos[0])+pipos[0]*8, | |
kipos[1]+(7-pipos[1])*8] | |
d = - 8 if piece < 2 else 48 + (piece // 2 - 1) * 64 | |
# jhd = n[1].sigi[j % hd] if reorder else j % hd | |
jhd = 0 if reorder else j % hd | |
x = inpos[0] + 128 * ((jhd) % 32) + (piece % 2)*64 | |
y = inpos[1] + d + 304 * (jhd // 32) | |
ii = (x + y * 4096) * 4 | |
if use_rgba: | |
img[ii] = r | |
img[ii+1] = g | |
img[ii+2] = b | |
img[ii+3] = 255 | |
else: | |
img[ii//4] = v | |
# img[ii//4] = ki | |
return img.reshape((totaldim//(4*hdim), hdim, 4) if use_rgba else (totaldim//hdim, hdim)) | |
# filename = "nn-62ef826d1a6d.nnue" # master | |
filename = "nn-ddbf15bd12bd.nnue" # vdv | |
# filename = "nn-64fc1e0029b5.nnue" # noob | |
# nnue = NNUEReader("/home/dieter/Downloads/nets/nn-62ef826d1a6d.nnue") | |
nnue = NNUEReader("/home/dieter/Downloads/nets/" + filename) | |
print("Net {}".format(filename)) | |
print("mean(abs(FT weights)) = {}".format( | |
np.mean(np.abs(nnue.input_weights)))) | |
print("mean(abs(FC1 weights)) = {}".format(np.mean(np.abs(nnue.l1_weights)))) | |
print("rms(FC1 weights) = {}".format(np.sqrt(np.mean(nnue.l1_weights**2)))) | |
img = get_image(nnue.input_weights, use_rgba=False) | |
plt.matshow(np.abs(img), vmin=0, vmax=64) | |
# plt.imshow(np.abs(get_image(nnue.input_weights, use_rgba=False))) | |
plt.colorbar() | |
hd = 256 | |
for i in range(hd//8): | |
plt.axvline(x=128*i-0.5, color='red') | |
for j in range(8): | |
plt.axhline(y=304*j-0.5, color='red') | |
plt.xlim([0, 4096]) | |
plt.ylim([8*304, 0]) | |
plt.title("{} (weights FT)".format(filename)) | |
plt.matshow(np.reshape(nnue.input_biases, (8, 32))) | |
plt.colorbar() | |
plt.title("{} (biases FT)".format(filename)) | |
indices = [] | |
for j in range(8): | |
for i in range(hd//8): | |
s = img[304*j:304*(j+1), 128*i:128*(i+1)] | |
if np.max(s) <= 10: | |
print(j, i, np.min(s), np.max(s), np.sum( | |
np.abs(s)), nnue.input_biases[i+32*j]) | |
indices.append(i+32*j) | |
plt.figure() | |
plt.plot(nnue.input_biases) | |
plt.plot(indices, nnue.input_biases[indices], 'x', label='dead features') | |
plt.legend() | |
plt.title("{} (biases FT)".format(filename)) | |
plt.figure() | |
plt.plot(nnue.l1_biases) | |
if False: | |
for i in range(2): | |
plt.matshow( | |
np.abs(nnue.l1_weights[:, 256*i:256*(i+1)]), vmin=0, vmax=32) | |
plt.colorbar() | |
plt.title("{} (weights FC1)".format(filename)) | |
plt.matshow(np.abs(np.reshape(nnue.l1_weights, (64, 256))), vmin=0, vmax=32) | |
plt.colorbar() | |
plt.title("{} (weights FC1)".format(filename)) | |
plt.figure() | |
plt.plot(np.mean(np.abs(np.reshape(nnue.l1_weights, (64, 256))), axis=0)) | |
plt.title("{} (weights FC1)".format(filename)) | |
plt.show(block=False) | |
input() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment