Skip to content

Instantly share code, notes, and snippets.

@jonpsy
Last active November 30, 2021 06:50
Show Gist options
  • Save jonpsy/77737780f90a1ebf52507d3622ea6ed7 to your computer and use it in GitHub Desktop.
Save jonpsy/77737780f90a1ebf52507d3622ea6ed7 to your computer and use it in GitHub Desktop.
from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb
# Creates model info.
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "Enhanced Super Resolution GAN for super resolution."
model_meta.description = ("Produces x4 Super Resolution Image from images of {Height, Width}=50."
"Works best on Bicubically downsampled images.")
model_meta.version = "v1"
model_meta.author = "TensorFlow"
model_meta.license = ("Apache License. Version 2.0 "
"http://www.apache.org/licenses/LICENSE-2.0.")
# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()
# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()
input_meta.name = "Input image."
input_meta.description = (
"Input image to be transformed. The expected image should be of {0} x {1}, with "
"three channels (red, blue, and green) per pixel. Each value in the "
"tensor is a single byte between 0 and 255.".format(50, 50))
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
_metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
_metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [0.]
input_normalization.options.std = [1.]
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [255]
input_stats.min = [0]
input_meta.stats = input_stats
output_meta.name = "Image."
output_meta.description = "Resolution enhanced image."
output_meta.content = _metadata_fb.ContentT()
output_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
output_meta.content.contentProperties.colorSpace = (
_metadata_fb.ColorSpaceType.RGB)
output_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.ImageProperties)
output_normalization = _metadata_fb.ProcessUnitT()
output_normalization.optionsType = (
_metadata_fb.ProcessUnitOptions.NormalizationOptions)
output_normalization.options = _metadata_fb.NormalizationOptionsT()
output_normalization.options.mean = [0.]
output_normalization.options.std = [1.]
output_meta.processUnits = [output_normalization]
output_stats = _metadata_fb.StatsT()
output_stats.max = [255]
output_stats.min = [0]
output_meta.stats = output_stats
# Creates subgraph info.
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [output_meta]
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(
model_meta.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()
model_file ="esrgan_with_input_and_output_metadata.tflite"
populator = _metadata.MetadataPopulator.with_model_file(model_file)
populator.load_metadata_buffer(metadata_buf)
populator.populate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment