Skip to content

Instantly share code, notes, and snippets.

@mjjimenez
Last active January 14, 2020 07:41
Show Gist options
  • Save mjjimenez/900e7356766e988179bc47fe48128f9d to your computer and use it in GitHub Desktop.
Save mjjimenez/900e7356766e988179bc47fe48128f9d to your computer and use it in GitHub Desktop.
torch::jit::script::Module forwardModel = torch::jit::load(forwardFilePath.UTF8String);
torch::jit::script::Module backwardModel = torch::jit::load(backwardFilePath.UTF8String);
auto w = torch::IValue(torch::full({10}, 1.0f));
torch::IValue loss;
//Create train x
std::vector<torch::jit::IValue> train_x;
for (int i = 0; i < 500; i++) {
train_x.push_back(torch::rand({10}));
}
//Create train y
std::vector<torch::jit::IValue> train_y;
for (int i = 0; i < 500; i++) {
train_y.push_back(torch::rand({1}));
}
for (int i = 0; i < 500; i++) {
for(unsigned i = 0; i < train_x.size(); ++i) {
auto x = train_x[i];
auto y = train_y[i];
torch::autograd::AutoGradMode guard(false);
at::AutoNonVariableTypeMode non_var_type_mode(true);
std::vector<torch::jit::IValue> forwardArgs;
forwardArgs.push_back(x);
forwardArgs.push_back(y);
forwardArgs.push_back(w);
loss = forwardModel.forward(forwardArgs);
auto tensor = loss.toTensor();
auto tensor_a = tensor.accessor<float, 1>();
auto loss_value = tensor_a[0];
std::cout << loss_value << std::endl;
torch::IValue loss_param = torch::IValue(torch::tensor(loss_value));
w = backwardModel.forward({w, loss_param, x});
}
}
class ForwardModule(torch.nn.Module):
def forward(self,x,y,W):
return torch.sum(W*x,axis=0)-y
class BackwardModule(torch.nn.Module):
def forward(self,W,loss,x):
g = 2*loss*x
return W - 0.00001*g
forward_model = ForwardModule()
backward_model = BackwardModule()
loss = torch.rand(1)
W = torch.rand(10)
x = torch.rand(10)
y = torch.rand(1)
forward_model = torch.jit.trace(forward_model,[x,y,W])
forward_model.save("forward_model.pt")
backward_model = torch.jit.trace(backward_model, [W,loss,x])
backward_model.save("backward_model.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment