Skip to content

Instantly share code, notes, and snippets.

@ddobbelaere
Last active December 30, 2020 11:33
Show Gist options
  • Save ddobbelaere/ad4e2645828b3fbd771249e37cfc2d0a to your computer and use it in GitHub Desktop.
Save ddobbelaere/ad4e2645828b3fbd771249e37cfc2d0a to your computer and use it in GitHub Desktop.
Analyze Stockfish NNUE weights
import numpy as np
import struct
from functools import reduce
import operator
import matplotlib.pyplot as plt
class NNUEReader():
def __init__(self, filename):
self.f = open(filename, 'rb')
self.read_header()
self.read_int32() # Feature transformer hash
self.input_biases = self.tensor(np.int16, (256,))
self.input_weights = self.tensor(np.int16, (41024, 256))
self.read_int32() # FC layers hash
self.l1_biases = self.tensor(np.int32, (32,))
self.l1_weights = self.tensor(np.int8, (32, 512))
# self.l2 = self.read_fc_layer()
# self.output = self.read_fc_layer(is_output=True)
def read_header(self):
self.read_int32() # version
self.read_int32() # halfkp network hash
desc_len = self.read_int32() # Network definition
description = self.f.read(desc_len)
def tensor(self, dtype, shape):
d = np.fromfile(self.f, dtype, reduce(operator.mul, shape, 1))
d = d.reshape(shape)
return d
def read_int32(self, expected=None):
v = struct.unpack("<i", self.f.read(4))[0]
if expected is not None and v != expected:
raise Exception("Expected: %x, got %x" % (expected, v))
return v
def get_image(weights, reorder=False, use_rgba=True):
weights = np.ndarray.flatten(weights)
hd = 256
hdim = 4096
totaldim = hdim*((hd*64*641)//hdim)
if use_rgba:
totaldim *= 4
img = np.zeros(totaldim, dtype=np.int16)
for j in range(weights.size):
pi = (j // hd - 1) % 641
ki = (j // hd - 1) // 641
piece = pi // 64
rank = (pi % 64) // 8
if (pi == 640 or (rank == 0 or rank == 7) and (piece == 0 or piece == 1)):
continue
r = 0
g = 0
b = 0
v = 0
if (ki != pi % 64):
v = -weights[j]
if use_rgba:
# Same logic/colors as https://hxim.github.io/Stockfish-Evaluation-Guide/
v *= 2
if (v >= 0 and v < 1020):
r = v if v <= 255 else (- v+510 if v <= 512 else 0)
g = v if v <= 255 else (255 if v <= 765 else -v+1020)
b = v if v <= 255 else (
255 if v <= 510 else(v+765 if v <= 765 else 0))
elif (v < 0 and v > -1020):
r = - v if v >= -255 else (255 if v >= -765 else v+1020)
g = - v if v >= -255 else (v+510 if v >= -510 else 0)
b = 0 if v >= -510 else (- v-510 if v >= -765 else v+1020)
else:
print("error weight too big: " + v)
return
kipos = [ki % 8, ki // 8]
pipos = [pi % 8, rank]
inpos = [(7-kipos[0])+pipos[0]*8,
kipos[1]+(7-pipos[1])*8]
d = - 8 if piece < 2 else 48 + (piece // 2 - 1) * 64
# jhd = n[1].sigi[j % hd] if reorder else j % hd
jhd = 0 if reorder else j % hd
x = inpos[0] + 128 * ((jhd) % 32) + (piece % 2)*64
y = inpos[1] + d + 304 * (jhd // 32)
ii = (x + y * 4096) * 4
if use_rgba:
img[ii] = r
img[ii+1] = g
img[ii+2] = b
img[ii+3] = 255
else:
img[ii//4] = v
# img[ii//4] = ki
return img.reshape((totaldim//(4*hdim), hdim, 4) if use_rgba else (totaldim//hdim, hdim))
# filename = "nn-62ef826d1a6d.nnue" # master
filename = "nn-ddbf15bd12bd.nnue" # vdv
# filename = "nn-64fc1e0029b5.nnue" # noob
# nnue = NNUEReader("/home/dieter/Downloads/nets/nn-62ef826d1a6d.nnue")
nnue = NNUEReader("/home/dieter/Downloads/nets/" + filename)
print("Net {}".format(filename))
print("mean(abs(FT weights)) = {}".format(
np.mean(np.abs(nnue.input_weights))))
print("mean(abs(FC1 weights)) = {}".format(np.mean(np.abs(nnue.l1_weights))))
print("rms(FC1 weights) = {}".format(np.sqrt(np.mean(nnue.l1_weights**2))))
img = get_image(nnue.input_weights, use_rgba=False)
plt.matshow(np.abs(img), vmin=0, vmax=64)
# plt.imshow(np.abs(get_image(nnue.input_weights, use_rgba=False)))
plt.colorbar()
hd = 256
for i in range(hd//8):
plt.axvline(x=128*i-0.5, color='red')
for j in range(8):
plt.axhline(y=304*j-0.5, color='red')
plt.xlim([0, 4096])
plt.ylim([8*304, 0])
plt.title("{} (weights FT)".format(filename))
plt.matshow(np.reshape(nnue.input_biases, (8, 32)))
plt.colorbar()
plt.title("{} (biases FT)".format(filename))
indices = []
for j in range(8):
for i in range(hd//8):
s = img[304*j:304*(j+1), 128*i:128*(i+1)]
if np.max(s) <= 10:
print(j, i, np.min(s), np.max(s), np.sum(
np.abs(s)), nnue.input_biases[i+32*j])
indices.append(i+32*j)
plt.figure()
plt.plot(nnue.input_biases)
plt.plot(indices, nnue.input_biases[indices], 'x', label='dead features')
plt.legend()
plt.title("{} (biases FT)".format(filename))
plt.figure()
plt.plot(nnue.l1_biases)
if False:
for i in range(2):
plt.matshow(
np.abs(nnue.l1_weights[:, 256*i:256*(i+1)]), vmin=0, vmax=32)
plt.colorbar()
plt.title("{} (weights FC1)".format(filename))
plt.matshow(np.abs(np.reshape(nnue.l1_weights, (64, 256))), vmin=0, vmax=32)
plt.colorbar()
plt.title("{} (weights FC1)".format(filename))
plt.figure()
plt.plot(np.mean(np.abs(np.reshape(nnue.l1_weights, (64, 256))), axis=0))
plt.title("{} (weights FC1)".format(filename))
plt.show(block=False)
input()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment