This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import coremltools | |
model = coremltools.converters.keras.convert( | |
model.h5', | |
input_names=['image'], | |
output_names=['output'], | |
class_labels=["0", "1"], | |
image_input_names='image', | |
red_bias = -1, | |
green_bias = -1, | |
blue_bias = -1, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
let model = model() | |
let inputImage = input.resize(to: CGSize(width: 224, height: 224)) | |
guard let features = inputImage?.pixelBuffer() else { | |
complete(nil, "Error while creating the pixel buffer.") | |
return | |
} | |
guard let result = try? model.prediction(image: features) else { | |
complete(nil, "Error while performing the prediction.") | |
return | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pip3 install torch torchvision tensorflow onnx onnx-tf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torchvision import datasets, transforms | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(1, 20, 5, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.autograd import Variable | |
trained_model = Net() | |
trained_model.load_state_dict(torch.load('output/mnist_cnn.pt')) | |
# Shape of input to the model | |
dummy_input = Variable(torch.randn(1, 1, 28, 28)) | |
# Export the trained model to ONNX | |
torch.onnx.export(trained_model, dummy_input, "output/mnist.onnx") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import onnx | |
from onnx_tf.backend import prepare | |
model = onnx.load('output/mnist.onnx') | |
tf_rep = prepare(model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from IPython.display import display | |
from PIL import Image | |
img = Image.open('images/five.png').resize((28, 28)).convert('L') | |
display(img) | |
output = tf_rep.run(np.asarray(img, dtype=np.float32)[np.newaxis, np.newaxis, :, :]) | |
print('The digit is classified as ', np.argmax(output)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const session = new onnx.InferenceSession(); | |
await session.loadModel("model.onnx"); | |
const prediction = await session.run([inputTensor]); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = FixResNet50(models.resnet.Bottleneck, [3, 4, 6, 3]) | |
from torch.utils.model_zoo import load_url as load_state_dict_from_url | |
state_dict = load_state_dict_from_url(model_url, | |
progress=True) | |
model.load_state_dict(state_dict) | |
from torch.autograd import Variable | |
dummy_input = Variable(torch.randn(1, 3, 224, 224)) | |
torch.onnx.export(model, dummy_input, "model.onnx") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
async function base64ToImg(b64string){ | |
const binaryData = Buffer.from(b64string, 'base64').toString('binary'); | |
const imgUrl = "/tmp/out.jpg" | |
await writeFile(imgUrl, binaryData, "binary"); | |
return imgUrl; | |
} |