Created
September 29, 2019 14:39
-
-
Save knsong/302efa12ea6647893b9ca41fae5eecc6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sample::Sample(){ | |
LOG(INFO) << "Initializing model"; | |
struct timeval begint, endt; | |
gettimeofday(&begint, NULL); | |
TF_Buffer* graph_def = read_file("./models.pb"); | |
tf_graph_ = TF_NewGraph(); | |
// Import graph_def into graph | |
TF_Status* status = TF_NewStatus(); | |
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); | |
TF_GraphImportGraphDef(tf_graph_, graph_def, opts, status); | |
TF_DeleteImportGraphDefOptions(opts); | |
if (TF_GetCode(status) != TF_OK) { | |
LOG(INFO) << "ERROR: Unable to import graph " << TF_Message(status); | |
TF_DeleteStatus(status); | |
TF_DeleteBuffer(graph_def); | |
TF_DeleteGraph(tf_graph_); | |
return; | |
} | |
LOG(INFO) << "Successfully imported graph"; | |
gettimeofday(&endt, NULL); | |
float interval = (endt.tv_sec - begint.tv_sec) + ((float)endt.tv_usec - begint.tv_usec) / 1000000; | |
LOG(INFO) << "Imported graph time is "<< interval << " secs"; | |
inputs_.resize(3); | |
inputs_[0].oper = TF_GraphOperationByName(tf_graph_, "input0"); | |
inputs_[0].index = 0; | |
inputs_[1].oper = TF_GraphOperationByName(tf_graph_, "input1"); | |
inputs_[1].index = 0; | |
inputs_[2].oper = TF_GraphOperationByName(tf_graph_, "input2"); | |
inputs_[2].index = 0; | |
output_.oper = TF_GraphOperationByName(tf_graph_, "output"); | |
output_.index = 0; | |
LOG(INFO) << "Creating session..."; | |
gettimeofday(&begint, NULL); | |
TF_SessionOptions* options = TF_NewSessionOptions(); | |
TF_SetConfig(options, session_config, strlen(session_config), status); | |
if (TF_GetCode(status) != TF_OK) { | |
LOG(INFO) << "ERROR: Unable to read session config: " << TF_Message(status); | |
TF_DeleteBuffer(graph_def); | |
TF_DeleteGraph(tf_graph_); | |
TF_DeleteStatus(status); | |
TF_DeleteSessionOptions(options); | |
return; | |
} | |
tf_session_ = TF_NewSession(tf_graph_, options, status); | |
TF_DeleteSessionOptions(options); | |
if (TF_GetCode(status) != TF_OK) { | |
LOG(INFO) << "ERROR: Unable to create session: " << TF_Message(status); | |
TF_DeleteBuffer(graph_def); | |
TF_DeleteGraph(tf_graph_); | |
TF_CloseSession(tf_session_, status); | |
TF_DeleteSession(tf_session_, status); | |
TF_DeleteStatus(status); | |
return; | |
} | |
TF_DeleteBuffer(graph_def); | |
gettimeofday(&endt, NULL); | |
interval = (endt.tv_sec - begint.tv_sec) + ((float)endt.tv_usec - begint.tv_usec) / 1000000; | |
LOG(INFO) << "Create sessioin time is "<< interval << " secs"; | |
// warm-up session | |
TF_Tensor* input_values[3]; | |
std::int64_t local_feature_shape[3] = {1, 1000, 64}; | |
std::int64_t region_feature_shape[4] = {1, 125, 125, 256}; | |
std::int64_t kpts_shape[3] = {1, 1000, 2}; | |
input_values[0] = create_tensor(TF_FLOAT, local_feature_shape, 3, | |
nullptr, 1000 * 64 * sizeof(float)); | |
input_values[1] = create_tensor(TF_FLOAT, region_feature_shape, 4, | |
nullptr, 125 * 125 * 256 * sizeof(float)); | |
input_values[2] = create_tensor(TF_FLOAT, kpts_shape, 3, | |
nullptr, 1000 * 2 * sizeof(float)); | |
TF_Tensor* output_value; | |
LOG(INFO) << "tf session warmming up..."; | |
gettimeofday(&begint, NULL); | |
TF_SessionRun(tf_session_, | |
nullptr, // Run options. | |
&inputs_[0], &input_values[0], 3, // Input tensors, input tensor values, number of inputs. | |
&output_, &output_value, 1, // Output tensors, output tensor values, number of outputs. | |
nullptr, 0, // Target operations, number of targets. | |
nullptr, // Run metadata. | |
status // Output status. | |
); | |
gettimeofday(&endt, NULL); | |
interval = (endt.tv_sec - begint.tv_sec) + ((float)endt.tv_usec - begint.tv_usec) / 1000000; | |
LOG(INFO) << "tf session run time is "<< interval << " secs"; | |
delete_tensor(input_values[0]); | |
delete_tensor(input_values[1]); | |
delete_tensor(input_values[2]); | |
if (TF_GetCode(status) != TF_OK) { | |
TF_DeleteStatus(status); | |
delete_tensor(output_value); | |
LOG(INFO) << "tf session warm up failed"; | |
return; | |
} | |
TF_DeleteStatus(status); | |
delete_tensor(output_value); | |
LOG(INFO) << "tf session warm up success"; | |
}; | |
Sample::~Sample() | |
{ | |
if (tf_graph_) | |
TF_DeleteGraph(tf_graph_); | |
TF_Status* status = TF_NewStatus(); | |
TF_CloseSession(tf_session_, status); | |
if (TF_GetCode(status) != TF_OK) | |
LOG(INFO) << "ERROR: Unable to close session: " << TF_Message(status); | |
else | |
LOG(INFO) << "close session successfully"; | |
TF_DeleteSession(tf_session_, status); | |
if (TF_GetCode(status) != TF_OK) | |
LOG(INFO) << "ERROR: Unable to delete session: " << TF_Message(status); | |
else | |
LOG(INFO) << "delete session successfully"; | |
TF_DeleteStatus(status); | |
} | |
void Sample::compute(const cv::Mat& cv_img, | |
const std::vector<float>& keypoints, | |
UMat& descriptors) | |
{ | |
// code prepare input data | |
... | |
TF_Status* status = TF_NewStatus(); | |
size_t kpts_num = keypoints.size() / 2; | |
TF_Tensor* input_values[3]; | |
std::int64_t local_feature_shape[3] = {1, 1000, 64}; | |
std::int64_t region_feature_shape[4] = {1, 125, 125, 256}; | |
std::int64_t kpts_shape[3] = {1, 1000, 2}; | |
gettimeofday(&begint, NULL); | |
input_values[0] = create_tensor(TF_FLOAT, local_feature_shape, 3, | |
data_ptr1, kpts_num * 64 * sizeof(float)); | |
input_values[1] = create_tensor(TF_FLOAT, region_feature_shape, 4, | |
data_ptr2, output_shape.count() * sizeof(float)); | |
input_values[2] = create_tensor(TF_FLOAT, kpts_shape, 3, | |
data_ptr3, keypoints.size() * sizeof(float)); | |
gettimeofday(&endt, NULL); | |
interval = (endt.tv_sec - begint.tv_sec) + ((float)endt.tv_usec - begint.tv_usec) / 1000000; | |
LOG(INFO) << "prepare input time is "<< interval << " secs"; | |
TF_Tensor* output_value; | |
LOG(INFO) << "tf session running..."; | |
gettimeofday(&begint, NULL); | |
TF_SessionRun(tf_session_, | |
nullptr, // Run options. | |
&inputs_[0], &input_values[0], 3, // Input tensors, input tensor values, number of inputs. | |
&output_, &output_value, 1, // Output tensors, output tensor values, number of outputs. | |
nullptr, 0, // Target operations, number of targets. | |
nullptr, // Run metadata. | |
status // Output status. | |
); | |
gettimeofday(&endt, NULL); | |
interval = (endt.tv_sec - begint.tv_sec) + ((float)endt.tv_usec - begint.tv_usec) / 1000000; | |
LOG(INFO) << "tf session run time is "<< interval << " secs"; | |
LOG(INFO) << "tf session finish..."; | |
delete_tensor(input_values[0]); | |
delete_tensor(input_values[1]); | |
delete_tensor(input_values[2]); | |
if (TF_GetCode(status) != TF_OK) { | |
TF_DeleteStatus(status); | |
delete_tensor(output_value); | |
return; | |
} | |
CHECK_EQ(TF_TensorType(output_value), TF_FLOAT) << "output type error"; | |
LOG(INFO) << "result num dims:" << TF_NumDims(output_value); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment