Skip to content

Instantly share code, notes, and snippets.

@sathyarr
Last active July 30, 2021 12:50
Show Gist options
  • Save sathyarr/58f5147d92f8b8168c5a0d0f8b245d2e to your computer and use it in GitHub Desktop.
Save sathyarr/58f5147d92f8b8168c5a0d0f8b245d2e to your computer and use it in GitHub Desktop.
Custom operation to replace py_func in google/seq2seq(beam_search.py#L90). Using this Custom operation helps to successfully export the model for Tensorflow Serving.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
REGISTER_OP("GatherTreePyCustom")
.Input("values: int32")
.Input("parents: int32")
.Output("res: int32");
class GatherTreePyCustomOp : public OpKernel {
public:
explicit GatherTreePyCustomOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor_values = context->input(0);
auto input_values = input_tensor_values.matrix<int32>();
// Grab the input tensor
const Tensor& input_tensor_parents = context->input(1);
auto input_parents = input_tensor_parents.matrix<int32>();
const TensorShape& input_tensor_values_shape = input_tensor_values.shape();
const TensorShape& input_tensor_parents_shape = input_tensor_parents.shape();
int beam_length = input_tensor_values_shape.dim_size(0);
int num_beams = input_tensor_values_shape.dim_size(1);
// Create an output tensor
Tensor* output_tensor = NULL;
TensorShape output_shape({input_tensor_values_shape.dim_size(0), input_tensor_values_shape.dim_size(1)});
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor));
auto output = output_tensor->matrix<int32>();
// res in python code
output.setZero();
for(int i = 0; i < input_tensor_values_shape.dim_size(1); i++){
output(input_tensor_values_shape.dim_size(0) - 1, i) = input_values(input_tensor_values_shape.dim_size(0) - 1, i);
}
for(int beam_id = 0; beam_id < num_beams; beam_id++){
int parent = input_parents(input_tensor_parents_shape.dim_size(0) - 1, beam_id);
for(int level = beam_length - 2; level >= 0; level--){
output(level, beam_id) = input_values(level, parent);
parent = input_parents(level, parent);
}
}
}
};
REGISTER_KERNEL_BUILDER(Name("GatherTreePyCustom").Device(DEVICE_CPU), GatherTreePyCustomOp);
@kunalgoyal9
Copy link

kunalgoyal9 commented Jul 29, 2021

@sathyarr Thanks for this, how did you use it in the original code?

@sathyarr
Copy link
Author

@kunalgoyal9 It should be placed at an appropriate directory while building the tensorflow server from source

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment