Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created July 23, 2018 22:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesr66a/a960cd47be0ae7cf1a554417ddf345fa to your computer and use it in GitHub Desktop.
Save jamesr66a/a960cd47be0ae7cf1a554417ddf345fa to your computer and use it in GitHub Desktop.
commit 96d6beb5d300da71c2e5eee0eec9a012480af8f0
Author: James Reed <jamesreed@fb.com>
Date: Mon Jul 23 15:47:43 2018 -0700
Bugfix for stateful module export
diff --git a/test/test_jit.py b/test/test_jit.py
index 9363c0954..39deb0f74 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -3036,6 +3036,7 @@ def func(t):
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__(False)
+ self.param = torch.nn.Parameter(torch.rand(2, 2, dtype=torch.float))
@torch.jit.script_method
def foo(self):
@@ -3043,7 +3044,7 @@ def func(t):
@torch.jit.script_method
def forward(self, input):
- return input + torch.ones([2, 2])
+ return input + torch.ones([2, 2]) + self.param
m_orig = M()
m_import = torch.jit.ScriptModule()
@@ -3053,7 +3054,7 @@ def func(t):
for m in [m_orig, m_import]:
input = torch.ones([2, 2], dtype=torch.float)
o = m(input)
- self.assertEqual(o, input + torch.ones([2, 2], dtype=torch.float))
+ self.assertEqual(o, input + torch.ones([2, 2], dtype=torch.float) + m.param)
def test_script_module_file_export(self):
from torch.onnx import OperatorExportTypes, ExportTypes
diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp
index 94858a3ca..2510548c1 100644
--- a/torch/csrc/jit/import.cpp
+++ b/torch/csrc/jit/import.cpp
@@ -256,8 +256,8 @@ std::shared_ptr<script::Module> ModuleDecoder::decode(
std::tie(parent_module, name) = parseFullName(root_module, tensor_proto.name());
auto param = buildParameter(tensor_proto);
- param_map[tensor_proto.name()] = &param;
parent_module->register_parameter(name, param, tensor_proto.int64_data(0));
+ param_map[tensor_proto.name()] = parent_module->parameter_slot(name);
}
for (auto &node_proto : graph_proto.node()) {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment