Skip to content

Instantly share code, notes, and snippets.

@jackersson
Created October 15, 2019 14:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jackersson/cea7ae76bcb980bba1ab420f0f696be3 to your computer and use it in GitHub Desktop.
Save jackersson/cea7ae76bcb980bba1ab420f0f696be3 to your computer and use it in GitHub Desktop.
import os
import tensorflow as tf
from spyglass.model import model_loader
from sales_zone.tf_models import ObjectDetector
def get_model_flops(model_cls, model_config: dict):
with model_loader(model_cls,
model_config=model_config) as model:
graph = model._session.graph
flops = tf.profiler.profile(graph, run_meta=tf.RunMetadata(),
cmd='op', options=tf.profiler.ProfileOptionBuilder.float_operation())
return flops.total_float_ops
def format_model_flops(model_config: dict, flops: int) -> str:
model_name = os.path.basename(model_config['weights'])
return f"<{model_name}> : {flops} flops"
model_config = {"weights": "data/models/person_detector_from_above/resnet-fpn-768x512_top-view-obi_fp32.pb"}
print(format_model_flops(model_config, flops=get_model_flops(ObjectDetector, model_config)))
model_config = {"weights": "data/models/person_detector_from_side/mobilenet-fpn-768x512_ms-coco_tf32.pb"}
print(format_model_flops(model_config, flops=get_model_flops(ObjectDetector, model_config)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment