Skip to content

Instantly share code, notes, and snippets.

@bwasti
Created April 5, 2019 06:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bwasti/a9d3b82d06f0de7df34a04e0c1ae4764 to your computer and use it in GitHub Desktop.
Save bwasti/a9d3b82d06f0de7df34a04e0c1ae4764 to your computer and use it in GitHub Desktop.
diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h
index 1a1a8da6..2e15b179 100644
--- a/include/tvm/runtime/c_runtime_api.h
+++ b/include/tvm/runtime/c_runtime_api.h
@@ -85,6 +85,7 @@ typedef enum {
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
+ kManagedArrayHandle = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h
index 9e4dbd0a..d5989eb7 100644
--- a/include/tvm/runtime/packed_func.h
+++ b/include/tvm/runtime/packed_func.h
@@ -447,6 +447,17 @@ class TVMPODValue_ {
return nullptr;
}
}
+ operator DLManagedTensor*() const {
+ if (type_code_ == kManagedArrayHandle) {
+ return static_cast<DLManagedTensor*>(value_.v_handle);
+ } else {
+ if (type_code_ == kNull) return nullptr;
+ LOG(FATAL) << "Expected "
+ << "DLManagedTensor* but got "
+ << TypeCode2Str(type_code_);
+ return nullptr;
+ }
+ }
operator NDArray() const {
if (type_code_ == kNull) return NDArray();
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
@@ -881,6 +892,7 @@ inline const char* TypeCode2Str(int type_code) {
case kNull: return "NULL";
case kNodeHandle: return "NodeHandle";
case kArrayHandle: return "ArrayHandle";
+ case kManagedArrayHandle: return "ManagedArrayHandle";
case kTVMType: return "TVMType";
case kTVMContext: return "TVMContext";
case kFuncHandle: return "FunctionHandle";
@@ -1052,6 +1064,10 @@ class TVMArgsSetter {
values_[i].v_handle = value;
type_codes_[i] = kArrayHandle;
}
+ void operator()(size_t i, DLManagedTensor* value) const {
+ values_[i].v_handle = value;
+ type_codes_[i] = kManagedArrayHandle;
+ }
void operator()(size_t i, TVMContext value) const {
values_[i].v_ctx = value;
type_codes_[i] = kTVMContext;
diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc
index 13146d79..39b8c8c6 100644
--- a/src/runtime/graph/graph_runtime.cc
+++ b/src/runtime/graph/graph_runtime.cc
@@ -77,6 +77,13 @@ void GraphRuntime::SetInput(int index, DLTensor* data_in) {
uint32_t eid = this->entry_id(input_nodes_[index], 0);
data_entry_[eid].CopyFrom(data_in);
}
+
+void GraphRuntime::SetInputZC(int index, DLManagedTensor* data_in) {
+ CHECK_LT(static_cast<size_t>(index), input_nodes_.size());
+ uint32_t eid = this->entry_id(input_nodes_[index], 0);
+ data_entry_[eid] = NDArray::FromDLPack(data_in);
+}
+
/*!
* \brief Get the number of outputs
*
@@ -322,6 +329,15 @@ PackedFunc GraphRuntime::GetFunction(
this->SetInput(args[0], args[1]);
}
});
+ } else if (name == "set_input_zero_copy") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ if (args[0].type_code() == kStr) {
+ int in_idx = this->GetInputIndex(args[0]);
+ if (in_idx >= 0) this->SetInputZC(in_idx, args[1]);
+ } else {
+ this->SetInputZC(args[0], args[1]);
+ }
+ });
} else if (name == "get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args.num_args == 2) {
diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h
index cdd236a0..db6b402f 100644
--- a/src/runtime/graph/graph_runtime.h
+++ b/src/runtime/graph/graph_runtime.h
@@ -92,6 +92,7 @@ class GraphRuntime : public ModuleNode {
* \param data_in The input data.
*/
void SetInput(int index, DLTensor* data_in);
+ void SetInputZC(int index, DLManagedTensor* data_in);
/*!
* \brief Get the number of outputs
*
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment