Last active
March 1, 2024 18:48
-
-
Save ruihe774/62fed775b04cb18e59b1c0fbbdc095c6 to your computer and use it in GitHub Desktop.
Convert guesslang model to onnx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import tensorflow as tf | |
tf.config.set_visible_devices([], 'GPU') | |
# %% | |
model = tf.saved_model.load("model") | |
predict = model.signatures["predict"] | |
# %% | |
class HyperParameter: | |
"""Model hyper parameters""" | |
BATCH_SIZE = 100 | |
NB_TOKENS = 10000 | |
VOCABULARY_SIZE = 5000 | |
EMBEDDING_SIZE = max(10, int(VOCABULARY_SIZE**0.5)) | |
DNN_HIDDEN_UNITS = [512, 32] | |
DNN_DROPOUT = 0.5 | |
N_GRAM = 2 | |
labels = ['asm', 'bat', 'c', 'cs', 'cpp', 'clj', 'cmake', 'cbl', | |
'coffee', 'css', 'csv', 'dart', 'dm', 'dockerfile', 'ex', | |
'erl', 'f90', 'go', 'groovy', 'hs', 'html', 'ini', | |
'java', 'js', 'json', 'jl', 'kt', 'lisp', 'lua', | |
'makefile', 'md', 'matla', 'mm', 'ml', 'pas', 'pm', | |
'php', 'ps1', 'prolog', 'py', 'r', 'r', 'rs', 'scala', | |
'sh', 'sql', 'swift', 'tex', 'toml', 'ts', 'v', 'vba', | |
'xml', 'yaml'] | |
categorical_column = tf.feature_column.categorical_column_with_hash_bucket( | |
key='content', | |
hash_bucket_size=HyperParameter.VOCABULARY_SIZE, | |
) | |
dense_column = tf.feature_column.embedding_column( | |
categorical_column=categorical_column, | |
dimension=HyperParameter.EMBEDDING_SIZE, | |
) | |
estimator = tf.estimator.DNNLinearCombinedClassifier( | |
linear_feature_columns=[categorical_column], | |
dnn_feature_columns=[dense_column], | |
dnn_hidden_units=HyperParameter.DNN_HIDDEN_UNITS, | |
dnn_dropout=HyperParameter.DNN_DROPOUT, | |
n_classes=len(labels), | |
) | |
# %% | |
def preprocess_texts(data): | |
c = tf.strings.bytes_split(data) | |
c = c.to_tensor(shape=(tf.size(data), HyperParameter.NB_TOKENS + 1)) | |
d1 = tf.slice(c, [0, 0], [-1, HyperParameter.NB_TOKENS]) | |
d2 = tf.slice(c, [0, 1], [-1, HyperParameter.NB_TOKENS]) | |
e = tf.stack([d1, d2], -1) | |
r = tf.strings.reduce_join(e, 2, separator=' ') | |
r = tf.where(tf.equal(r, ' '), '', r) | |
return r | |
# %% | |
py_code = open(__file__).read() | |
go_code = "package main\n\nimport (\n\t\"fmt\"\n\t\"math/rand\"\n\t\"time\"\n)\n\nconst (\n\top_num = iota\n\top_add\n\top_sub\n\top_mul\n\top_div\n)\n\ntype frac struct {\n\tnum, denom int\n}\n\n// Expression: can either be a single number, or a result of binary\n// operation from left and right node\ntype Expr struct {\n\top int\n\tleft, right *Expr\n\tvalue frac\n}\n\nvar n_cards = 4\nvar goal = 24\nvar digit_range = 9\n\nfunc (x *Expr) String() string {\n\tif x.op == op_num {\n\t\treturn fmt.Sprintf(\"%d\", x.value.num)\n\t}\n\n\tvar bl1, br1, bl2, br2, opstr string\n\tswitch {\n\tcase x.left.op == op_num:\n\tcase x.left.op >= x.op:\n\tcase x.left.op == op_add && x.op == op_sub:\n\t\tbl1, br1 = \"\", \"\"\n\tdefault:\n\t\tbl1, br1 = \"(\", \")\"\n\t}\n\n\tif x.right.op == op_num || x.op < x.right.op {\n\t\tbl2, br2 = \"\", \"\"\n\t} else {\n\t\tbl2, br2 = \"(\", \")\"\n\t}\n\n\tswitch {\n\tcase x.op == op_add:\n\t\topstr = \" + \"\n\tcase x.op == op_sub:\n\t\topstr = \" - \"\n\tcase x.op == op_mul:\n\t\topstr = \" * \"\n\tcase x.op == op_div:\n\t\topstr = \" / \"\n\t}\n\n\treturn bl1 + x.left.String() + br1 + opstr +\n\t\tbl2 + x.right.String() + br2\n}\n\nfunc expr_eval(x *Expr) (f frac) {\n\tif x.op == op_num {\n\t\treturn x.value\n\t}\n\n\tl, r := expr_eval(x.left), expr_eval(x.right)\n\n\tswitch x.op {\n\tcase op_add:\n\t\tf.num = l.num*r.denom + l.denom*r.num\n\t\tf.denom = l.denom * r.denom\n\t\treturn\n\n\tcase op_sub:\n\t\tf.num = l.num*r.denom - l.denom*r.num\n\t\tf.denom = l.denom * r.denom\n\t\treturn\n\n\tcase op_mul:\n\t\tf.num = l.num * r.num\n\t\tf.denom = l.denom * r.denom\n\t\treturn\n\n\tcase op_div:\n\t\tf.num = l.num * r.denom\n\t\tf.denom = l.denom * r.num\n\t\treturn\n\t}\n\treturn\n}\n\nfunc solve(ex_in []*Expr) bool {\n\t// only one expression left, meaning all numbers are arranged into\n\t// a binary tree, so evaluate and see if we get 24\n\tif len(ex_in) == 1 {\n\t\tf := expr_eval(ex_in[0])\n\t\tif f.denom != 0 && f.num == f.denom*goal {\n\t\t\tfmt.Println(ex_in[0].String())\n\t\t\treturn true\n\t\t}\n\t\treturn false\n\t}\n\n\tvar node Expr\n\tex := make([]*Expr, len(ex_in)-1)\n\n\t// try to combine a pair of expressions into one, thus reduce\n\t// the list length by 1, and recurse down\n\tfor i := range ex {\n\t\tcopy(ex[i:len(ex)], ex_in[i+1:len(ex_in)])\n\n\t\tex[i] = &node\n\t\tfor j := i + 1; j < len(ex_in); j++ {\n\t\t\tnode.left = ex_in[i]\n\t\t\tnode.right = ex_in[j]\n\n\t\t\t// try all 4 operators\n\t\t\tfor o := op_add; o <= op_div; o++ {\n\t\t\t\tnode.op = o\n\t\t\t\tif solve(ex) {\n\t\t\t\t\treturn true\n\t\t\t\t}\n\t\t\t}\n\n\t\t\t// also - and / are not commutative, so swap arguments\n\t\t\tnode.left = ex_in[j]\n\t\t\tnode.right = ex_in[i]\n\n\t\t\tnode.op = op_sub\n\t\t\tif solve(ex) {\n\t\t\t\treturn true\n\t\t\t}\n\n\t\t\tnode.op = op_div\n\t\t\tif solve(ex) {\n\t\t\t\treturn true\n\t\t\t}\n\n\t\t\tif j < len(ex) {\n\t\t\t\tex[j] = ex_in[j]\n\t\t\t}\n\t\t}\n\t\tex[i] = ex_in[i]\n\t}\n\treturn false\n}\n\nfunc main() {\n\tcards := make([]*Expr, n_cards)\n\trand.Seed(time.Now().Unix())\n\n\tfor k := 0; k < 10; k++ {\n\t\tfor i := 0; i < n_cards; i++ {\n\t\t\tcards[i] = &Expr{op_num, nil, nil,\n\t\t\t\tfrac{rand.Intn(digit_range-1) + 1, 1}}\n\t\t\tfmt.Printf(\" %d\", cards[i].value.num)\n\t\t}\n\t\tfmt.Print(\": \")\n\t\tif !solve(cards) {\n\t\t\tfmt.Println(\"No solution\")\n\t\t}\n\t}\n}\n" | |
def input_fn(): | |
return {"content": preprocess_texts(tf.constant([go_code]))}, tf.constant([1]) | |
estimator.train(input_fn, steps=1) | |
# %% | |
def serving_input_receiver_fn() -> tf.estimator.export.ServingInputReceiver: | |
"""Function to serve model for predictions.""" | |
content = tf.compat.v1.placeholder(tf.string, [None]) | |
receiver_tensors = {'content': content} | |
features = {'content': preprocess_texts(content)} | |
return tf.estimator.export.ServingInputReceiver( | |
receiver_tensors=receiver_tensors, | |
features=features, | |
) | |
# %% | |
import shutil | |
shutil.rmtree("new", True) | |
path = estimator.export_saved_model("new", serving_input_receiver_fn).decode() | |
shutil.rmtree(path + "/variables") | |
shutil.copytree("model/variables", path + "/variables") | |
# %% | |
new_model = tf.saved_model.load(path) | |
new_predict = new_model.signatures["predict"] | |
new_predict | |
# %% | |
stdans = predict(tf.constant([py_code, go_code])) | |
# %% | |
@tf.function | |
def f(content): | |
return new_predict(tf.expand_dims(content, 0))["probabilities"][0] | |
# %% | |
assert (tf.argmax(f(tf.constant(py_code))) == stdans["class_ids"][0][0]).numpy() | |
# %% | |
assert (tf.argmax(f(tf.constant(go_code))) == stdans["class_ids"][1][0]).numpy() | |
# %% | |
import tf2onnx.convert | |
from tf2onnx.utils import make_opsetid | |
# %% | |
tf2onnx.convert.from_function( | |
f, | |
input_signature=[tf.TensorSpec((), tf.string)], | |
output_path="model.onnx", | |
opset=18, | |
extra_opset=[make_opsetid("ai.onnx.contrib", 1)], | |
) | |
# %% | |
import onnxruntime as ort | |
so = ort.SessionOptions() | |
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED | |
so.optimized_model_filepath = "optimized.onnx" | |
so.register_custom_ops_library("/opt/homebrew/lib/libortextensions.0.9.0.dylib") | |
session = ort.InferenceSession("model.onnx", so) | |
so | |
# %% | |
import onnx | |
import onnx.helper | |
model_proto = onnx.load_model("optimized.onnx") | |
with open("optimized.graph", "w") as f: | |
f.write(repr(model_proto)) | |
# %% | |
def quantize(initializer): | |
data = tf.constant(initializer.raw_data) | |
data = tf.io.decode_raw(data, tf.float32) | |
q = tf.quantization.quantize(data, tf.reduce_min(data), tf.reduce_max(data), tf.qint8) | |
output = tf.cast(q.output, tf.int8) | |
scale = (q.output_max - q.output_min) / 255 | |
zero_point = tf.cast(tf.round(-q.output_min / scale + 128), tf.int8) | |
return output, scale, zero_point | |
# %% | |
new_initializer = [] | |
for initializer in model_proto.graph.initializer: | |
if initializer.data_type == 1 and len(initializer.dims) > 1: | |
output, scale, zero_point = quantize(initializer) | |
initializer.raw_data = bytes(output) | |
initializer.data_type = onnx.TensorProto.INT8 | |
new_name = initializer.name + "/Quantized" | |
new_initializer.append(onnx.helper.make_tensor(initializer.name + "/Scale", onnx.TensorProto.FLOAT, [], tf.expand_dims(scale, 0))) | |
if zero_point: | |
new_initializer.append(onnx.helper.make_tensor(initializer.name + "/ZeroPoint", onnx.TensorProto.INT8, [], tf.expand_dims(zero_point, 0))) | |
model_proto.graph.node.append(onnx.helper.make_node( | |
"DequantizeLinear", | |
[new_name, initializer.name + "/Scale"] + ([initializer.name + "/ZeroPoint"] if zero_point else []), | |
[initializer.name] | |
)) | |
initializer.name = new_name | |
model_proto.graph.initializer.extend(new_initializer) | |
# %% | |
with open("quantized.graph", "w") as f: | |
f.write(repr(model_proto)) | |
onnx.save_model(model_proto, "quantized.onnx") | |
# %% | |
from onnxruntime.tools.convert_onnx_models_to_ort import convert_onnx_models_to_ort, OptimizationStyle | |
from pathlib import Path | |
convert_onnx_models_to_ort( | |
Path("quantized.onnx"), | |
Path("ort"), | |
[OptimizationStyle.Fixed], | |
Path("/opt/homebrew/lib/libortextensions.0.9.0.dylib"), | |
enable_type_reduction=True | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment