Skip to content

Instantly share code, notes, and snippets.

@Mk-Chan
Last active September 23, 2020 00:20
Show Gist options
  • Save Mk-Chan/0e45f17dd0208da6f1339c8735a8cbd1 to your computer and use it in GitHub Desktop.
Save Mk-Chan/0e45f17dd0208da6f1339c8735a8cbd1 to your computer and use it in GitHub Desktop.
Runs in --release-fast fine, fails --release-safe or normal debug build (zig build-exe --release-fast nnue.zig)
const std = @import("std");
const assert = std.debug.assert;
pub const ActivationFunction = enum {
ReLU,
Identity,
};
pub fn Layer(comptime InputType_: type, comptime inputs_size_: usize, comptime OutputType_: type, comptime outputs_size_: usize, comptime activation_function: ActivationFunction) type {
return struct {
const SelfType = @This();
const InputType: type = InputType_;
const inputs_size: usize = inputs_size_;
const OutputType: type = OutputType_;
const outputs_size: usize = outputs_size_;
weights: [outputs_size_][inputs_size_]SelfType.InputType = undefined,
biases: [inputs_size_]SelfType.InputType = undefined,
pub fn feedForward(self: *const SelfType, inputs: [*]SelfType.InputType, outputs: [*]SelfType.OutputType) void {
var neuron_index: usize = 0;
while (neuron_index < outputs_size_) : (neuron_index += 1) {
var input_index: usize = 0;
var neuron_result: SelfType.OutputType = 0;
while (input_index < inputs_size_) : (input_index += 1) {
var input_result: SelfType.InputType = self.weights[neuron_index][input_index] * inputs[input_index] + self.biases[input_index];
neuron_result += switch (activation_function) {
.ReLU => std.math.min(std.math.maxInt(SelfType.OutputType), @intCast(SelfType.OutputType, std.math.max(0, input_result))),
.Identity => std.math.max(std.math.minInt(SelfType.OutputType), std.math.min(std.math.maxInt(SelfType.OutputType), @intCast(SelfType.OutputType, input_result))),
};
}
outputs[neuron_index] = neuron_result;
}
}
};
}
pub fn Network(comptime layer_list: anytype) type {
return struct {
const SelfType = @This();
const InputType = @TypeOf(layer_list[0]).InputType;
const inputs_size = @TypeOf(layer_list[0]).inputs_size;
const OutputType = @TypeOf(layer_list[layer_list.len - 1]).OutputType;
const outputs_size = @TypeOf(layer_list[layer_list.len - 1]).outputs_size;
layers: @TypeOf(layer_list) = layer_list,
pub fn feedForward(self: *const SelfType, inputs: [*]SelfType.InputType, outputs: [*]SelfType.OutputType) void {
comptime assert(self.layers.len > 0);
self.feedForwardHelper(0, inputs, outputs);
}
fn feedForwardHelper(self: *const SelfType, comptime layer_index: usize, layer_inputs: anytype, final_outputs: [*]SelfType.OutputType) void {
const layer = &self.layers[layer_index];
if (layer_index + 1 >= self.layers.len) {
layer.feedForward(layer_inputs, final_outputs);
return;
}
const LayerType = @TypeOf(layer.*);
var layer_outputs: [LayerType.outputs_size]LayerType.OutputType = undefined;
layer.feedForward(layer_inputs, &layer_outputs);
self.feedForwardHelper(layer_index + 1, &layer_outputs, final_outputs);
}
};
}
pub fn ParallelNetworkGroup(comptime network_list: anytype) type {
return struct {
const SelfType = @This();
const InputType = @TypeOf(network_list[0]).InputType;
const inputs_size = comptime calculateInputsSize();
const OutputType = @TypeOf(network_list[0]).OutputType;
const outputs_size = comptime calculateOutputsSize();
networks: @TypeOf(network_list) = network_list,
pub fn feedForward(self: *const SelfType, inputs: [*]SelfType.InputType, outputs: [*]SelfType.OutputType) void {
comptime assert(self.networks.len > 0);
comptime var inputs_index = 0;
comptime var outputs_index = 0;
comptime var network_index: usize = 0;
inline while (network_index < network_list.len) : (network_index += 1) {
const network = &self.networks[network_index];
comptime assert(@TypeOf(network.*).InputType == SelfType.InputType);
comptime assert(@TypeOf(network.*).OutputType == SelfType.OutputType);
network.feedForward(inputs + inputs_index, outputs + outputs_index);
inputs_index += @TypeOf(network.layers[0]).inputs_size;
outputs_index += @TypeOf(network.layers[network.layers.len - 1]).outputs_size;
}
}
fn calculateInputsSize() usize {
var inputs_size_counter: usize = 0;
comptime var network_index = 0;
inline while (network_index < network_list.len) : (network_index += 1) {
inputs_size_counter += @TypeOf(network_list[network_index]).inputs_size;
}
return inputs_size_counter;
}
fn calculateOutputsSize() usize {
var outputs_size_counter: usize = 0;
comptime var network_index: usize = 0;
inline while (network_index < network_list.len) : (network_index += 1) {
outputs_size_counter += @TypeOf(network_list[network_index]).outputs_size;
}
return outputs_size_counter;
}
};
}
pub fn SerialNetworkGroup(comptime network_list: anytype) type {
return struct {
const SelfType = @This();
const FirstNetworkType = @TypeOf(network_list[0]);
const InputType = FirstNetworkType.InputType;
const inputs_size = FirstNetworkType.inputs_size;
const LastNetworkType = @TypeOf(network_list[network_list.len - 1]);
const OutputType = LastNetworkType.OutputType;
const outputs_size = LastNetworkType.outputs_size;
networks: @TypeOf(network_list) = network_list,
pub fn feedForward(self: *const SelfType, inputs: [*]SelfType.InputType, outputs: [*]SelfType.OutputType) void {
comptime assert(self.networks.len > 0);
self.feedForwardHelper(0, inputs, outputs);
}
fn feedForwardHelper(self: *const SelfType, comptime network_index: usize, network_inputs: anytype, final_outputs: [*]SelfType.OutputType) void {
const network = &self.networks[network_index];
if (network_index + 1 >= self.networks.len) {
network.feedForward(network_inputs, final_outputs);
return;
}
const NetworkType = @TypeOf(network.*);
var network_outputs: [NetworkType.outputs_size]NetworkType.OutputType = undefined;
network.feedForward(network_inputs, &network_outputs);
self.feedForwardHelper(network_index + 1, &network_outputs, final_outputs);
}
};
}
pub fn main() void {
const possible_king_squares = 64;
const possible_non_king_piece_color_squares = 5 * 2 * 64; // No +1 for the captured piece from the Shogi NNUE implementation
const halfkp_size = possible_king_squares * possible_non_king_piece_color_squares;
const WhiteInputLayer = Layer(i16, halfkp_size, i16, halfkp_size, .Identity);
const WhiteAffineLayer = Layer(i16, halfkp_size, i8, 256, .Identity);
const white_input_network = Network(.{ WhiteInputLayer{}, WhiteAffineLayer{} }){};
const BlackInputLayer = Layer(i16, halfkp_size, i16, halfkp_size, .Identity);
const BlackAffineLayer = Layer(i16, halfkp_size, i8, 256, .Identity);
const black_input_network = Network(.{ BlackInputLayer{}, BlackAffineLayer{} }){};
const board_input_network = ParallelNetworkGroup(.{ white_input_network, black_input_network }){};
const HiddenLayer1 = Layer(i8, 2 * 256, i8, 32 * 32, .ReLU);
const HiddenLayer2 = Layer(i8, 32 * 32, i8, 32, .ReLU);
const OutputLayer = Layer(i8, 32, i16, 1, .Identity);
const evaluation_hidden_network = Network(.{ HiddenLayer1{}, HiddenLayer2{}, OutputLayer{} }){};
const halfkp_2x256_32_32_network = SerialNetworkGroup(.{ board_input_network, evaluation_hidden_network }){};
var inputs = [_]i16{0} ** halfkp_size;
var outputs: [1]i16 = undefined;
halfkp_2x256_32_32_network.feedForward(&inputs, &outputs);
std.io.getStdOut().writer().print("Output: {}\n", .{outputs[0]}) catch |err| {};
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment