Skip to content

Instantly share code, notes, and snippets.

@exceedsystem
Last active September 12, 2022 10:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save exceedsystem/9ce8e4f58fae34aaecb0d8794efb0623 to your computer and use it in GitHub Desktop.
Save exceedsystem/9ce8e4f58fae34aaecb0d8794efb0623 to your computer and use it in GitHub Desktop.
How to implement XOR gate nural network with torch in .NET 6
// EXCEEDSYSTEM Sample of the XOR gate neural network with the TorchSharp
// https://www.exceedsystem.net/2022/09/12/how-to-implement-xor-gate-neural-network-with-torch-in-dotnet-6
// License: MIT
using TorchSharp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using static TorchSharp.torch.nn.functional;
// Fix the random number seed to be used for weight and bias
torch.random.manual_seed(1);
// Training data(xa, xb)
var trainData = new float[,]
{
{0, 0},
{1, 0},
{0, 1},
{1, 1}
};
// Training label(t)
var trainLabel = new float[,]
{
{0},
{1},
{1},
{0}
};
// Create a newral network
var nn = Sequential(
// Linear transformation 1
Linear(2, 2),
// Non-Linear transformation
Sigmoid(),
// Linear transformation 2
Linear(2, 1)
);
// Set the modules in training mode
nn.train();
// Create tensors for training data and labels
var x = tensor(trainData);
var y = tensor(trainLabel);
// Number of epochs
const int epochs = 10000;
// Learning rate
const double lr = 0.1;
// Create a loss function
var criterion = mse_loss(Reduction.Sum);
// Create an optimizer
var optimizer = optim.SGD(nn.parameters(), lr);
// Training
for (var epoch = 1; epoch <= epochs; ++epoch)
{
// Computes prediction
var eval = nn.forward(x);
// Computes MSE loss
var loss = criterion(eval, y);
// Sets gradients to zero
optimizer.zero_grad();
// Computes gradients
loss.backward();
// Updates parameters
optimizer.step();
if (epoch % 100 == 0)
Console.WriteLine($"Epoch:{epoch} Loss:{loss.ToSingle()}");
}
// Set the modules in evaluation mode
nn.eval();
// Make predictions using trained model
// 0 ⊻ 0
nn.forward(torch.tensor(new float[] { 0, 0 })).print();
// 1 ⊻ 0
nn.forward(torch.tensor(new float[] { 1, 0 })).print();
// 0 ⊻ 1
nn.forward(torch.tensor(new float[] { 0, 1 })).print();
// 1 ⊻ 1
nn.forward(torch.tensor(new float[] { 1, 1 })).print();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment