Created
July 23, 2018 20:59
-
-
Save jamesr66a/afc084db32116a05a0a90a5f9c51ccb7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit bd1b4dbe99e6040ded6c4ab2a59798536a2d6a45 | |
Author: James Reed <jamesreed@fb.com> | |
Date: Mon Jul 23 13:59:34 2018 -0700 | |
Fix zipfile export | |
diff --git a/test/expect/TestScript.test_script_module_file_export.expect b/test/expect/TestScript.test_script_module_file_export.expect | |
new file mode 100644 | |
index 000000000..1d7ce966b | |
--- /dev/null | |
+++ b/test/expect/TestScript.test_script_module_file_export.expect | |
@@ -0,0 +1 @@ | |
+['__MODEL_PROTO', '$0', '$9', '$6', '$1', '$2', '$3', '$7', '$4', '$5', '$8', '$10'] | |
\ No newline at end of file | |
diff --git a/test/test_jit.py b/test/test_jit.py | |
index 4f9c35553..9363c0954 100644 | |
--- a/test/test_jit.py | |
+++ b/test/test_jit.py | |
@@ -3055,6 +3055,31 @@ def func(t): | |
o = m(input) | |
self.assertEqual(o, input + torch.ones([2, 2], dtype=torch.float)) | |
+ def test_script_module_file_export(self): | |
+ from torch.onnx import OperatorExportTypes, ExportTypes | |
+ class M(torch.jit.ScriptModule): | |
+ def __init__(self): | |
+ super(M, self).__init__(False) | |
+ | |
+ @torch.jit.script_method | |
+ def foo(self): | |
+ return torch.ones([2, 2]) | |
+ | |
+ @torch.jit.script_method | |
+ def forward(self, input): | |
+ return input + torch.ones([2, 2]) | |
+ | |
+ m_orig = M() | |
+ m_import = torch.jit.ScriptModule() | |
+ import io | |
+ f = io.BytesIO() | |
+ torch.onnx._export_module(m_orig, f, OperatorExportTypes.RAW, ExportTypes.ZIP_ARCHIVE) | |
+ f.seek(0) | |
+ import zipfile | |
+ with zipfile.ZipFile(f, 'r', compression=zipfile.ZIP_STORED) as z: | |
+ self.assertExpected(str([file.filename for file in z.infolist()])) | |
+ | |
+ | |
def test_onnx_export_script_module(self): | |
class ModuleToExport(torch.jit.ScriptModule): | |
def __init__(self): | |
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py | |
index 1807b711a..ee52129ac 100644 | |
--- a/torch/onnx/__init__.py | |
+++ b/torch/onnx/__init__.py | |
@@ -35,6 +35,10 @@ def _export_to_pretty_string(*args, **kwargs): | |
from torch.onnx import utils | |
return utils._export_to_pretty_string(*args, **kwargs) | |
+def _export_module(*args, **kwargs): | |
+ from torch.onnx import utils | |
+ return utils._export_module(*args, **kwargs) | |
+ | |
def _optimize_trace(trace, operator_export_type): | |
from torch.onnx import utils | |
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py | |
index c5d6f07ba..2d3ee45e0 100644 | |
--- a/torch/onnx/utils.py | |
+++ b/torch/onnx/utils.py | |
@@ -265,7 +265,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=False, | |
return torch_out | |
-def _export_module(module, operator_export_type=OperatorExportTypes.ONNX, export_type=ExportTypes.PROTOBUF_FILE): | |
+def _export_module(module, f, operator_export_type=OperatorExportTypes.ONNX, export_type=ExportTypes.PROTOBUF_FILE): | |
# TODO: Don't allocate a in-memory string for the protobuf | |
from torch.onnx.symbolic import _onnx_opset_version | |
@@ -278,7 +278,7 @@ def _export_module(module, operator_export_type=OperatorExportTypes.ONNX, export | |
else zipfile.ZIP_STORED | |
with zipfile.ZipFile(f, 'w', compression=compression) as z: | |
z.writestr('__MODEL_PROTO', proto) | |
- for k, v in export_map.items(): | |
+ for k, v in storage_map.items(): | |
z.writestr(k, v) | |
elif export_type in [ExportTypes.DIRECTORY, ExportTypes.PROTOBUF_FILE]: | |
raise RuntimeError('Unsupported export type') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment