Skip to content

Instantly share code, notes, and snippets.

@goldsborough
Created September 19, 2018 00:46
Show Gist options
  • Save goldsborough/4fbae0d83c49a26faf09cf8a1b32ef26 to your computer and use it in GitHub Desktop.
Save goldsborough/4fbae0d83c49a26faf09cf8a1b32ef26 to your computer and use it in GitHub Desktop.
Stream/Random policy data loader
namespace torch {
namespace data {
template <typename D = torch::Tensor, typename L = torch::Tensor>
struct Example {
D data;
L label;
};
template <typename D>
struct Example<D, void> {
D data;
};
namespace datasets {
// can this just be an enum class?
namespace access_policy {
// Allows next_batch(size_t batch_size)
struct Stream {};
// Allows next_batch(ArrayRef<size_t> indices)
struct Random : Stream {};
} // namespace access_policy
template <typename S, typename T>
struct Map;
// Trait class
template <
typename S,
typename B = std::vector<Example<>>,
typename A = access_policy::Random>
struct Dataset {
using Self = S;
using BatchType = B;
using AccessPolicy = A;
template <typename TransformType, typename... Args>
Map<Self, TransformType> map(Args&&... args) &&;
};
// Map
template <typename S, typename T, typename AccessPolicy>
struct MapBase : Dataset<Map<S, T>, typename T::OutputType, typename S::AccessPolicy> {
MapBase(S&& dataset, T&& transform)
: dataset(std::move(dataset)), transform(std::move(transform)) {}
S dataset;
T transform;
};
template <typename S, typename T, typename AccessPolicy>
struct MapImpl;
template <typename S, typename T>
struct MapImpl<S, T, access_policy::Stream> : MapBase<S, T, access_policy::Stream> {
using MapBase<S, T, access_policy::Stream>::MapBase;
typename T::OutputType next(size_t count) {
return this->transform(this->dataset.next(count));
}
};
template <typename S, typename T>
struct MapImpl<S, T, access_policy::Random> : MapBase<S, T, access_policy::Random> {
using MapBase<S, T, access_policy::Random>::MapBase;
typename T::OutputType next(std::vector<size_t>&& indices) {
return this->transform.apply(this->dataset.next(std::move(indices)));
}
};
template<typename S, typename T>
struct Map : MapImpl<S, T, typename S::AccessPolicy> {
using MapImpl<S, T, typename S::AccessPolicy>::MapImpl;
};
// End Map
template <typename S, typename B, typename A>
template <typename TransformType, typename... Args>
Map<S, TransformType> Dataset<S, B, A>::map(Args&&... args) && {
// static_assert(
// std::is_same<B, typename TransformType::InputType>::value,
// "Batch type of dataset does not match input type of transform");
return {std::move(*static_cast<S*>(this)), TransformType(std::forward<Args>(args)...)};
}
class MNIST : public Dataset<MNIST> {
public:
explicit MNIST(const std::string& root_path, bool train = true) : data_(100) {}
std::vector<Example<>> next(std::vector<size_t>&& indices) {
std::vector<Example<>> examples;
for (const auto& index : indices) {
examples.push_back(data_[index]);
}
return examples;
}
size_t size() const noexcept {
return data_.size();
}
private:
std::vector<Example<>> data_;
};
struct RowBatch { size_t count; };
class HiveDataset : public Dataset<HiveDataset, RowBatch, access_policy::Stream> {
public:
HiveDataset() = default;
RowBatch next(size_t count) {
return {count};
}
size_t size() const noexcept {
return 12345;
}
};
} // namespace datasets
namespace transforms {
template <typename I, typename O>
struct Transform {
using InputType = I;
using OutputType = O;
};
template<typename L = torch::Tensor>
struct TensorTransform : Transform<std::vector<Example<torch::Tensor, L>>, std::vector<Example<torch::Tensor, L>>> {
virtual ~TensorTransform() = default;
virtual Tensor apply(const Tensor& tensor) = 0;
Example<torch::Tensor, L> apply(Example<torch::Tensor, L>&& batch) const {
for (const auto& example : batch) {
apply(example.data);
}
return std::move(batch);
}
};
struct Normalize : TensorTransform<> {
Normalize(double mean, double stddev) : mean(mean), stddev(stddev) {}
Tensor apply(const Tensor& tensor) override {
return (tensor - mean) / stddev;
}
template<typename L = torch::Tensor>
std::vector<Example<torch::Tensor, L>> apply(std::vector<Example<torch::Tensor, L>>&& batch) {
for (auto& example : batch) {
example.data = apply(example.data);
}
return std::move(batch);
}
double mean{0};
double stddev{0};
};
} // namespace transforms
} // namespace data
} // namespace torch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment