Last active
September 12, 2022 10:05
-
-
Save exceedsystem/9ce8e4f58fae34aaecb0d8794efb0623 to your computer and use it in GitHub Desktop.
How to implement XOR gate nural network with torch in .NET 6
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// EXCEEDSYSTEM Sample of the XOR gate neural network with the TorchSharp | |
// https://www.exceedsystem.net/2022/09/12/how-to-implement-xor-gate-neural-network-with-torch-in-dotnet-6 | |
// License: MIT | |
using TorchSharp; | |
using static TorchSharp.torch; | |
using static TorchSharp.torch.nn; | |
using static TorchSharp.torch.nn.functional; | |
// Fix the random number seed to be used for weight and bias | |
torch.random.manual_seed(1); | |
// Training data(xa, xb) | |
var trainData = new float[,] | |
{ | |
{0, 0}, | |
{1, 0}, | |
{0, 1}, | |
{1, 1} | |
}; | |
// Training label(t) | |
var trainLabel = new float[,] | |
{ | |
{0}, | |
{1}, | |
{1}, | |
{0} | |
}; | |
// Create a newral network | |
var nn = Sequential( | |
// Linear transformation 1 | |
Linear(2, 2), | |
// Non-Linear transformation | |
Sigmoid(), | |
// Linear transformation 2 | |
Linear(2, 1) | |
); | |
// Set the modules in training mode | |
nn.train(); | |
// Create tensors for training data and labels | |
var x = tensor(trainData); | |
var y = tensor(trainLabel); | |
// Number of epochs | |
const int epochs = 10000; | |
// Learning rate | |
const double lr = 0.1; | |
// Create a loss function | |
var criterion = mse_loss(Reduction.Sum); | |
// Create an optimizer | |
var optimizer = optim.SGD(nn.parameters(), lr); | |
// Training | |
for (var epoch = 1; epoch <= epochs; ++epoch) | |
{ | |
// Computes prediction | |
var eval = nn.forward(x); | |
// Computes MSE loss | |
var loss = criterion(eval, y); | |
// Sets gradients to zero | |
optimizer.zero_grad(); | |
// Computes gradients | |
loss.backward(); | |
// Updates parameters | |
optimizer.step(); | |
if (epoch % 100 == 0) | |
Console.WriteLine($"Epoch:{epoch} Loss:{loss.ToSingle()}"); | |
} | |
// Set the modules in evaluation mode | |
nn.eval(); | |
// Make predictions using trained model | |
// 0 ⊻ 0 | |
nn.forward(torch.tensor(new float[] { 0, 0 })).print(); | |
// 1 ⊻ 0 | |
nn.forward(torch.tensor(new float[] { 1, 0 })).print(); | |
// 0 ⊻ 1 | |
nn.forward(torch.tensor(new float[] { 0, 1 })).print(); | |
// 1 ⊻ 1 | |
nn.forward(torch.tensor(new float[] { 1, 1 })).print(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment