Last active
February 1, 2024 20:40
-
-
Save HudsonGraeme/5fe39f26d68626586dac2fec609b30ea to your computer and use it in GitHub Desktop.
5fe39f26d68626586dac2fec609b30ea
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/HudsonGraeme/5fe39f26d68626586dac2fec609b30ea/notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "VOhZf5ompktU" | |
}, | |
"source": [ | |
"### Install required dependencies" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "m-28wslFI2wX" | |
}, | |
"outputs": [], | |
"source": [ | |
"!curl https://raw.githubusercontent.com/zkonduit/ezkl/main/install_ezkl_cli.sh | bash\n", | |
"!pip uninstall -y tensorflow\n", | |
"!pip install dataclasses\n", | |
"!pip install matplotlib\n", | |
"!pip install torch\n", | |
"!pip install numpy\n", | |
"!pip install requests\n", | |
"!pip install onnxruntime\n", | |
"!pip install onnx\n", | |
"!pip install torchvision\n", | |
"!pip uninstall -y typing_extensions\n", | |
"!pip install typing_extensions\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "P2iv9gZLpktV" | |
}, | |
"source": [ | |
"### Model Definition" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "-61wEcCP1OZ5" | |
}, | |
"outputs": [], | |
"source": [ | |
"\"\"\"\n", | |
"Reference: https://github.com/karpathy/nanoGPT\n", | |
"\"\"\"\n", | |
"\n", | |
"import json\n", | |
"import math\n", | |
"from dataclasses import dataclass\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"from torch.nn import functional as F\n", | |
"import sys\n", | |
"import os\n", | |
"\n", | |
"model_dir = \"./model/\"\n", | |
"\n", | |
"# Constructing paths using os.path.join\n", | |
"onnx_path = os.path.join(model_dir, \"network.onnx\")\n", | |
"input_path = os.path.join(model_dir, \"input.json\")\n", | |
"settings_path = os.path.join(model_dir, \"settings.json\")\n", | |
"srs_path = os.path.join(model_dir, \"kzg.srs\")\n", | |
"ezkl_path = os.path.join(model_dir, \"network.ezkl\")\n", | |
"pk_path = os.path.join(model_dir, \"pk.key\")\n", | |
"vk_path = os.path.join(model_dir, \"vk.key\")\n", | |
"witness_path = os.path.join(model_dir, \"witness.json\")\n", | |
"proof_path = os.path.join(model_dir, \"proof.proof\")\n", | |
"sol_path = os.path.join(model_dir, \"verif.sol\")\n", | |
"abi_path = os.path.join(model_dir, \"verif.abi\")\n", | |
"ezkl_binary = '/' + os.path.join('root', '.ezkl', 'ezkl')\n", | |
"\n", | |
"\n", | |
"def remove_non_ascii(s):\n", | |
" regex = re.compile(r\"\\x1b\\[([0-9]*;?[0-9]+)?[m|K|h]\")\n", | |
" return regex.sub(\"\", s)\n", | |
"\n", | |
"\n", | |
"def new_gelu(x):\n", | |
" \"\"\"\n", | |
" Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).\n", | |
" Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415\n", | |
" \"\"\"\n", | |
" return (\n", | |
" 0.5\n", | |
" * x\n", | |
" * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * x * x * x)))\n", | |
" )\n", | |
"\n", | |
"\n", | |
"\n", | |
"class LayerNorm(nn.Module):\n", | |
" \"\"\" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False \"\"\"\n", | |
"\n", | |
" def __init__(self, ndim, bias):\n", | |
" super().__init__()\n", | |
" self.weight = nn.Parameter(torch.ones(ndim))\n", | |
" self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None\n", | |
"\n", | |
" def forward(self, input):\n", | |
" return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)\n", | |
"\n", | |
"\n", | |
"class CausalSelfAttention(nn.Module):\n", | |
"\n", | |
" def __init__(self, config):\n", | |
" super().__init__()\n", | |
" assert config.n_embd % config.n_head == 0\n", | |
" # key, query, value projections for all heads, but in a batch\n", | |
" self.c_attn = nn.Linear(\n", | |
" config.n_embd, 3 * config.n_embd, bias=config.bias)\n", | |
" # output projection\n", | |
" self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n", | |
" # regularization\n", | |
" self.attn_dropout = nn.Dropout(config.dropout)\n", | |
" self.resid_dropout = nn.Dropout(config.dropout)\n", | |
" self.n_head = config.n_head\n", | |
" self.n_embd = config.n_embd\n", | |
" self.dropout = config.dropout\n", | |
"\n", | |
" # causal mask to ensure that attention is only applied to the left in the input sequence\n", | |
" self.register_buffer(\"bias\", torch.tril(torch.ones(config.block_size, config.block_size))\n", | |
" .view(1, 1, config.block_size, config.block_size))\n", | |
"\n", | |
" def forward(self, x):\n", | |
" B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)\n", | |
" # calculate query, key, values for all heads in batch and move head forward to be the batch dim\n", | |
" q, k, v = self.c_attn(x).split(self.n_embd, dim=2)\n", | |
" k = k.view(B, T, self.n_head, C //\n", | |
" self.n_head).transpose(1, 2) # (B, nh, T, hs)\n", | |
" q = q.view(B, T, self.n_head, C //\n", | |
" self.n_head).transpose(1, 2) # (B, nh, T, hs)\n", | |
" v = v.view(B, T, self.n_head, C //\n", | |
" self.n_head).transpose(1, 2) # (B, nh, T, hs)\n", | |
"\n", | |
" # manual implementation of attention\n", | |
" # q shape:(B, nh, T, hs), k transpose shape (B, nh, hs, T) -> (B, nh, T, T)\n", | |
" att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))\n", | |
" att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float(-10))\n", | |
" att = F.softmax(att, dim=-1)\n", | |
" att = self.attn_dropout(att)\n", | |
" y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n", | |
" # re-assemble all head outputs side by side\n", | |
" y = y.transpose(1, 2).contiguous().view(B, T, C)\n", | |
"\n", | |
" y = self.resid_dropout(self.c_proj(y))\n", | |
" return y\n", | |
"\n", | |
"\n", | |
"class MLP(nn.Module):\n", | |
"\n", | |
" def __init__(self, config):\n", | |
" super().__init__()\n", | |
" self.c_fc = nn.Linear(\n", | |
" config.n_embd, 4 * config.n_embd, bias=config.bias)\n", | |
" self.c_proj = nn.Linear(\n", | |
" 4 * config.n_embd, config.n_embd, bias=config.bias)\n", | |
" self.dropout = nn.Dropout(config.dropout)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.c_fc(x)\n", | |
" x = new_gelu(x)\n", | |
" x = self.c_proj(x)\n", | |
" x = self.dropout(x)\n", | |
" return x\n", | |
"\n", | |
"\n", | |
"class Block(nn.Module):\n", | |
"\n", | |
" def __init__(self, config):\n", | |
" super().__init__()\n", | |
" self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)\n", | |
" self.attn = CausalSelfAttention(config)\n", | |
" self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)\n", | |
" self.mlp = MLP(config)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = x + self.attn(self.ln_1(x))\n", | |
" x = x + self.mlp(self.ln_2(x))\n", | |
" return x\n", | |
"\n", | |
"\n", | |
"@dataclass\n", | |
"class GPTConfig:\n", | |
" block_size: int = 1024\n", | |
" # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency\n", | |
" vocab_size: int = 50304\n", | |
" n_layer: int = 12\n", | |
" n_head: int = 12\n", | |
" n_embd: int = 768\n", | |
" dropout: float = 0.0\n", | |
" # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster\n", | |
" bias: bool = True\n", | |
"\n", | |
"\n", | |
"class GPT(nn.Module):\n", | |
"\n", | |
" def __init__(self, config):\n", | |
" super().__init__()\n", | |
" assert config.vocab_size is not None\n", | |
" assert config.block_size is not None\n", | |
" self.config = config\n", | |
"\n", | |
" self.transformer = nn.ModuleDict(dict(\n", | |
" wte=nn.Embedding(config.vocab_size, config.n_embd),\n", | |
" wpe=nn.Embedding(config.block_size, config.n_embd),\n", | |
" drop=nn.Dropout(config.dropout),\n", | |
" h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),\n", | |
" ln_f=LayerNorm(config.n_embd, bias=config.bias),\n", | |
" ))\n", | |
"\n", | |
" self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n", | |
"\n", | |
" # weight-tying\n", | |
" # https://paperswithcode.com/method/weight-tying\n", | |
" self.transformer.wte.weight = self.lm_head.weight\n", | |
" self.block = Block(config)\n", | |
" # init all weights\n", | |
" self.apply(self._init_weights)\n", | |
" # apply special scaled init to the residual projections, per GPT-2 paper\n", | |
" for pn, p in self.named_parameters():\n", | |
" if pn.endswith('c_proj.weight'):\n", | |
" torch.nn.init.normal_(\n", | |
" p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))\n", | |
"\n", | |
" # report number of parameters\n", | |
" print(\"number of parameters: %.2fM\" % (self.get_num_params()/1e6,))\n", | |
"\n", | |
" def get_num_params(self, non_embedding=True):\n", | |
" \"\"\"\n", | |
" Return the number of parameters in the model.\n", | |
" For non-embedding count (default), the position embeddings get subtracted.\n", | |
" The token embeddings would too, except due to the parameter sharing these\n", | |
" params are actually used as weights in the final layer, so we include them.\n", | |
" \"\"\"\n", | |
" n_params = sum(p.numel() for p in self.parameters())\n", | |
" if non_embedding:\n", | |
" n_params -= self.transformer.wpe.weight.numel()\n", | |
" return n_params\n", | |
"\n", | |
" def _init_weights(self, module):\n", | |
" if isinstance(module, nn.Linear):\n", | |
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n", | |
" if module.bias is not None:\n", | |
" torch.nn.init.zeros_(module.bias)\n", | |
" elif isinstance(module, nn.Embedding):\n", | |
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n", | |
"\n", | |
" def forward(self, idx, targets=None):\n", | |
" device = idx.device\n", | |
" b, t = idx.size()\n", | |
" assert t <= self.config.block_size, f\"Cannot forward sequence of length {t}, block size is only {self.config.block_size}\"\n", | |
" pos = torch.arange(0, t, dtype=torch.long,\n", | |
" device=device).unsqueeze(0) # shape (1, t)\n", | |
"\n", | |
" # # # forward the GPT model itself\n", | |
" # token embeddings of shape (b, t, n_embd), idx -> token_emb\n", | |
" idx = self.transformer.wte(idx)\n", | |
" # position embeddings of shape (1, t, n_embd)\n", | |
" pos_emb = self.transformer.wpe(pos)\n", | |
" idx = self.transformer.drop(idx + pos_emb)\n", | |
"\n", | |
" for block in self.transformer.h:\n", | |
" idx = block(idx)\n", | |
"\n", | |
" idx = self.transformer.ln_f(idx)\n", | |
" idx = self.lm_head(idx)\n", | |
"\n", | |
" return idx\n", | |
"\n", | |
"\n", | |
"gptconf = GPTConfig(block_size=32, vocab_size=65, n_layer=2,\n", | |
" n_head=2, n_embd=32, dropout=0.0, bias=False)\n", | |
"model = GPT(gptconf)\n", | |
"model.get_num_params()\n", | |
"\n", | |
"\n", | |
"\n", | |
"shape = [1, 8]\n", | |
"x = torch.randint(8, (1, 8))\n", | |
"torch_out = model(x)\n", | |
"\n", | |
"torch.onnx.export(\n", | |
" model,\n", | |
" x,\n", | |
" onnx_path,\n", | |
" export_params=True, # store the trained parameter weights inside the model file\n", | |
" opset_version=11, # the ONNX version to export the model to\n", | |
" do_constant_folding=True, # whether to execute constant folding for optimization\n", | |
" input_names=[\"input\"], # the model's input names\n", | |
" output_names=[\"output\"], # the model's output names\n", | |
" dynamic_axes={\n", | |
" \"input\": {0: \"batch_size\"}, # variable length axes\n", | |
" \"output\": {0: \"batch_size\"},\n", | |
" },\n", | |
")\n", | |
"\n", | |
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n", | |
"\n", | |
"data = dict(\n", | |
" input_shapes=[shape],\n", | |
" input_data=[d],\n", | |
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out],\n", | |
")\n", | |
"\n", | |
"# Serialize data into file:\n", | |
"json.dump(data, open(\"./model/input.json\", \"w\"))\n", | |
"print(\"Success\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "6ygqdFSOpktX" | |
}, | |
"source": [ | |
"### Circuitize the model\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "7u1l91AWpktX", | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# Define the base directory\n", | |
"import os\n", | |
"import subprocess\n", | |
"import time\n", | |
"\n", | |
"# Printing initial message\n", | |
"print(f\"Creating circuit for {onnx_path}...\")\n", | |
"!{ezkl_binary} --version\n", | |
"start = time.time()\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [], | |
"metadata": { | |
"id": "yNY-qIwu3Fjv" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "T6KeIgzX2Jaa" | |
}, | |
"outputs": [], | |
"source": [ | |
"!{ezkl_binary} gen-settings -M {onnx_path} --param-visibility fixed --input-visibility public --output-visibility public -O {settings_path}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "vv4jYbIw2Jaa" | |
}, | |
"outputs": [], | |
"source": [ | |
"!{ezkl_binary} calibrate-settings -M {onnx_path} -D {input_path} -O {settings_path} --target accuracy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "jgwDg0F42Jaa" | |
}, | |
"outputs": [], | |
"source": [ | |
"!{ezkl_binary} get-srs -S {settings_path} --srs-path {srs_path}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "mD8o1iIp2Jaa" | |
}, | |
"outputs": [], | |
"source": [ | |
"!{ezkl_binary} compile-circuit -M {onnx_path} -S {settings_path} --compiled-circuit {ezkl_path}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "9jjCdsHc2Jaa" | |
}, | |
"outputs": [], | |
"source": [ | |
"!{ezkl_binary} setup -M {ezkl_path} --vk-path {vk_path} --pk-path {pk_path} --srs-path {srs_path}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "5sVuobkm2Jaa" | |
}, | |
"outputs": [], | |
"source": [ | |
"end = time.time()\n", | |
"print(\"Time taken to circuitize model: \", end - start)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "JLo-BlfspktX" | |
}, | |
"source": [ | |
"### Reproduction of the issue\n", | |
"\n", | |
"When attempting to generate a proof for this model, an OOR error is thrown." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "NUW1u_v0pktX", | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"print(\"Running\");\n", | |
"import re\n", | |
"import onnxruntime as ort\n", | |
"\n", | |
"%set_env RUST_LOG=debug\n", | |
"!{ezkl_binary} gen-witness -D {input_path} -M {ezkl_path} -O {witness_path}\n", | |
"!{ezkl_binary} prove -M {ezkl_path} -W {witness_path} --pk-path {pk_path} --proof-path {proof_path} --srs-path {srs_path}\n", | |
"\n", | |
"def remove_non_ascii(s):\n", | |
" regex = re.compile(r\"\\x1b\\[([0-9]*;?[0-9]+)?[m|K|h]\")\n", | |
" return regex.sub(\"\", s)\n", | |
"\n", | |
"\n", | |
"\n", | |
"proof_in_hex = subprocess.run(\n", | |
" [ ezkl_binary,\"print-proof-hex\", \"--proof-path\", proof_path],\n", | |
" text=True,\n", | |
" stdout=subprocess.PIPE,\n", | |
" stderr=subprocess.PIPE,\n", | |
").stdout\n", | |
"proof_in_hex = remove_non_ascii(proof_in_hex)\n", | |
"\n", | |
"split_by_command = proof_in_hex.split(\"| }\")[1]\n", | |
"instances_string = split_by_command.split(\"[*] [0s, ezkl::execute] - \")[0].split(\"\\n\")[\n", | |
" 1\n", | |
"]\n", | |
"\n", | |
"instances = []\n", | |
"instances_string = instances_string.replace(\"[\", \"\").replace(\"]\", \"\")\n", | |
"for hex_string in instances_string.split(\",\"):\n", | |
" try:\n", | |
" instances.append(int(hex_string.strip(), 16))\n", | |
" except ValueError:\n", | |
" continue\n", | |
"\n", | |
"print(\"Instances from the circuit:\", instances)\n", | |
"\n", | |
"print(\"Output tokens from PyTorch: \", torch_out)\n", | |
"\n", | |
"\n", | |
"\n", | |
"session = ort.InferenceSession(onnx_path, providers=[\"CPUExecutionProvider\"])\n", | |
"results = session.run(\n", | |
" None,\n", | |
" {k.name: [v.numpy().tolist()] for k, v in zip(session.get_inputs(), x)},\n", | |
")\n", | |
"print(\"Output tokens from ONNX\", results)\n", | |
"\n" | |
] | |
} | |
], | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment