Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created July 23, 2018 20:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesr66a/afc084db32116a05a0a90a5f9c51ccb7 to your computer and use it in GitHub Desktop.
Save jamesr66a/afc084db32116a05a0a90a5f9c51ccb7 to your computer and use it in GitHub Desktop.
commit bd1b4dbe99e6040ded6c4ab2a59798536a2d6a45
Author: James Reed <jamesreed@fb.com>
Date: Mon Jul 23 13:59:34 2018 -0700
Fix zipfile export
diff --git a/test/expect/TestScript.test_script_module_file_export.expect b/test/expect/TestScript.test_script_module_file_export.expect
new file mode 100644
index 000000000..1d7ce966b
--- /dev/null
+++ b/test/expect/TestScript.test_script_module_file_export.expect
@@ -0,0 +1 @@
+['__MODEL_PROTO', '$0', '$9', '$6', '$1', '$2', '$3', '$7', '$4', '$5', '$8', '$10']
\ No newline at end of file
diff --git a/test/test_jit.py b/test/test_jit.py
index 4f9c35553..9363c0954 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -3055,6 +3055,31 @@ def func(t):
o = m(input)
self.assertEqual(o, input + torch.ones([2, 2], dtype=torch.float))
+ def test_script_module_file_export(self):
+ from torch.onnx import OperatorExportTypes, ExportTypes
+ class M(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M, self).__init__(False)
+
+ @torch.jit.script_method
+ def foo(self):
+ return torch.ones([2, 2])
+
+ @torch.jit.script_method
+ def forward(self, input):
+ return input + torch.ones([2, 2])
+
+ m_orig = M()
+ m_import = torch.jit.ScriptModule()
+ import io
+ f = io.BytesIO()
+ torch.onnx._export_module(m_orig, f, OperatorExportTypes.RAW, ExportTypes.ZIP_ARCHIVE)
+ f.seek(0)
+ import zipfile
+ with zipfile.ZipFile(f, 'r', compression=zipfile.ZIP_STORED) as z:
+ self.assertExpected(str([file.filename for file in z.infolist()]))
+
+
def test_onnx_export_script_module(self):
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index 1807b711a..ee52129ac 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -35,6 +35,10 @@ def _export_to_pretty_string(*args, **kwargs):
from torch.onnx import utils
return utils._export_to_pretty_string(*args, **kwargs)
+def _export_module(*args, **kwargs):
+ from torch.onnx import utils
+ return utils._export_module(*args, **kwargs)
+
def _optimize_trace(trace, operator_export_type):
from torch.onnx import utils
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index c5d6f07ba..2d3ee45e0 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -265,7 +265,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
return torch_out
-def _export_module(module, operator_export_type=OperatorExportTypes.ONNX, export_type=ExportTypes.PROTOBUF_FILE):
+def _export_module(module, f, operator_export_type=OperatorExportTypes.ONNX, export_type=ExportTypes.PROTOBUF_FILE):
# TODO: Don't allocate a in-memory string for the protobuf
from torch.onnx.symbolic import _onnx_opset_version
@@ -278,7 +278,7 @@ def _export_module(module, operator_export_type=OperatorExportTypes.ONNX, export
else zipfile.ZIP_STORED
with zipfile.ZipFile(f, 'w', compression=compression) as z:
z.writestr('__MODEL_PROTO', proto)
- for k, v in export_map.items():
+ for k, v in storage_map.items():
z.writestr(k, v)
elif export_type in [ExportTypes.DIRECTORY, ExportTypes.PROTOBUF_FILE]:
raise RuntimeError('Unsupported export type')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment