Skip to content

Instantly share code, notes, and snippets.

@AndreasMadsen
Last active November 1, 2016 14:54
Show Gist options
  • Save AndreasMadsen/4335215cd4293daad3cad745bbeae82a to your computer and use it in GitHub Desktop.
Save AndreasMadsen/4335215cd4293daad3cad745bbeae82a to your computer and use it in GitHub Desktop.
l2loss GPU Eigen::half
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/nn_ops.cc.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "l2loss_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/numeric_op.h"
namespace tensorflow {
REGISTER_OP("CustomL2Loss")
.Input("t: T")
.Output("output: T")
.Attr("T: numbertype")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
L2 Loss.
Computes half the L2 norm of a tensor without the `sqrt`:
output = sum(t ** 2) / 2
t: Typically 2-D, but may have any dimensions.
output: 0-D.
)doc");
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device, typename T>
class CustomL2LossOp : public OpKernel {
public:
explicit CustomL2LossOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// The input tensor can be of any number of dimensions, even though it's
// 2D in most typical applications.
const Tensor& input = context->input(0);
// The output is a single number.
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({}), &output));
functor::CustomL2Loss<Device, T>()(context->eigen_device<Device>(),
input.flat<T>(), output->scalar<T>());
}
};
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("CustomL2Loss").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
CustomL2LossOp<CPUDevice, T>);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
REGISTER_KERNEL(Eigen::half);
#undef REGISTER_KERNEL
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void CustomL2Loss<GPUDevice, T>::operator()(const GPUDevice& d, \
typename TTypes<T>::ConstTensor input, \
typename TTypes<T>::Scalar output); \
extern template struct CustomL2Loss<GPUDevice, T>;
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(double);
DECLARE_GPU_SPEC(Eigen::half);
#undef DECLARE_GPU_SPEC
} // namespace functor
// Registration of the GPU implementations.
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("CustomL2Loss").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
CustomL2LossOp<GPUDevice, T>);
REGISTER_GPU_KERNEL(float);
REGISTER_GPU_KERNEL(double);
REGISTER_GPU_KERNEL(Eigen::half);
#undef REGISTER_GPU_KERNEL
#endif // GOOGLE_CUDA
} // namespace tensorflow
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "l2loss_op.h"
#include "tensorflow/core/framework/register_types.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
template struct functor::CustomL2Loss<GPUDevice, float>;
template struct functor::CustomL2Loss<GPUDevice, double>;
template struct functor::CustomL2Loss<GPUDevice, Eigen::half>;
} // namespace tensorflow
#endif // GOOGLE_CUDA
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_KERNELS_CustomL2Loss_OP_H_
#define TENSORFLOW_KERNELS_CustomL2Loss_OP_H_
// Functor definition for CustomL2LossOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
namespace functor {
// Functor used by CustomL2LossOp to do the computations.
template <typename Device, typename T>
struct CustomL2Loss {
void operator()(const Device& d, typename TTypes<T>::ConstTensor input,
typename TTypes<T>::Scalar output) {
// We flatten the input tensor and reduce on dimension 0, producing
// a single number which is Mul(Sum(x^2), 0.5).
output.device(d) = (input.square() * static_cast<T>(0.5)).sum();
}
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_CustomL2Loss_OP_H_
.PHONY: test lint
NVCC=nvcc
CXX=g++
CXXFLAGS=-std=c++11
CFLAGS=-fPIC -O2 -Wall -D GOOGLE_CUDA=1
NVCCFLAGS=-x cu -Xcompiler -fPIC -Xcompiler -Wall -D GOOGLE_CUDA=1
CPPFLAGS=-isystem $(shell python3 -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') -D_GLIBCXX_USE_CXX11_ABI=0
LDFLAGS=-L /appl/cuda/8.0/lib64 -L /appl/cudnn/v5.1-prod/lib64 -lcudart
# sparsemax targets
l2loss_op.so: l2loss_op.o l2loss_op.cu.o
$(CXX) $(LDFLAGS) -shared $^ -o $@
l2loss_op.o: l2loss_op.cc l2loss_op.h
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $(CFLAGS) -c -o $@ $<
l2loss_op.cu.o: l2loss_op.cu.cc l2loss_op.h
$(NVCC) $(CXXFLAGS) $(CPPFLAGS) $(NVCCFLAGS) -c -o $@ $<
clean:
rm -f *.o
rm -f *.so
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment