Skip to content

Instantly share code, notes, and snippets.

@springkim
Created March 10, 2020 08:14
Show Gist options
  • Save springkim/fb1ec7cda06fc85e2c414192f96b8bb7 to your computer and use it in GitHub Desktop.
Save springkim/fb1ec7cda06fc85e2c414192f96b8bb7 to your computer and use it in GitHub Desktop.
CIFAR10 Reader for libtorch 1.4.0>
#ifndef LIBTORCH_CIFAR10_H
#define LIBTORCH_CIFAR10_H
#include<torch/torch.h>
namespace {
constexpr uint32_t kTrainSize = 50000;
constexpr uint32_t kTestSize = 10000;
const std::vector<std::string> kTrainImagesFilename = { "data_batch_1.bin","data_batch_2.bin","data_batch_3.bin","data_batch_4.bin","data_batch_5.bin" };
const std::vector<std::string> kTestImagesFilename = {"test_batch.bin"};
constexpr uint32_t kImageRows = 32;
constexpr uint32_t kImageColumns = 32;
std::vector<std::string> join_paths(std::string head, const std::vector<std::string>& tail) {
if (head.back() != '/') {
head.push_back('/');
}
std::vector<std::string> ret;
for (auto&e : tail) {
ret.push_back(head + e);
}
return ret;
}
std::pair<torch::Tensor,torch::Tensor> read_images(const std::string& root, bool train) {
const auto paths = join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename);
const int count = 10000 * paths.size();
auto images = torch::empty({ count,3,kImageRows,kImageColumns }, torch::kByte);
auto targets = torch::empty({ count }, torch::kByte);
int fcnt = 0;
const int data_size = 3 * kImageRows*kImageColumns;
for (auto&path : paths) {
std::ifstream datas(path, std::ios::binary);
TORCH_CHECK(datas, "Error opening images file at ", path);
for (int i = 0; i < 10000; i++) {
datas >> *(reinterpret_cast<char*>(targets.data_ptr()) + (fcnt * 10000 + i));
datas.read(reinterpret_cast<char*>(images.data_ptr()) + (data_size*i + fcnt*10000* data_size), data_size);
}
fcnt++;
}
return std::make_pair(images.to(torch::kFloat32).div_(255), targets.to(torch::kInt64));
}
}
class CIFAR10 : public torch::data::Dataset<CIFAR10> {
private:
torch::Tensor images_, targets_;
public:
enum class Mode { kTrain, kTest };
explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain) {
auto data=read_images(root, mode == Mode::kTrain);
images_ = data.first;
targets_ = data.second;
}
torch::data::Example<> get(size_t index) override {
return { images_[index],targets_[index] };
}
c10::optional<size_t> size() const override {
return images_.size(0);
}
bool is_train() const noexcept {
return images_.size(0) == kTrainSize;
}
const torch::Tensor& images() const {
return images_;
}
const torch::Tensor& targets()const {
return targets_;
}
};
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment