Skip to content

Instantly share code, notes, and snippets.

@vmarkovtsev
Last active January 29, 2021 11:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save vmarkovtsev/72b26e9ad1f212bc05aa4cf2fd4812e0 to your computer and use it in GitHub Desktop.
Save vmarkovtsev/72b26e9ad1f212bc05aa4cf2fd4812e0 to your computer and use it in GitHub Desktop.
# Remember to remove the old model!
# rm *.tflite
import json
import urllib.request
def generate_edgetpu_model(log: logging.Logger, images_shape: Tuple[int], func: callable, name: str):
"""Convert tf.function to Edge TPU model."""
def gen_input_samples():
yield [np.zeros(images_shape, np.float32)]
yield [np.ones(images_shape, np.float32) * 255]
log.info("Generating the quantized TensorFlow Lite model")
converter = tf.lite.TFLiteConverter.from_concrete_functions([func.get_concrete_function()])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = gen_input_samples
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
fn = "%s.tflite" % name
with open(fn, "wb") as fout:
fout.write(tflite_model)
log.info("Wrote it to %s", fn)
if not Path("schema.fbs").exists():
log.info("schema.fbs was not found, downloading")
urllib.request.urlretrieve(
"https://github.com/tensorflow/tensorflow/raw/master/tensorflow/lite/schema/schema.fbs",
"schema.fbs")
log.info("Downloaded schema.fbs")
log.info("Converting the model from binary flatbuffers to JSON")
echo_run("flatc", "-t", "--strict-json", "--defaults-json", "schema.fbs", "--", fn)
log.info("Patching the model in JSON")
fn_json = str(Path(fn).with_suffix(".json"))
with open(fn_json) as fin:
model = json.load(fin)
# Let the hell begin. Erase all opcodes except DEPTHWISE_CONV_2D.
conv_opcode = -1
new_opcodes = []
for i, c in enumerate(model["operator_codes"]):
if c["builtin_code"] == "DEPTHWISE_CONV_2D":
new_opcodes.append(c)
conv_opcode = i
assert conv_opcode >= 0
model["operator_codes"] = new_opcodes
# Fix the tensor dtypes which are int8 instead of uint8.
# Also remove the multi-channel quantization which is not supported on Edge TPU.
graph = model["subgraphs"][0]
new_tensors = []
index_map = {}
for i, t in enumerate(graph["tensors"]):
if t["type"] == "FLOAT32":
continue
if t["type"] == "INT8":
t["type"] = "UINT8"
t["quantization"]["zero_point"][0] = 0
t["quantization"]["scale"] = [t["quantization"]["scale"][0]]
t["quantization"]["zero_point"] = [t["quantization"]["zero_point"][0]]
t["quantization"]["quantized_dimension"] = 0
index_map[i] = len(new_tensors)
new_tensors.append(t)
graph["tensors"] = new_tensors
# Update the tensor indexes in rhe ops.
new_ops = []
for op in graph["operators"]:
if op["opcode_index"] != conv_opcode:
continue
op["outputs"] = [index_map[i] for i in op["outputs"]]
op["inputs"] = [index_map[i] for i in op["inputs"]]
new_ops.append(op)
graph["operators"] = new_ops
# Update the global input and output tensor indexes.
graph["inputs"][0] = new_ops[0]["inputs"][0]
graph["outputs"][0] = new_ops[0]["outputs"][0]
model["subgraphs"][0] = graph
with open(fn_json, "w") as fout:
json.dump(model, fout, indent=4)
log.info("Generating the binary flatbuffers model from JSON")
echo_run("flatc", "-b", "schema.fbs", fn_json)
log.info("Compiling the Edge TPU model")
echo_run("edgetpu_compiler", "-s", fn)
Path(fn_json).unlink()
Path(fn).with_name(Path(fn).stem + "_edgetpu.log").unlink()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment