Skip to content

Instantly share code, notes, and snippets.

@driazati
Created March 7, 2019 22:15
Show Gist options
  • Save driazati/5754b7658da29c4286dee7532f15a957 to your computer and use it in GitHub Desktop.
Save driazati/5754b7658da29c4286dee7532f15a957 to your computer and use it in GitHub Desktop.
#include <torch/script.h>
#include <iostream>
#include <memory>
int main() {
std::shared_ptr<torch::jit::script::Module> module =
torch::jit::load("msm.pt");
for (auto ivalue : module->find_method("forward")->initial_ivalues()) {
std::cout << "IValue:" << ivalue << "\n";
}
for (auto &named_ivalue : module->get_attributes()) {
std::cout << "Slot: " << named_ivalue.value().slot() << "\n";
}
for (auto &named_ivalue : module->get_parameters()) {
std::cout << "Parameter: " << named_ivalue.value().slot() << "\n";
}
for (auto initial_ivalue : module->find_method("forward")->initial_ivalues()) {
bool found_param = false;
for (auto& param : module->get_parameters()) {
if (initial_ivalue == param->slot()) {
std::cout << "Found parameter\n";
found_param = true;
break;
}
}
if (found_param) {
continue;
}
for (auto& param : module->get_attributes()) {
if (initial_ivalue == param->slot()) {
std::cout << "Found buffer\n";
found_param = true;
break;
}
}
}
}
import torch
class MyModule(torch.jit.ScriptModule):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
self.register_buffer('my_buffer', torch.randn(N, N))
@torch.jit.script_method
def forward(self, input):
return self.weight + self.my_buffer
my_script_module = MyModule(2, 3)
my_script_module.save("msm.pt")
@driazati
Copy link
Author

driazati commented Mar 7, 2019

C++ Output

IValue:0x79a0a0
IValue:0x7abad0
Slot: 0x7abad0
Parameter: 0x79a0a0
Found parameter
Found buffer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment