Created
April 2, 2021 07:11
-
-
Save take-cheeze/4a456ed04aaf7ca0bb24ff7ad978d815 to your computer and use it in GitHub Desktop.
onnx.diff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
modified onnx/checker.cc | |
@@ -795,13 +795,14 @@ void check_model(const ModelProto& model, CheckerContext& ctx) { | |
} | |
void check_model(const std::string& model_path) { | |
- ModelProto model; | |
+ google::protobuf::Arena arena; | |
+ ModelProto* model = google::protobuf::Arena::CreateMessage<ModelProto>(&arena); | |
std::fstream model_stream(model_path, std::ios::in | std::ios::binary); | |
if (!model_stream.good()) { | |
fail_check("Unable to open model file:", model_path, ". Please check if it is a valid file."); | |
} | |
std::string data{std::istreambuf_iterator<char>{model_stream}, std::istreambuf_iterator<char>{}}; | |
- if (!ParseProtoFromBytes(&model, data.c_str(), data.size())) { | |
+ if (!ParseProtoFromBytes(model, data.c_str(), data.size())) { | |
fail_check( | |
"Unable to parse model from file:", model_path, ". Please check if it is a valid protobuf file of model."); | |
} | |
@@ -813,7 +814,7 @@ void check_model(const std::string& model_path) { | |
model_dir = model_path.substr(0, pos + 1); | |
} | |
ctx.set_model_dir(model_dir); | |
- check_model(model, ctx); | |
+ check_model(*model, ctx); | |
} | |
void check_model(const ModelProto& model) { | |
modified onnx/cpp2py_export.cc | |
@@ -304,13 +304,14 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { | |
shape_inference.def("infer_shapes", [](const py::bytes& bytes, bool check_type, bool strict_mode) { | |
- ModelProto proto{}; | |
- ParseProtoFromPyBytes(&proto, bytes); | |
- shape_inference::InferShapes(proto, check_type, | |
+ google::protobuf::Arena arena; | |
+ ModelProto* proto = google::protobuf::Arena::CreateMessage<ModelProto>(&arena); | |
+ ParseProtoFromPyBytes(proto, bytes); | |
+ shape_inference::InferShapes(*proto, check_type, | |
OpSchemaRegistry::Instance(), | |
strict_mode == true ? 1 : 0); | |
std::string out; | |
- proto.SerializeToString(&out); | |
+ proto->SerializeToString(&out); | |
return py::bytes(out); | |
}, "bytes"_a, "check_type"_a = false, "strict_mode"_a = false); | |
modified onnx/shape_inference/implementation.cc | |
@@ -399,23 +399,24 @@ void InferShapes( | |
const std::string& save_path, | |
const ISchemaRegistry* schema_registry, | |
const int error_mode) { | |
- ModelProto model; | |
+ google::protobuf::Arena arena; | |
+ ModelProto* model = google::protobuf::Arena::CreateMessage<ModelProto>(&arena); | |
std::fstream model_stream(model_path, std::ios::in | std::ios::binary); | |
if (!model_stream.good()) { | |
fail_check("Unable to open model file:", model_path, ". Please check if it is a valid file."); | |
} | |
std::string data{std::istreambuf_iterator<char>{model_stream}, std::istreambuf_iterator<char>{}}; | |
- if (!ParseProtoFromBytes(&model, data.c_str(), data.size())) { | |
+ if (!ParseProtoFromBytes(model, data.c_str(), data.size())) { | |
fail_check( | |
"Unable to parse model from file:", model_path, ". Please check if it is a valid protobuf file of model."); | |
} | |
- InferShapes(model, check_type, schema_registry, error_mode); | |
+ InferShapes(*model, check_type, schema_registry, error_mode); | |
// Save the inferred model to the original model path | |
// Use SerializeToString instead of SerializeToOstream due to LITE_PROTO | |
std::fstream output(save_path, std::ios::out | std::ios::trunc | std::ios::binary); | |
std::string model_string; | |
ONNX_TRY { | |
- model.SerializeToString(&model_string); | |
+ model->SerializeToString(&model_string); | |
output << model_string; | |
} | |
ONNX_CATCH(...) { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment