Skip to content

Instantly share code, notes, and snippets.

@JoshVarty
Created December 19, 2018 20:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JoshVarty/143aa35c0efc25d29d18ac523fbb597c to your computer and use it in GitHub Desktop.
Save JoshVarty/143aa35c0efc25d29d18ac523fbb597c to your computer and use it in GitHub Desktop.
#include <cstddef>
#include <iostream>
#include <string>
#include <vector>
#include <torch/torch.h>
torch::nn::Conv2d conv3x3(int64_t inputChannels, int64_t outputChannels, int64_t stride) {
auto options = torch::nn::Conv2dOptions(inputChannels, outputChannels, /*kernel_size=*/3);
options = options.stride(stride).padding(1).with_bias(false);
return std::make_shared<torch::nn::Conv2dImpl>(options);
}
torch::nn::Conv2d conv1x1(int64_t inputChannels, int64_t outputChannels, int64_t stride) {
auto options = torch::nn::Conv2dOptions(inputChannels, outputChannels, /*kernel_size=*/1);
options = options.stride(stride).with_bias(false);
return std::make_shared<torch::nn::Conv2dImpl>(options);
}
struct BasicBlock : torch::nn::Module {
static const int64_t EXPANSION = 1;
torch::nn::Conv2d conv1;
torch::nn::BatchNorm bn1;
torch::nn::Conv2d conv2;
torch::nn::BatchNorm bn2;
torch::nn::Sequential downsample;
BasicBlock(int64_t inplanes, int64_t planes, int64_t stride, torch::nn::Sequential downsample)
: conv1(conv3x3(inplanes, planes, stride)),
bn1(torch::nn::BatchNorm(torch::nn::BatchNormOptions(planes))),
conv2(conv3x3(planes, planes, /*stride*/1)),
bn2(torch::nn::BatchNorm(torch::nn::BatchNormOptions(planes))),
downsample(downsample)
{
register_module("conv1", conv1);
register_module("bn1", bn1);
register_module("conv2", conv2);
register_module("bn2", bn2);
register_module("downsample", downsample);
}
torch::Tensor forward(torch::Tensor x) {
auto identity = x;
auto out = conv1->forward(x);
out = bn1->forward(out);
out = torch::relu(out);
out = conv2->forward(out);
out = bn2->forward(out);
if (this->downsample.get()->size() > 0) {
identity = this->downsample->forward(x);
}
out = out + identity;
out = torch::relu(out);
return out;
}
};
struct ResNet : torch::nn::Module {
int64_t inplanes = 64;
torch::nn::Conv2dOptions conv1Options;
torch::nn::Conv2d conv1;
torch::nn::BatchNorm bn1;
torch::nn::Sequential layer1;
torch::nn::Sequential layer2;
torch::nn::Sequential layer3;
torch::nn::Sequential layer4;
torch::nn::Linear fc1;
ResNet(int64_t inputDepth, int layers[])
:
conv1Options(torch::nn::Conv2dOptions(inputDepth, 64, /*kernel_size=*/7).stride(2).padding(3).with_bias(false)),
conv1(std::make_shared<torch::nn::Conv2dImpl>(conv1Options)),
bn1(std::make_shared<torch::nn::BatchNormImpl>(64)),
layer1(make_layer_basic(64, layers[0], /*stride=*/1)),
layer2(make_layer_basic(128, layers[1], /*stride=*/2)),
layer3(make_layer_basic(256, layers[2], /*stride=*/2)),
layer4(make_layer_basic(512, layers[3], /*stride=*/2)),
fc1(std::make_shared<torch::nn::LinearImpl>(512, 9))
{
register_module("conv1", conv1);
register_module("bn1", bn1);
register_module("layer1", layer1);
register_module("layer2", layer2);
register_module("layer3", layer3);
register_module("layer4", layer4);
register_module("fc1", fc1);
}
torch::nn::Sequential make_layer_basic(int64_t planes, int64_t blocks, int64_t stride) {
torch::nn::Sequential downsample = std::make_shared<torch::nn::SequentialImpl>();
if(stride != 1 || this->inplanes != planes * BasicBlock::EXPANSION) {
downsample = torch::nn::Sequential(
std::make_shared<torch::nn::SequentialImpl>(
conv1x1(this->inplanes, planes * BasicBlock::EXPANSION, stride),
torch::nn::BatchNorm(planes * BasicBlock::EXPANSION))
);
}
torch::nn::Sequential layers = std::make_shared<torch::nn::SequentialImpl>();
auto newBlock = std::make_shared<BasicBlock>(this->inplanes, planes, stride, downsample);
layers->push_back(newBlock);
this->inplanes = planes * BasicBlock::EXPANSION;
for(int64_t i = 0; i < blocks; i++) {
torch::nn::Sequential empty_downsample = std::make_shared<torch::nn::SequentialImpl>();
newBlock = std::make_shared<BasicBlock>(this->inplanes, planes, /*stride=*/1, empty_downsample);
layers->push_back(newBlock);
}
return layers;
}
torch::Tensor forward(torch::Tensor x) {
x = this->conv1->forward(x);
x = this->bn1->forward(x);
x = torch::relu(x);
x = torch::max_pool2d(x, /*kernel_size*/{3}, /*stride*/{2}, /*padding*/{1});
x = this->layer1->forward(x);
x = this->layer2->forward(x);
x = this->layer3->forward(x);
x = this->layer4->forward(x);
x = torch::adaptive_avg_pool2d(x, {1,1});
x = x.view({-1, 512});
auto logits = this->fc1->forward(x);
x = torch::softmax(logits, /*dim=*/1);
return x;
}
};
struct Options {
std::string data_root{"data"};
int32_t batch_size{64};
int32_t epochs{10};
double lr{0.01};
double momentum{0.5};
bool no_cuda{false};
int32_t seed{1};
int32_t test_batch_size{1000};
int32_t log_interval{10};
};
template <typename DataLoader>
void train(
int32_t epoch,
const Options& options,
ResNet& model,
torch::Device device,
DataLoader& data_loader,
torch::optim::SGD& optimizer,
size_t dataset_size) {
model.train();
size_t batch_idx = 0;
for (auto& batch : data_loader) {
auto data = batch.data.to(device), targets = batch.target.to(device);
optimizer.zero_grad();
auto output = model.forward(data);
auto loss = torch::nll_loss(output, targets);
loss.backward();
optimizer.step();
if (batch_idx++ % options.log_interval == 0) {
std::cout << "Train Epoch: " << epoch << " ["
<< batch_idx * batch.data.size(0) << "/" << dataset_size
<< "]\tLoss: " << loss.template item<float>() << std::endl;
}
}
}
template <typename DataLoader>
void test(
ResNet& model,
torch::Device device,
DataLoader& data_loader,
size_t dataset_size) {
torch::NoGradGuard no_grad;
model.eval();
double test_loss = 0;
int32_t correct = 0;
for (const auto& batch : data_loader) {
auto data = batch.data.to(device), targets = batch.target.to(device);
auto output = model.forward(data);
test_loss += torch::nll_loss(
output,
targets,
/*weight=*/{},
Reduction::Sum)
.template item<float>();
auto pred = output.argmax(1);
correct += pred.eq(targets).sum().template item<int64_t>();
}
test_loss /= dataset_size;
std::cout << "Test set: Average loss: " << test_loss
<< ", Accuracy: " << correct << "/" << dataset_size << std::endl;
}
struct Normalize : public torch::data::transforms::TensorTransform<> {
Normalize(float mean, float stddev)
: mean_(torch::tensor(mean)), stddev_(torch::tensor(stddev)) {}
torch::Tensor operator()(torch::Tensor input) {
return input.sub_(mean_).div_(stddev_);
}
torch::Tensor mean_, stddev_;
};
auto main(int argc, const char* argv[]) -> int {
torch::manual_seed(0);
Options options;
torch::DeviceType device_type;
if (torch::cuda::is_available() && !options.no_cuda) {
std::cout << "CUDA available! Training on GPU" << std::endl;
device_type = torch::kCUDA;
} else {
std::cout << "Training on CPU" << std::endl;
device_type = torch::kCPU;
}
torch::Device device(device_type);
int layers[4] = {2,2,2,2};
auto model_shared = std::make_shared<ResNet>(1, layers);
auto model = model_shared.get();
model->to(device);
auto train_dataset =
torch::data::datasets::MNIST(
options.data_root, torch::data::datasets::MNIST::Mode::kTrain)
.map(Normalize(0.1307, 0.3081))
.map(torch::data::transforms::Stack<>());
//const auto dataset_size = train_dataset.size();
auto x = train_dataset.size();
const auto dataset_size = 50000;
auto train_loader = torch::data::make_data_loader(std::move(train_dataset), options.batch_size);
auto test_loader = torch::data::make_data_loader(
torch::data::datasets::MNIST(
options.data_root, torch::data::datasets::MNIST::Mode::kTest)
.map(Normalize(0.1307, 0.3081))
.map(torch::data::transforms::Stack<>()),
options.batch_size);
torch::optim::SGD optimizer(
model->parameters(),
torch::optim::SGDOptions(options.lr).momentum(options.momentum));
for (size_t epoch = 1; epoch <= options.epochs; ++epoch) {
train(epoch, options, *model, device, *train_loader, optimizer, dataset_size);
test(*model, device, *test_loader, dataset_size);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment