Last active
May 3, 2018 14:06
-
-
Save NTT123/746e11da179674b410b20d03516caa94 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### usage python ELF2LeelaZ.py > weights.txt | |
import torch | |
m = torch.load("pretrained-go-19x19-v0.bin") | |
opt = m["options"] | |
num_block = opt['num_block'] | |
dim = opt['dim'] | |
bn = opt['bn'] | |
leaky_relu = opt['leaky_relu'] | |
wei = m["state_dict"] | |
#print(wei.keys()) | |
#print(wei["pi_linear.weight"].size()) # 362 x 722 | |
def print_tensor(wei, name): | |
w = wei[name].view(-1) | |
#print(wei[name].size()) # to check size | |
s = [str(w[i].item()) for i in range(w.size(0))] | |
print(" ".join(s)) | |
def print_conv(wei, name): | |
#name = name + "_conv" | |
print_tensor(wei, name + ".0.weight") | |
print_tensor(wei, name + ".0.bias") | |
#print_tensor(wei, name + ".1.weight") | |
#print_tensor(wei, name + ".1.bias") | |
print_tensor(wei, name + ".1.running_mean") | |
print_tensor(wei, name + ".1.running_var") | |
def print_resnet(wei, idx): | |
name1 = "resnet.module.resnet.%d.conv_lower" % idx | |
name2 = "resnet.module.resnet.%d.conv_upper" % idx | |
print_conv(wei, name1) | |
print_conv(wei, name2) | |
print("1") # version | |
print_conv(wei, 'init_conv') | |
for i in range(num_block): | |
print_resnet(wei, i) | |
print_conv(wei, 'pi_final_conv') | |
print_tensor(wei, "pi_linear.weight") | |
print_tensor(wei, "pi_linear.bias") | |
print_conv(wei, 'value_final_conv') | |
print_tensor(wei, "value_linear1.weight") | |
print_tensor(wei, "value_linear1.bias") | |
print_tensor(wei, "value_linear2.weight") | |
print_tensor(wei, "value_linear2.bias") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
whats the issues