Last active
January 6, 2019 08:10
-
-
Save bnsh/7c5cdc8ec13b7b222fda4ac77885c899 to your computer and use it in GitHub Desktop.
Problems compiling LLTM demo module from PyTorch.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <torch/torch.h> | |
#include <iostream> | |
#include <vector> | |
at::Tensor d_sigmoid(at::Tensor z) { | |
auto s = at::sigmoid(z); | |
return (1-s) * s; | |
} | |
std::vector<at::Tensor> lltm_forward( | |
at::Tensor input, | |
at::Tensor weights, | |
at::Tensor bias, | |
at::Tensor old_h, | |
at::Tensor old_cell | |
) { | |
auto X = at::cat({old_h, input}, /*dim=*/1); | |
auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1)); | |
auto gates = gate_weights.chunk(3, /*dim=*/1); | |
auto input_gate = at::sigmoid(gates[0]); | |
auto output_gate = at::sigmoid(gates[1]); | |
auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0); | |
auto new_cell = old_cell + candidate_cell * input_gate; | |
auto new_h = at::tanh(new_cell) * output_gate; | |
return { | |
new_h, | |
new_cell, | |
input_gate, | |
candidate_cell, | |
X, | |
gate_weights | |
}; | |
} | |
at::Tensor d_tanh(at::Tensor z) { | |
return 1 - z.tanh().pow(2); | |
} | |
at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) { | |
auto e = z.exp(); | |
auto mask = (alpha * (e-1)) < 0; | |
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); | |
} | |
std::vector<at::Tensor> lltm_backward( | |
at::Tensor grad_h, | |
at::Tensor grad_cell, | |
at::Tensor new_cell, | |
at::Tensor input_gate, | |
at::Tensor output_gate, | |
at::Tensor candidate_cell, | |
at::Tensor X, | |
at::Tensor gate_weights, | |
at::Tensor weights | |
) { | |
auto d_output_gate = at::tanh(new_cell) * grad_h; | |
auto d_tanh_new_cell = output_gate * grad_h; | |
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; | |
auto d_old_cell = d_new_cell; | |
auto d_candidate_cell = input_gate * d_new_cell; | |
auto d_input_gate = candidate_cell * d_new_cell; | |
auto gates = gate_weights.chunk(3, /*dim=*/1); | |
d_input_gate *= d_sigmoid(gates[0]); | |
d_output_gate *= d_sigmoid(gates[1]); | |
d_candidate_cell *= d_elu(gates[2]); | |
auto d_gates = | |
at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); | |
auto d_weights = d_gates.t().mm(X); | |
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); | |
auto d_X = d_gates.mm(weights); | |
const auto state_size = grad_h.size(1); | |
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); | |
auto d_input = d_X.slice(/*dim=*/1, state_size); | |
return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("backward", &lltm_backward, "LLTM backward"); | |
m.def("forward", &lltm_forward, "LLTM forward"); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python3 | |
"""This is the LLTM module from https://pytorch.org/tutorials/advanced/cpp_extension.html""" | |
from setuptools import setup | |
from torch.utils.cpp_extension import CppExtension, BuildExtension | |
def main(): | |
setup(name="lltm", | |
ext_modules=[CppExtension("lltm", ["lltm.cpp"], extra_compile_args = ["-Wno-c++0x-compat", "-std=c++11"])], | |
cmd_class={"build_ext": BuildExtension}) | |
if __name__ == "__main__": | |
main() |
(Just in case anyone else comes across this... The problem was that I needed to add
#include <pybind11/pybind11.h>
to lltm.cpp.)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I'm having problems compiling this demo module for pyTorch.. when I run
python3 ./setup.py build
I get this:What am I doing wrong?