Last active
January 29, 2021 11:30
-
-
Save vmarkovtsev/72b26e9ad1f212bc05aa4cf2fd4812e0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Remember to remove the old model! | |
# rm *.tflite | |
import json | |
import urllib.request | |
def generate_edgetpu_model(log: logging.Logger, images_shape: Tuple[int], func: callable, name: str): | |
"""Convert tf.function to Edge TPU model.""" | |
def gen_input_samples(): | |
yield [np.zeros(images_shape, np.float32)] | |
yield [np.ones(images_shape, np.float32) * 255] | |
log.info("Generating the quantized TensorFlow Lite model") | |
converter = tf.lite.TFLiteConverter.from_concrete_functions([func.get_concrete_function()]) | |
converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
converter.representative_dataset = gen_input_samples | |
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] | |
converter.inference_input_type = tf.uint8 | |
converter.inference_output_type = tf.uint8 | |
tflite_model = converter.convert() | |
fn = "%s.tflite" % name | |
with open(fn, "wb") as fout: | |
fout.write(tflite_model) | |
log.info("Wrote it to %s", fn) | |
if not Path("schema.fbs").exists(): | |
log.info("schema.fbs was not found, downloading") | |
urllib.request.urlretrieve( | |
"https://github.com/tensorflow/tensorflow/raw/master/tensorflow/lite/schema/schema.fbs", | |
"schema.fbs") | |
log.info("Downloaded schema.fbs") | |
log.info("Converting the model from binary flatbuffers to JSON") | |
echo_run("flatc", "-t", "--strict-json", "--defaults-json", "schema.fbs", "--", fn) | |
log.info("Patching the model in JSON") | |
fn_json = str(Path(fn).with_suffix(".json")) | |
with open(fn_json) as fin: | |
model = json.load(fin) | |
# Let the hell begin. Erase all opcodes except DEPTHWISE_CONV_2D. | |
conv_opcode = -1 | |
new_opcodes = [] | |
for i, c in enumerate(model["operator_codes"]): | |
if c["builtin_code"] == "DEPTHWISE_CONV_2D": | |
new_opcodes.append(c) | |
conv_opcode = i | |
assert conv_opcode >= 0 | |
model["operator_codes"] = new_opcodes | |
# Fix the tensor dtypes which are int8 instead of uint8. | |
# Also remove the multi-channel quantization which is not supported on Edge TPU. | |
graph = model["subgraphs"][0] | |
new_tensors = [] | |
index_map = {} | |
for i, t in enumerate(graph["tensors"]): | |
if t["type"] == "FLOAT32": | |
continue | |
if t["type"] == "INT8": | |
t["type"] = "UINT8" | |
t["quantization"]["zero_point"][0] = 0 | |
t["quantization"]["scale"] = [t["quantization"]["scale"][0]] | |
t["quantization"]["zero_point"] = [t["quantization"]["zero_point"][0]] | |
t["quantization"]["quantized_dimension"] = 0 | |
index_map[i] = len(new_tensors) | |
new_tensors.append(t) | |
graph["tensors"] = new_tensors | |
# Update the tensor indexes in rhe ops. | |
new_ops = [] | |
for op in graph["operators"]: | |
if op["opcode_index"] != conv_opcode: | |
continue | |
op["outputs"] = [index_map[i] for i in op["outputs"]] | |
op["inputs"] = [index_map[i] for i in op["inputs"]] | |
new_ops.append(op) | |
graph["operators"] = new_ops | |
# Update the global input and output tensor indexes. | |
graph["inputs"][0] = new_ops[0]["inputs"][0] | |
graph["outputs"][0] = new_ops[0]["outputs"][0] | |
model["subgraphs"][0] = graph | |
with open(fn_json, "w") as fout: | |
json.dump(model, fout, indent=4) | |
log.info("Generating the binary flatbuffers model from JSON") | |
echo_run("flatc", "-b", "schema.fbs", fn_json) | |
log.info("Compiling the Edge TPU model") | |
echo_run("edgetpu_compiler", "-s", fn) | |
Path(fn_json).unlink() | |
Path(fn).with_name(Path(fn).stem + "_edgetpu.log").unlink() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment