Skip to content

Instantly share code, notes, and snippets.

@asimshankar
Created November 17, 2017 20:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save asimshankar/f11e79a5b7947e716ae2242387162ba2 to your computer and use it in GitHub Desktop.
Save asimshankar/f11e79a5b7947e716ae2242387162ba2 to your computer and use it in GitHub Desktop.
TensorFlow Custom Datasets
PATHS=$(python -c "import tensorflow as tf; print('-I{} -I{}/external/nsync/public -L{}'.format(tf.sysconfig.get_include(), tf.sysconfig.get_include(), tf.sysconfig.get_lib()))")
# Unfortunately, not all header files are currently included in the PIP package yet. So for now,
# clone the TensorFlow repository and switch to the appropriate branch for some additional files.
# Let's say that is in /tmp/tensorflow_src
g++ -shared -fPIC ${PATHS} -I/tensorflow -std=c++11 dataset.cc -ltensorflow_framework -olibmydataset.so
#include <string>
#include <vector>
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/dataset.h"
using std::string;
using tensorflow::Status;
using tensorflow::Tensor;
REGISTER_OP("MyDataset")
.Input("value: int32")
.Output("handle: variant")
.SetIsStateful()
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
.Doc(R"doc(
Silly dataset that produces 'value' once.
)doc");
class MyDatasetOp : public tensorflow::DatasetOpKernel {
public:
explicit MyDatasetOp(tensorflow::OpKernelConstruction* ctx)
: tensorflow::DatasetOpKernel(ctx) {}
void MakeDataset(tensorflow::OpKernelContext* ctx,
tensorflow::DatasetBase** output) override {
tensorflow::OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("value", &inputs));
std::vector<Tensor> components;
components.reserve(inputs.size());
for (const Tensor& t : inputs) {
components.push_back(t);
}
*output = new Dataset(std::move(components));
}
private:
class Dataset : public tensorflow::DatasetBase {
public:
explicit Dataset(std::vector<Tensor> tensors)
: tensors_(std::move(tensors)) {
for (const Tensor& t : tensors_) {
dtypes_.push_back(t.dtype());
shapes_.emplace_back(t.shape().dim_sizes());
}
}
std::unique_ptr<tensorflow::IteratorBase> MakeIterator(
const string& prefix) const override {
return std::unique_ptr<tensorflow::IteratorBase>(
new Iterator({this, prefix + "::MyDataset"}));
}
const tensorflow::DataTypeVector& output_dtypes() const override {
return dtypes_;
}
const std::vector<tensorflow::PartialTensorShape>& output_shapes()
const override {
return shapes_;
}
string DebugString() override { return "MyDatasetOp::Dataset"; }
private:
class Iterator : public tensorflow::DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: tensorflow::DatasetIterator<Dataset>(params), produced_(false) {}
Status GetNextInternal(tensorflow::IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
tensorflow::mutex_lock l(mu_);
if (!produced_) {
*out_tensors = dataset()->tensors_;
produced_ = true;
*end_of_sequence = false;
return Status::OK();
} else {
*end_of_sequence = true;
return Status::OK();
}
}
private:
tensorflow::mutex mu_;
bool produced_ GUARDED_BY(mu_);
};
const std::vector<Tensor> tensors_;
tensorflow::DataTypeVector dtypes_;
std::vector<tensorflow::PartialTensorShape> shapes_;
};
};
namespace tensorflow {
REGISTER_KERNEL_BUILDER(Name("MyDataset").Device(DEVICE_CPU), MyDatasetOp);
} // namespace tensorflow
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment