Skip to content

Instantly share code, notes, and snippets.

@gordinmitya
Created October 22, 2020 16:27
Show Gist options
  • Save gordinmitya/96a1b041bea18add4ec2e31907d11100 to your computer and use it in GitHub Desktop.
Save gordinmitya/96a1b041bea18add4ec2e31907d11100 to your computer and use it in GitHub Desktop.
def env_info():
import sys
print('python version:', sys.version)
import torch
print('PyTorch version:', torch.__version__)
import coremltools
print('coremltools version:', coremltools.__version__)
# make a simpliest super-resolution model
def export():
print('='*24, 'export', '='*24)
import torch
from torch import nn
from torch.nn import functional as F
class SuperNet(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return F.interpolate(x, scale_factor=(2, 2))
model = SuperNet().eval()
# doesn't matter that is the image size here, can be just anything
traced = torch.jit.trace(model, torch.rand((1, 3, 1, 1)))
torch.jit.save(traced, 'traced.pt')
INPUT_NODE = 'input'
OUTPUT_NODE = '30'
MODEL_DEFAULT_INPUT_SIZE = (256, 256)
IMAGE_SIZE = (513, 517)
def convert():
global OUTPUT_NODE
print('='*24, 'convert', '='*24)
import torch
import coremltools as ct
from coremltools.models.neural_network import flexible_shape_utils
import coremltools.proto.FeatureTypes_pb2 as ft
from coremltools.models.neural_network.builder import NeuralNetworkBuilder
traced = torch.jit.load('traced.pt')
# The way from the documentation just throw: TypeError: unsupported operand type(s) for +: 'NoneType' and 'float'
# input_shape = ct.Shape(shape=[1, 3, ct.RangeDim(256, 1024, 256, 'w'), ct.RangeDim(256, 1024, 256, 'h')])
# Ok, lets convert with fixed shape and add flexibility later
input_shape = ct.Shape(shape=[1, 3, 256, 256])
converted = ct.convert(
traced,
inputs=[ct.TensorType(name=INPUT_NODE, shape=input_shape)])
# doesn't matter how to get spec, result the same = crash
# converted.save('coreml.mlmodel')
# spec = ct.utils.load_spec('coreml.mlmodel')
spec = converted.get_spec()
OUTPUT_NODE = spec.description.output[0].name
# make input flexibility
# size_range = flexible_shape_utils.NeuralNetworkImageSizeRange()
# size_range.add_height_range((128, 512)) # tried with upper_bound=-1, doesn't work
# size_range.add_width_range((128, 512))
# flexible_shape_utils.update_image_size_range(spec, INPUT_NODE, size_range=size_range)
size_l = 128
size_u = 2048
flexible_shape_utils.set_multiarray_ndshape_range(spec, INPUT_NODE, [1,3,size_l,size_l], [1,3,size_u,size_u])
# The only working way to get flexible sizes is enumerate them all
# but it doesn't make much sense for super-resulution task in our case
# flexible_shape_utils.add_enumerated_image_sizes(spec, INPUT_NODE, sizes=[
# flexible_shape_utils.NeuralNetworkImageSize(x, x) for x in [128, 256, 512]
# ])
# make output as an image + flexibility
feature = flexible_shape_utils._get_feature(spec, OUTPUT_NODE)
feature.type.imageType.colorSpace = ft.ImageFeatureType.RGB
# can specify output size but it doesn't change anything
feature.type.imageType.width = 512
feature.type.imageType.height = 512
size_range = flexible_shape_utils.NeuralNetworkImageSizeRange()
size_range.add_height_range((256, -1))
size_range.add_width_range((256, -1))
flexible_shape_utils.update_image_size_range(spec, feature_name=OUTPUT_NODE, size_range=size_range)
# print our model
builder = NeuralNetworkBuilder(spec=spec)
print('inputs')
builder.inspect_input_features()
print('model')
builder.inspect_layers(verbose=True)
print('outputs')
builder.inspect_output_features()
updated = ct.models.MLModel(spec)
updated.save('coreml.mlmodel')
def test():
print('='*24, 'test', '='*24)
import coremltools as ct
import PIL
from PIL import Image
import numpy as np
print('input image size=', IMAGE_SIZE)
arr = np.zeros([IMAGE_SIZE[0], IMAGE_SIZE[1], 3], dtype=np.uint8)
img = Image.fromarray(arr)
data = np.array(img, dtype=np.float)
data = np.expand_dims(data, 0)
data = np.transpose(data, (0, 3, 1, 2))
print(data.shape)
model = ct.models.MLModel('coreml.mlmodel')
res = model.predict({INPUT_NODE: data})[OUTPUT_NODE]
print('result image size=', res.size)
env_info()
export()
convert()
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment