Skip to content

Instantly share code, notes, and snippets.

@take-cheeze
Created April 2, 2021 07:11
Show Gist options
  • Save take-cheeze/4a456ed04aaf7ca0bb24ff7ad978d815 to your computer and use it in GitHub Desktop.
Save take-cheeze/4a456ed04aaf7ca0bb24ff7ad978d815 to your computer and use it in GitHub Desktop.
onnx.diff
modified onnx/checker.cc
@@ -795,13 +795,14 @@ void check_model(const ModelProto& model, CheckerContext& ctx) {
}
void check_model(const std::string& model_path) {
- ModelProto model;
+ google::protobuf::Arena arena;
+ ModelProto* model = google::protobuf::Arena::CreateMessage<ModelProto>(&arena);
std::fstream model_stream(model_path, std::ios::in | std::ios::binary);
if (!model_stream.good()) {
fail_check("Unable to open model file:", model_path, ". Please check if it is a valid file.");
}
std::string data{std::istreambuf_iterator<char>{model_stream}, std::istreambuf_iterator<char>{}};
- if (!ParseProtoFromBytes(&model, data.c_str(), data.size())) {
+ if (!ParseProtoFromBytes(model, data.c_str(), data.size())) {
fail_check(
"Unable to parse model from file:", model_path, ". Please check if it is a valid protobuf file of model.");
}
@@ -813,7 +814,7 @@ void check_model(const std::string& model_path) {
model_dir = model_path.substr(0, pos + 1);
}
ctx.set_model_dir(model_dir);
- check_model(model, ctx);
+ check_model(*model, ctx);
}
void check_model(const ModelProto& model) {
modified onnx/cpp2py_export.cc
@@ -304,13 +304,14 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
shape_inference.def("infer_shapes", [](const py::bytes& bytes, bool check_type, bool strict_mode) {
- ModelProto proto{};
- ParseProtoFromPyBytes(&proto, bytes);
- shape_inference::InferShapes(proto, check_type,
+ google::protobuf::Arena arena;
+ ModelProto* proto = google::protobuf::Arena::CreateMessage<ModelProto>(&arena);
+ ParseProtoFromPyBytes(proto, bytes);
+ shape_inference::InferShapes(*proto, check_type,
OpSchemaRegistry::Instance(),
strict_mode == true ? 1 : 0);
std::string out;
- proto.SerializeToString(&out);
+ proto->SerializeToString(&out);
return py::bytes(out);
}, "bytes"_a, "check_type"_a = false, "strict_mode"_a = false);
modified onnx/shape_inference/implementation.cc
@@ -399,23 +399,24 @@ void InferShapes(
const std::string& save_path,
const ISchemaRegistry* schema_registry,
const int error_mode) {
- ModelProto model;
+ google::protobuf::Arena arena;
+ ModelProto* model = google::protobuf::Arena::CreateMessage<ModelProto>(&arena);
std::fstream model_stream(model_path, std::ios::in | std::ios::binary);
if (!model_stream.good()) {
fail_check("Unable to open model file:", model_path, ". Please check if it is a valid file.");
}
std::string data{std::istreambuf_iterator<char>{model_stream}, std::istreambuf_iterator<char>{}};
- if (!ParseProtoFromBytes(&model, data.c_str(), data.size())) {
+ if (!ParseProtoFromBytes(model, data.c_str(), data.size())) {
fail_check(
"Unable to parse model from file:", model_path, ". Please check if it is a valid protobuf file of model.");
}
- InferShapes(model, check_type, schema_registry, error_mode);
+ InferShapes(*model, check_type, schema_registry, error_mode);
// Save the inferred model to the original model path
// Use SerializeToString instead of SerializeToOstream due to LITE_PROTO
std::fstream output(save_path, std::ios::out | std::ios::trunc | std::ios::binary);
std::string model_string;
ONNX_TRY {
- model.SerializeToString(&model_string);
+ model->SerializeToString(&model_string);
output << model_string;
}
ONNX_CATCH(...) {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment