Skip to content

Instantly share code, notes, and snippets.

@goldsborough
Created October 2, 2018 18:10
Show Gist options
  • Save goldsborough/0be71b6ca777bb43cd26138faca1591f to your computer and use it in GitHub Desktop.
Save goldsborough/0be71b6ca777bb43cd26138faca1591f to your computer and use it in GitHub Desktop.
#include <torch/torch.h>
torch::nn::Linear model(num_features, 1);
torch::optim::SGD optimizer(model->parameters());
auto data_loader = torch::data::data_loader(dataset);
for (size_t epoch = 0; epoch < 10; ++epoch) {
for (auto batch : data_loader) {
auto prediction = model->forward(batch.data);
auto loss = loss_function(prediction, batch.target);
loss.backward();
optimizer.step();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment