Skip to content

Instantly share code, notes, and snippets.

@goldsborough
Last active June 27, 2018 16:43
Show Gist options
  • Save goldsborough/f1cc332bd9ac03cae24c422175a49026 to your computer and use it in GitHub Desktop.
Save goldsborough/f1cc332bd9ac03cae24c422175a49026 to your computer and use it in GitHub Desktop.
#include <torch/torch.h>
#include "mnist_reader.h"
#include <fstream>
#include <iomanip>
#include <iostream>
#include <string>
#include <vector>
using namespace torch;
void initialize_weights(nn::Module& module) {
if (module.name().find("Conv2d") != std::string::npos) {
module.parameters()["weight"].data().normal_(0.0, 0.02);
} else if (module.name().find("BatchNorm") != std::string::npos) {
auto parameters = module.parameters();
parameters["weight"].data().normal_(1.0, 0.02);
parameters["bias"].data().fill_(0);
}
}
void store_csv_tensor(Tensor tensor) {
std::ofstream out("out.csv");
auto flat = tensor.flatten();
for (size_t i = 0; i < tensor.numel(); ++i) {
out << flat[i].toCFloat() << ",";
}
}
auto main() -> int {
const int64_t kNoiseSize = 100;
const int64_t kNumberOfEpochs = 30;
const int64_t kBatchSize = 60;
const int64_t kSampleEvery = 100;
const at::Device device(at::kCUDA, 0);
nn::Sequential generator(
// Layer 1
nn::Conv2d(nn::Conv2dOptions(kNoiseSize, 256, 4).with_bias(false).transposed(true)),
nn::BatchNorm(256),
nn::Functional(at::relu),
// Layer 2
nn::Conv2d(nn::Conv2dOptions(256, 128, 3).stride(2).padding(1).with_bias(false).transposed(true)),
nn::BatchNorm(128),
nn::Functional(at::relu),
// Layer 3
nn::Conv2d(nn::Conv2dOptions(128, 64, 4).stride(2).padding(1).with_bias(false).transposed(true)),
nn::BatchNorm(64),
nn::Functional(at::relu),
// Layer 4
nn::Conv2d(nn::Conv2dOptions(64, 1, 4).stride(2).padding(1).with_bias(false).transposed(true)),
nn::Functional(at::tanh));
generator.to(device);
generator.modules().apply(initialize_weights);
nn::Sequential discriminator(
// Layer 1
nn::Conv2d(nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).with_bias(false)),
nn::Functional(at::leaky_relu, 0.2),
// Layer 2
nn::Conv2d(nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).with_bias(false)),
nn::BatchNorm(64 * 2),
nn::Functional(at::leaky_relu, 0.2),
// Layer 3
nn::Conv2d(nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).with_bias(false)),
nn::BatchNorm(256),
nn::Functional(at::leaky_relu, 0.2),
// Layer 4
nn::Conv2d(nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).with_bias(false)),
nn::Functional(at::sigmoid));
discriminator.to(device);
discriminator.modules().apply(initialize_weights);
optim::Adam generator_optimizer(generator.parameters(), optim::AdamOptions(2e-4).beta1(0.5));
optim::Adam discriminator_optimizer(discriminator.parameters(), optim::AdamOptions(5e-4).beta1(0.5));
auto examples = read_mnist_examples("test/cpp/api/mnist/train-images-idx3-ubyte");
examples = (examples * 2) - 1;
examples = examples.reshape({examples.size(0) / kBatchSize, kBatchSize, 1, 28, 28});
examples = examples.to(device);
const auto fixed_noise = torch::randn({kBatchSize, kNoiseSize, 1, 1}, device);
std::cout << std::setprecision(4) << "\n";
for (size_t epoch = 0; epoch < kNumberOfEpochs; ++epoch) {
for (size_t i = 0; i < examples.size(0); ++i) {
// Train discriminator with real images.
discriminator.zero_grad();
torch::Tensor real_images = examples[i].to(device);
auto real_labels = torch::empty(kBatchSize, device).uniform_(0.8, 1.0);
auto real_output = discriminator.forward(real_images);
auto d_loss_real = at::binary_cross_entropy(real_output, real_labels);
d_loss_real.backward();
// Train discriminator with fake images.
auto noise = torch::randn({kBatchSize, kNoiseSize, 1, 1}, device);
torch::Tensor fake_images = generator.forward(noise);
auto fake_labels = torch::zeros(kBatchSize, device);
auto fake_output = discriminator.forward(torch::Tensor(fake_images.detach()));
auto d_loss_fake = at::binary_cross_entropy(fake_output, fake_labels);
d_loss_fake.backward();
auto d_loss = d_loss_real + d_loss_fake;
discriminator_optimizer.step();
// Train generator.
generator.zero_grad();
fake_labels.fill_(1);
fake_output = discriminator.forward(fake_images);
auto g_loss = at::binary_cross_entropy(fake_output, fake_labels);
g_loss.backward();
generator_optimizer.step();
std::cout << "\r[" << epoch << "/" << kNumberOfEpochs << "][" << i << "/" << examples.size(0)
<< "] D_loss: " << d_loss.toCFloat() << " | G_loss: " << g_loss.toCFloat() << std::flush;
if (i % kSampleEvery == 0) {
auto fake_images = generator.forward(fixed_noise);
auto image = (fake_images[0] + 1) / 2;
store_csv_tensor(image);
// std::cout << "\nWrote tensor to out.csv" << std::endl;
}
}
}
std::cout << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment