Skip to content

Instantly share code, notes, and snippets.

@bnsh
Last active January 6, 2019 08:10
Show Gist options
  • Save bnsh/7c5cdc8ec13b7b222fda4ac77885c899 to your computer and use it in GitHub Desktop.
Save bnsh/7c5cdc8ec13b7b222fda4ac77885c899 to your computer and use it in GitHub Desktop.
Problems compiling LLTM demo module from PyTorch.
#include <torch/torch.h>
#include <iostream>
#include <vector>
at::Tensor d_sigmoid(at::Tensor z) {
auto s = at::sigmoid(z);
return (1-s) * s;
}
std::vector<at::Tensor> lltm_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell
) {
auto X = at::cat({old_h, input}, /*dim=*/1);
auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1));
auto gates = gate_weights.chunk(3, /*dim=*/1);
auto input_gate = at::sigmoid(gates[0]);
auto output_gate = at::sigmoid(gates[1]);
auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0);
auto new_cell = old_cell + candidate_cell * input_gate;
auto new_h = at::tanh(new_cell) * output_gate;
return {
new_h,
new_cell,
input_gate,
candidate_cell,
X,
gate_weights
};
}
at::Tensor d_tanh(at::Tensor z) {
return 1 - z.tanh().pow(2);
}
at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) {
auto e = z.exp();
auto mask = (alpha * (e-1)) < 0;
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
}
std::vector<at::Tensor> lltm_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights
) {
auto d_output_gate = at::tanh(new_cell) * grad_h;
auto d_tanh_new_cell = output_gate * grad_h;
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;
auto d_old_cell = d_new_cell;
auto d_candidate_cell = input_gate * d_new_cell;
auto d_input_gate = candidate_cell * d_new_cell;
auto gates = gate_weights.chunk(3, /*dim=*/1);
d_input_gate *= d_sigmoid(gates[0]);
d_output_gate *= d_sigmoid(gates[1]);
d_candidate_cell *= d_elu(gates[2]);
auto d_gates =
at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
auto d_weights = d_gates.t().mm(X);
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
auto d_X = d_gates.mm(weights);
const auto state_size = grad_h.size(1);
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
auto d_input = d_X.slice(/*dim=*/1, state_size);
return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("backward", &lltm_backward, "LLTM backward");
m.def("forward", &lltm_forward, "LLTM forward");
}
#! /usr/bin/env python3
"""This is the LLTM module from https://pytorch.org/tutorials/advanced/cpp_extension.html"""
from setuptools import setup
from torch.utils.cpp_extension import CppExtension, BuildExtension
def main():
setup(name="lltm",
ext_modules=[CppExtension("lltm", ["lltm.cpp"], extra_compile_args = ["-Wno-c++0x-compat", "-std=c++11"])],
cmd_class={"build_ext": BuildExtension})
if __name__ == "__main__":
main()
@bnsh
Copy link
Author

bnsh commented Jan 6, 2019

I'm having problems compiling this demo module for pyTorch.. when I run python3 ./setup.py build I get this:

/usr/lib/python3.5/distutils/dist.py:261: UserWarning: Unknown distribution option: 'cmd_class'
  warnings.warn(msg)
running build
running build_ext
building 'lltm' extension
x86_64-linux-gnu-gcc -pthread -DNDEBUG -g -fwrapv -O2 -Wall -Wstrict-prototypes -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/usr/local/lib/python3.5/dist-packages/torch/lib/include -I/usr/local/lib/python3.5/dist-packages/torch/lib/include/torch/csrc/api/include -I/usr/local/lib/python3.5/dist-packages/torch/lib/include/TH -I/usr/local/lib/python3.5/dist-packages/torch/lib/include/THC -I/usr/include/python3.5m -c lltm.cpp -o build/temp.linux-x86_64-3.5/lltm.o -Wno-c++0x-compat -std=c++11
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
lltm.cpp:88:16: error: expected constructor, destructor, or type conversion before ‘(’ token
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
                ^
error: command 'x86_64-linux-gnu-gcc' failed with exit status 1

What am I doing wrong?

@bnsh
Copy link
Author

bnsh commented Jan 6, 2019

(Just in case anyone else comes across this... The problem was that I needed to add

#include <pybind11/pybind11.h>

to lltm.cpp.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment