Skip to content

Instantly share code, notes, and snippets.

@9nut
Created March 26, 2021 21:04
Show Gist options
  • Save 9nut/f95bb4cbe9c223e9f73a9e06429f71ac to your computer and use it in GitHub Desktop.
Save 9nut/f95bb4cbe9c223e9f73a9e06429f71ac to your computer and use it in GitHub Desktop.
package main
import (
"log"
tf "github.com/galeone/tensorflow/tensorflow/go"
tg "github.com/galeone/tfgo"
"github.com/galeone/tfgo/image"
)
/*
$ saved_model_cli show --all --dir ~/TFModels/centernet_hourglass_512x512_kpts_1/ 2>/dev/null
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['__saved_model_init_op']:
The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
outputs['__saved_model_init_op'] tensor_info:
dtype: DT_INVALID
shape: unknown_rank
name: NoOp
Method name is:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['input_tensor'] tensor_info:
dtype: DT_UINT8
shape: (1, -1, -1, 3)
name: serving_default_input_tensor:0
The given SavedModel SignatureDef contains the following output(s):
outputs['detection_boxes'] tensor_info:
dtype: DT_FLOAT
shape: (1, 100, 4)
name: StatefulPartitionedCall:0
outputs['detection_classes'] tensor_info:
dtype: DT_FLOAT
shape: (1, 100)
name: StatefulPartitionedCall:1
outputs['detection_keypoint_scores'] tensor_info:
dtype: DT_FLOAT
shape: (1, 100, 17)
name: StatefulPartitionedCall:2
outputs['detection_keypoints'] tensor_info:
dtype: DT_FLOAT
shape: (1, 100, 17, 2)
name: StatefulPartitionedCall:3
outputs['detection_scores'] tensor_info:
dtype: DT_FLOAT
shape: (1, 100)
name: StatefulPartitionedCall:4
outputs['num_detections'] tensor_info:
dtype: DT_FLOAT
shape: (1)
name: StatefulPartitionedCall:5
Method name is: tensorflow/serving/predict
Defined Functions:
Function Name: '__call__'
Option #1
Callable with:
Argument #1
input_tensor: TensorSpec(shape=(1, None, None, 3), dtype=tf.uint8, name='input_tensor')
$
*/
func inferCenterNet() {
model := tg.LoadModel("./centernet_hourglass_512x512_kpts_1", []string{"serve"}, nil)
root := tg.NewRoot()
rgbimg := image.Read(root, "D1.jpg", 3)
rgbimg = rgbimg.Clone().ResizeArea(image.Size{Height: 512, Width: 512})
imgin := tg.Batchify(root, []tf.Output{rgbimg.Value()})
input := tg.Exec(root, []tf.Output{imgin}, nil, &tf.SessionOptions{})
tensor, err := tf.NewTensor(input[0].Value())
if err != nil {
log.Fatalln(err)
}
results := model.Exec(
[]tf.Output{
model.Op("StatefulPartitionedCall", 0), // detection_boxes
// model.Op("StatefulPartitionedCall", 1), // detection_classes
// model.Op("StatefulPartitionedCall", 4), // detection_scores
},
map[tf.Output]*tf.Tensor{model.Op("serving_default_input_tensor", 0): tensor},
)
log.Println(results[0].Value())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment