Skip to content

Instantly share code, notes, and snippets.

@ruihe774
Last active March 1, 2024 18:48
Show Gist options
  • Save ruihe774/62fed775b04cb18e59b1c0fbbdc095c6 to your computer and use it in GitHub Desktop.
Save ruihe774/62fed775b04cb18e59b1c0fbbdc095c6 to your computer and use it in GitHub Desktop.
Convert guesslang model to onnx
# %%
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
# %%
model = tf.saved_model.load("model")
predict = model.signatures["predict"]
# %%
class HyperParameter:
"""Model hyper parameters"""
BATCH_SIZE = 100
NB_TOKENS = 10000
VOCABULARY_SIZE = 5000
EMBEDDING_SIZE = max(10, int(VOCABULARY_SIZE**0.5))
DNN_HIDDEN_UNITS = [512, 32]
DNN_DROPOUT = 0.5
N_GRAM = 2
labels = ['asm', 'bat', 'c', 'cs', 'cpp', 'clj', 'cmake', 'cbl',
'coffee', 'css', 'csv', 'dart', 'dm', 'dockerfile', 'ex',
'erl', 'f90', 'go', 'groovy', 'hs', 'html', 'ini',
'java', 'js', 'json', 'jl', 'kt', 'lisp', 'lua',
'makefile', 'md', 'matla', 'mm', 'ml', 'pas', 'pm',
'php', 'ps1', 'prolog', 'py', 'r', 'r', 'rs', 'scala',
'sh', 'sql', 'swift', 'tex', 'toml', 'ts', 'v', 'vba',
'xml', 'yaml']
categorical_column = tf.feature_column.categorical_column_with_hash_bucket(
key='content',
hash_bucket_size=HyperParameter.VOCABULARY_SIZE,
)
dense_column = tf.feature_column.embedding_column(
categorical_column=categorical_column,
dimension=HyperParameter.EMBEDDING_SIZE,
)
estimator = tf.estimator.DNNLinearCombinedClassifier(
linear_feature_columns=[categorical_column],
dnn_feature_columns=[dense_column],
dnn_hidden_units=HyperParameter.DNN_HIDDEN_UNITS,
dnn_dropout=HyperParameter.DNN_DROPOUT,
n_classes=len(labels),
)
# %%
def preprocess_texts(data):
c = tf.strings.bytes_split(data)
c = c.to_tensor(shape=(tf.size(data), HyperParameter.NB_TOKENS + 1))
d1 = tf.slice(c, [0, 0], [-1, HyperParameter.NB_TOKENS])
d2 = tf.slice(c, [0, 1], [-1, HyperParameter.NB_TOKENS])
e = tf.stack([d1, d2], -1)
r = tf.strings.reduce_join(e, 2, separator=' ')
r = tf.where(tf.equal(r, ' '), '', r)
return r
# %%
py_code = open(__file__).read()
go_code = "package main\n\nimport (\n\t\"fmt\"\n\t\"math/rand\"\n\t\"time\"\n)\n\nconst (\n\top_num = iota\n\top_add\n\top_sub\n\top_mul\n\top_div\n)\n\ntype frac struct {\n\tnum, denom int\n}\n\n// Expression: can either be a single number, or a result of binary\n// operation from left and right node\ntype Expr struct {\n\top int\n\tleft, right *Expr\n\tvalue frac\n}\n\nvar n_cards = 4\nvar goal = 24\nvar digit_range = 9\n\nfunc (x *Expr) String() string {\n\tif x.op == op_num {\n\t\treturn fmt.Sprintf(\"%d\", x.value.num)\n\t}\n\n\tvar bl1, br1, bl2, br2, opstr string\n\tswitch {\n\tcase x.left.op == op_num:\n\tcase x.left.op >= x.op:\n\tcase x.left.op == op_add && x.op == op_sub:\n\t\tbl1, br1 = \"\", \"\"\n\tdefault:\n\t\tbl1, br1 = \"(\", \")\"\n\t}\n\n\tif x.right.op == op_num || x.op < x.right.op {\n\t\tbl2, br2 = \"\", \"\"\n\t} else {\n\t\tbl2, br2 = \"(\", \")\"\n\t}\n\n\tswitch {\n\tcase x.op == op_add:\n\t\topstr = \" + \"\n\tcase x.op == op_sub:\n\t\topstr = \" - \"\n\tcase x.op == op_mul:\n\t\topstr = \" * \"\n\tcase x.op == op_div:\n\t\topstr = \" / \"\n\t}\n\n\treturn bl1 + x.left.String() + br1 + opstr +\n\t\tbl2 + x.right.String() + br2\n}\n\nfunc expr_eval(x *Expr) (f frac) {\n\tif x.op == op_num {\n\t\treturn x.value\n\t}\n\n\tl, r := expr_eval(x.left), expr_eval(x.right)\n\n\tswitch x.op {\n\tcase op_add:\n\t\tf.num = l.num*r.denom + l.denom*r.num\n\t\tf.denom = l.denom * r.denom\n\t\treturn\n\n\tcase op_sub:\n\t\tf.num = l.num*r.denom - l.denom*r.num\n\t\tf.denom = l.denom * r.denom\n\t\treturn\n\n\tcase op_mul:\n\t\tf.num = l.num * r.num\n\t\tf.denom = l.denom * r.denom\n\t\treturn\n\n\tcase op_div:\n\t\tf.num = l.num * r.denom\n\t\tf.denom = l.denom * r.num\n\t\treturn\n\t}\n\treturn\n}\n\nfunc solve(ex_in []*Expr) bool {\n\t// only one expression left, meaning all numbers are arranged into\n\t// a binary tree, so evaluate and see if we get 24\n\tif len(ex_in) == 1 {\n\t\tf := expr_eval(ex_in[0])\n\t\tif f.denom != 0 && f.num == f.denom*goal {\n\t\t\tfmt.Println(ex_in[0].String())\n\t\t\treturn true\n\t\t}\n\t\treturn false\n\t}\n\n\tvar node Expr\n\tex := make([]*Expr, len(ex_in)-1)\n\n\t// try to combine a pair of expressions into one, thus reduce\n\t// the list length by 1, and recurse down\n\tfor i := range ex {\n\t\tcopy(ex[i:len(ex)], ex_in[i+1:len(ex_in)])\n\n\t\tex[i] = &node\n\t\tfor j := i + 1; j < len(ex_in); j++ {\n\t\t\tnode.left = ex_in[i]\n\t\t\tnode.right = ex_in[j]\n\n\t\t\t// try all 4 operators\n\t\t\tfor o := op_add; o <= op_div; o++ {\n\t\t\t\tnode.op = o\n\t\t\t\tif solve(ex) {\n\t\t\t\t\treturn true\n\t\t\t\t}\n\t\t\t}\n\n\t\t\t// also - and / are not commutative, so swap arguments\n\t\t\tnode.left = ex_in[j]\n\t\t\tnode.right = ex_in[i]\n\n\t\t\tnode.op = op_sub\n\t\t\tif solve(ex) {\n\t\t\t\treturn true\n\t\t\t}\n\n\t\t\tnode.op = op_div\n\t\t\tif solve(ex) {\n\t\t\t\treturn true\n\t\t\t}\n\n\t\t\tif j < len(ex) {\n\t\t\t\tex[j] = ex_in[j]\n\t\t\t}\n\t\t}\n\t\tex[i] = ex_in[i]\n\t}\n\treturn false\n}\n\nfunc main() {\n\tcards := make([]*Expr, n_cards)\n\trand.Seed(time.Now().Unix())\n\n\tfor k := 0; k < 10; k++ {\n\t\tfor i := 0; i < n_cards; i++ {\n\t\t\tcards[i] = &Expr{op_num, nil, nil,\n\t\t\t\tfrac{rand.Intn(digit_range-1) + 1, 1}}\n\t\t\tfmt.Printf(\" %d\", cards[i].value.num)\n\t\t}\n\t\tfmt.Print(\": \")\n\t\tif !solve(cards) {\n\t\t\tfmt.Println(\"No solution\")\n\t\t}\n\t}\n}\n"
def input_fn():
return {"content": preprocess_texts(tf.constant([go_code]))}, tf.constant([1])
estimator.train(input_fn, steps=1)
# %%
def serving_input_receiver_fn() -> tf.estimator.export.ServingInputReceiver:
"""Function to serve model for predictions."""
content = tf.compat.v1.placeholder(tf.string, [None])
receiver_tensors = {'content': content}
features = {'content': preprocess_texts(content)}
return tf.estimator.export.ServingInputReceiver(
receiver_tensors=receiver_tensors,
features=features,
)
# %%
import shutil
shutil.rmtree("new", True)
path = estimator.export_saved_model("new", serving_input_receiver_fn).decode()
shutil.rmtree(path + "/variables")
shutil.copytree("model/variables", path + "/variables")
# %%
new_model = tf.saved_model.load(path)
new_predict = new_model.signatures["predict"]
new_predict
# %%
stdans = predict(tf.constant([py_code, go_code]))
# %%
@tf.function
def f(content):
return new_predict(tf.expand_dims(content, 0))["probabilities"][0]
# %%
assert (tf.argmax(f(tf.constant(py_code))) == stdans["class_ids"][0][0]).numpy()
# %%
assert (tf.argmax(f(tf.constant(go_code))) == stdans["class_ids"][1][0]).numpy()
# %%
import tf2onnx.convert
from tf2onnx.utils import make_opsetid
# %%
tf2onnx.convert.from_function(
f,
input_signature=[tf.TensorSpec((), tf.string)],
output_path="model.onnx",
opset=18,
extra_opset=[make_opsetid("ai.onnx.contrib", 1)],
)
# %%
import onnxruntime as ort
so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
so.optimized_model_filepath = "optimized.onnx"
so.register_custom_ops_library("/opt/homebrew/lib/libortextensions.0.9.0.dylib")
session = ort.InferenceSession("model.onnx", so)
so
# %%
import onnx
import onnx.helper
model_proto = onnx.load_model("optimized.onnx")
with open("optimized.graph", "w") as f:
f.write(repr(model_proto))
# %%
def quantize(initializer):
data = tf.constant(initializer.raw_data)
data = tf.io.decode_raw(data, tf.float32)
q = tf.quantization.quantize(data, tf.reduce_min(data), tf.reduce_max(data), tf.qint8)
output = tf.cast(q.output, tf.int8)
scale = (q.output_max - q.output_min) / 255
zero_point = tf.cast(tf.round(-q.output_min / scale + 128), tf.int8)
return output, scale, zero_point
# %%
new_initializer = []
for initializer in model_proto.graph.initializer:
if initializer.data_type == 1 and len(initializer.dims) > 1:
output, scale, zero_point = quantize(initializer)
initializer.raw_data = bytes(output)
initializer.data_type = onnx.TensorProto.INT8
new_name = initializer.name + "/Quantized"
new_initializer.append(onnx.helper.make_tensor(initializer.name + "/Scale", onnx.TensorProto.FLOAT, [], tf.expand_dims(scale, 0)))
if zero_point:
new_initializer.append(onnx.helper.make_tensor(initializer.name + "/ZeroPoint", onnx.TensorProto.INT8, [], tf.expand_dims(zero_point, 0)))
model_proto.graph.node.append(onnx.helper.make_node(
"DequantizeLinear",
[new_name, initializer.name + "/Scale"] + ([initializer.name + "/ZeroPoint"] if zero_point else []),
[initializer.name]
))
initializer.name = new_name
model_proto.graph.initializer.extend(new_initializer)
# %%
with open("quantized.graph", "w") as f:
f.write(repr(model_proto))
onnx.save_model(model_proto, "quantized.onnx")
# %%
from onnxruntime.tools.convert_onnx_models_to_ort import convert_onnx_models_to_ort, OptimizationStyle
from pathlib import Path
convert_onnx_models_to_ort(
Path("quantized.onnx"),
Path("ort"),
[OptimizationStyle.Fixed],
Path("/opt/homebrew/lib/libortextensions.0.9.0.dylib"),
enable_type_reduction=True
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment