Created
March 10, 2020 08:14
-
-
Save springkim/fb1ec7cda06fc85e2c414192f96b8bb7 to your computer and use it in GitHub Desktop.
CIFAR10 Reader for libtorch 1.4.0>
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef LIBTORCH_CIFAR10_H | |
#define LIBTORCH_CIFAR10_H | |
#include<torch/torch.h> | |
namespace { | |
constexpr uint32_t kTrainSize = 50000; | |
constexpr uint32_t kTestSize = 10000; | |
const std::vector<std::string> kTrainImagesFilename = { "data_batch_1.bin","data_batch_2.bin","data_batch_3.bin","data_batch_4.bin","data_batch_5.bin" }; | |
const std::vector<std::string> kTestImagesFilename = {"test_batch.bin"}; | |
constexpr uint32_t kImageRows = 32; | |
constexpr uint32_t kImageColumns = 32; | |
std::vector<std::string> join_paths(std::string head, const std::vector<std::string>& tail) { | |
if (head.back() != '/') { | |
head.push_back('/'); | |
} | |
std::vector<std::string> ret; | |
for (auto&e : tail) { | |
ret.push_back(head + e); | |
} | |
return ret; | |
} | |
std::pair<torch::Tensor,torch::Tensor> read_images(const std::string& root, bool train) { | |
const auto paths = join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename); | |
const int count = 10000 * paths.size(); | |
auto images = torch::empty({ count,3,kImageRows,kImageColumns }, torch::kByte); | |
auto targets = torch::empty({ count }, torch::kByte); | |
int fcnt = 0; | |
const int data_size = 3 * kImageRows*kImageColumns; | |
for (auto&path : paths) { | |
std::ifstream datas(path, std::ios::binary); | |
TORCH_CHECK(datas, "Error opening images file at ", path); | |
for (int i = 0; i < 10000; i++) { | |
datas >> *(reinterpret_cast<char*>(targets.data_ptr()) + (fcnt * 10000 + i)); | |
datas.read(reinterpret_cast<char*>(images.data_ptr()) + (data_size*i + fcnt*10000* data_size), data_size); | |
} | |
fcnt++; | |
} | |
return std::make_pair(images.to(torch::kFloat32).div_(255), targets.to(torch::kInt64)); | |
} | |
} | |
class CIFAR10 : public torch::data::Dataset<CIFAR10> { | |
private: | |
torch::Tensor images_, targets_; | |
public: | |
enum class Mode { kTrain, kTest }; | |
explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain) { | |
auto data=read_images(root, mode == Mode::kTrain); | |
images_ = data.first; | |
targets_ = data.second; | |
} | |
torch::data::Example<> get(size_t index) override { | |
return { images_[index],targets_[index] }; | |
} | |
c10::optional<size_t> size() const override { | |
return images_.size(0); | |
} | |
bool is_train() const noexcept { | |
return images_.size(0) == kTrainSize; | |
} | |
const torch::Tensor& images() const { | |
return images_; | |
} | |
const torch::Tensor& targets()const { | |
return targets_; | |
} | |
}; | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment