Skip to content

Instantly share code, notes, and snippets.

@HudsonGraeme
Last active February 1, 2024 20:40
Show Gist options
  • Save HudsonGraeme/5fe39f26d68626586dac2fec609b30ea to your computer and use it in GitHub Desktop.
Save HudsonGraeme/5fe39f26d68626586dac2fec609b30ea to your computer and use it in GitHub Desktop.
5fe39f26d68626586dac2fec609b30ea
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/HudsonGraeme/5fe39f26d68626586dac2fec609b30ea/notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VOhZf5ompktU"
},
"source": [
"### Install required dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m-28wslFI2wX"
},
"outputs": [],
"source": [
"!curl https://raw.githubusercontent.com/zkonduit/ezkl/main/install_ezkl_cli.sh | bash\n",
"!pip uninstall -y tensorflow\n",
"!pip install dataclasses\n",
"!pip install matplotlib\n",
"!pip install torch\n",
"!pip install numpy\n",
"!pip install requests\n",
"!pip install onnxruntime\n",
"!pip install onnx\n",
"!pip install torchvision\n",
"!pip uninstall -y typing_extensions\n",
"!pip install typing_extensions\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P2iv9gZLpktV"
},
"source": [
"### Model Definition"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-61wEcCP1OZ5"
},
"outputs": [],
"source": [
"\"\"\"\n",
"Reference: https://github.com/karpathy/nanoGPT\n",
"\"\"\"\n",
"\n",
"import json\n",
"import math\n",
"from dataclasses import dataclass\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn import functional as F\n",
"import sys\n",
"import os\n",
"\n",
"model_dir = \"./model/\"\n",
"\n",
"# Constructing paths using os.path.join\n",
"onnx_path = os.path.join(model_dir, \"network.onnx\")\n",
"input_path = os.path.join(model_dir, \"input.json\")\n",
"settings_path = os.path.join(model_dir, \"settings.json\")\n",
"srs_path = os.path.join(model_dir, \"kzg.srs\")\n",
"ezkl_path = os.path.join(model_dir, \"network.ezkl\")\n",
"pk_path = os.path.join(model_dir, \"pk.key\")\n",
"vk_path = os.path.join(model_dir, \"vk.key\")\n",
"witness_path = os.path.join(model_dir, \"witness.json\")\n",
"proof_path = os.path.join(model_dir, \"proof.proof\")\n",
"sol_path = os.path.join(model_dir, \"verif.sol\")\n",
"abi_path = os.path.join(model_dir, \"verif.abi\")\n",
"ezkl_binary = '/' + os.path.join('root', '.ezkl', 'ezkl')\n",
"\n",
"\n",
"def remove_non_ascii(s):\n",
" regex = re.compile(r\"\\x1b\\[([0-9]*;?[0-9]+)?[m|K|h]\")\n",
" return regex.sub(\"\", s)\n",
"\n",
"\n",
"def new_gelu(x):\n",
" \"\"\"\n",
" Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).\n",
" Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415\n",
" \"\"\"\n",
" return (\n",
" 0.5\n",
" * x\n",
" * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * x * x * x)))\n",
" )\n",
"\n",
"\n",
"\n",
"class LayerNorm(nn.Module):\n",
" \"\"\" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False \"\"\"\n",
"\n",
" def __init__(self, ndim, bias):\n",
" super().__init__()\n",
" self.weight = nn.Parameter(torch.ones(ndim))\n",
" self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None\n",
"\n",
" def forward(self, input):\n",
" return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)\n",
"\n",
"\n",
"class CausalSelfAttention(nn.Module):\n",
"\n",
" def __init__(self, config):\n",
" super().__init__()\n",
" assert config.n_embd % config.n_head == 0\n",
" # key, query, value projections for all heads, but in a batch\n",
" self.c_attn = nn.Linear(\n",
" config.n_embd, 3 * config.n_embd, bias=config.bias)\n",
" # output projection\n",
" self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n",
" # regularization\n",
" self.attn_dropout = nn.Dropout(config.dropout)\n",
" self.resid_dropout = nn.Dropout(config.dropout)\n",
" self.n_head = config.n_head\n",
" self.n_embd = config.n_embd\n",
" self.dropout = config.dropout\n",
"\n",
" # causal mask to ensure that attention is only applied to the left in the input sequence\n",
" self.register_buffer(\"bias\", torch.tril(torch.ones(config.block_size, config.block_size))\n",
" .view(1, 1, config.block_size, config.block_size))\n",
"\n",
" def forward(self, x):\n",
" B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)\n",
" # calculate query, key, values for all heads in batch and move head forward to be the batch dim\n",
" q, k, v = self.c_attn(x).split(self.n_embd, dim=2)\n",
" k = k.view(B, T, self.n_head, C //\n",
" self.n_head).transpose(1, 2) # (B, nh, T, hs)\n",
" q = q.view(B, T, self.n_head, C //\n",
" self.n_head).transpose(1, 2) # (B, nh, T, hs)\n",
" v = v.view(B, T, self.n_head, C //\n",
" self.n_head).transpose(1, 2) # (B, nh, T, hs)\n",
"\n",
" # manual implementation of attention\n",
" # q shape:(B, nh, T, hs), k transpose shape (B, nh, hs, T) -> (B, nh, T, T)\n",
" att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))\n",
" att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float(-10))\n",
" att = F.softmax(att, dim=-1)\n",
" att = self.attn_dropout(att)\n",
" y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n",
" # re-assemble all head outputs side by side\n",
" y = y.transpose(1, 2).contiguous().view(B, T, C)\n",
"\n",
" y = self.resid_dropout(self.c_proj(y))\n",
" return y\n",
"\n",
"\n",
"class MLP(nn.Module):\n",
"\n",
" def __init__(self, config):\n",
" super().__init__()\n",
" self.c_fc = nn.Linear(\n",
" config.n_embd, 4 * config.n_embd, bias=config.bias)\n",
" self.c_proj = nn.Linear(\n",
" 4 * config.n_embd, config.n_embd, bias=config.bias)\n",
" self.dropout = nn.Dropout(config.dropout)\n",
"\n",
" def forward(self, x):\n",
" x = self.c_fc(x)\n",
" x = new_gelu(x)\n",
" x = self.c_proj(x)\n",
" x = self.dropout(x)\n",
" return x\n",
"\n",
"\n",
"class Block(nn.Module):\n",
"\n",
" def __init__(self, config):\n",
" super().__init__()\n",
" self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)\n",
" self.attn = CausalSelfAttention(config)\n",
" self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)\n",
" self.mlp = MLP(config)\n",
"\n",
" def forward(self, x):\n",
" x = x + self.attn(self.ln_1(x))\n",
" x = x + self.mlp(self.ln_2(x))\n",
" return x\n",
"\n",
"\n",
"@dataclass\n",
"class GPTConfig:\n",
" block_size: int = 1024\n",
" # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency\n",
" vocab_size: int = 50304\n",
" n_layer: int = 12\n",
" n_head: int = 12\n",
" n_embd: int = 768\n",
" dropout: float = 0.0\n",
" # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster\n",
" bias: bool = True\n",
"\n",
"\n",
"class GPT(nn.Module):\n",
"\n",
" def __init__(self, config):\n",
" super().__init__()\n",
" assert config.vocab_size is not None\n",
" assert config.block_size is not None\n",
" self.config = config\n",
"\n",
" self.transformer = nn.ModuleDict(dict(\n",
" wte=nn.Embedding(config.vocab_size, config.n_embd),\n",
" wpe=nn.Embedding(config.block_size, config.n_embd),\n",
" drop=nn.Dropout(config.dropout),\n",
" h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),\n",
" ln_f=LayerNorm(config.n_embd, bias=config.bias),\n",
" ))\n",
"\n",
" self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
"\n",
" # weight-tying\n",
" # https://paperswithcode.com/method/weight-tying\n",
" self.transformer.wte.weight = self.lm_head.weight\n",
" self.block = Block(config)\n",
" # init all weights\n",
" self.apply(self._init_weights)\n",
" # apply special scaled init to the residual projections, per GPT-2 paper\n",
" for pn, p in self.named_parameters():\n",
" if pn.endswith('c_proj.weight'):\n",
" torch.nn.init.normal_(\n",
" p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))\n",
"\n",
" # report number of parameters\n",
" print(\"number of parameters: %.2fM\" % (self.get_num_params()/1e6,))\n",
"\n",
" def get_num_params(self, non_embedding=True):\n",
" \"\"\"\n",
" Return the number of parameters in the model.\n",
" For non-embedding count (default), the position embeddings get subtracted.\n",
" The token embeddings would too, except due to the parameter sharing these\n",
" params are actually used as weights in the final layer, so we include them.\n",
" \"\"\"\n",
" n_params = sum(p.numel() for p in self.parameters())\n",
" if non_embedding:\n",
" n_params -= self.transformer.wpe.weight.numel()\n",
" return n_params\n",
"\n",
" def _init_weights(self, module):\n",
" if isinstance(module, nn.Linear):\n",
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
" if module.bias is not None:\n",
" torch.nn.init.zeros_(module.bias)\n",
" elif isinstance(module, nn.Embedding):\n",
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
"\n",
" def forward(self, idx, targets=None):\n",
" device = idx.device\n",
" b, t = idx.size()\n",
" assert t <= self.config.block_size, f\"Cannot forward sequence of length {t}, block size is only {self.config.block_size}\"\n",
" pos = torch.arange(0, t, dtype=torch.long,\n",
" device=device).unsqueeze(0) # shape (1, t)\n",
"\n",
" # # # forward the GPT model itself\n",
" # token embeddings of shape (b, t, n_embd), idx -> token_emb\n",
" idx = self.transformer.wte(idx)\n",
" # position embeddings of shape (1, t, n_embd)\n",
" pos_emb = self.transformer.wpe(pos)\n",
" idx = self.transformer.drop(idx + pos_emb)\n",
"\n",
" for block in self.transformer.h:\n",
" idx = block(idx)\n",
"\n",
" idx = self.transformer.ln_f(idx)\n",
" idx = self.lm_head(idx)\n",
"\n",
" return idx\n",
"\n",
"\n",
"gptconf = GPTConfig(block_size=32, vocab_size=65, n_layer=2,\n",
" n_head=2, n_embd=32, dropout=0.0, bias=False)\n",
"model = GPT(gptconf)\n",
"model.get_num_params()\n",
"\n",
"\n",
"\n",
"shape = [1, 8]\n",
"x = torch.randint(8, (1, 8))\n",
"torch_out = model(x)\n",
"\n",
"torch.onnx.export(\n",
" model,\n",
" x,\n",
" onnx_path,\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=11, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=[\"input\"], # the model's input names\n",
" output_names=[\"output\"], # the model's output names\n",
" dynamic_axes={\n",
" \"input\": {0: \"batch_size\"}, # variable length axes\n",
" \"output\": {0: \"batch_size\"},\n",
" },\n",
")\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(\n",
" input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out],\n",
")\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"./model/input.json\", \"w\"))\n",
"print(\"Success\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6ygqdFSOpktX"
},
"source": [
"### Circuitize the model\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7u1l91AWpktX",
"scrolled": false
},
"outputs": [],
"source": [
"# Define the base directory\n",
"import os\n",
"import subprocess\n",
"import time\n",
"\n",
"# Printing initial message\n",
"print(f\"Creating circuit for {onnx_path}...\")\n",
"!{ezkl_binary} --version\n",
"start = time.time()\n"
]
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "yNY-qIwu3Fjv"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T6KeIgzX2Jaa"
},
"outputs": [],
"source": [
"!{ezkl_binary} gen-settings -M {onnx_path} --param-visibility fixed --input-visibility public --output-visibility public -O {settings_path}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vv4jYbIw2Jaa"
},
"outputs": [],
"source": [
"!{ezkl_binary} calibrate-settings -M {onnx_path} -D {input_path} -O {settings_path} --target accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jgwDg0F42Jaa"
},
"outputs": [],
"source": [
"!{ezkl_binary} get-srs -S {settings_path} --srs-path {srs_path}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mD8o1iIp2Jaa"
},
"outputs": [],
"source": [
"!{ezkl_binary} compile-circuit -M {onnx_path} -S {settings_path} --compiled-circuit {ezkl_path}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9jjCdsHc2Jaa"
},
"outputs": [],
"source": [
"!{ezkl_binary} setup -M {ezkl_path} --vk-path {vk_path} --pk-path {pk_path} --srs-path {srs_path}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5sVuobkm2Jaa"
},
"outputs": [],
"source": [
"end = time.time()\n",
"print(\"Time taken to circuitize model: \", end - start)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JLo-BlfspktX"
},
"source": [
"### Reproduction of the issue\n",
"\n",
"When attempting to generate a proof for this model, an OOR error is thrown."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NUW1u_v0pktX",
"scrolled": true
},
"outputs": [],
"source": [
"print(\"Running\");\n",
"import re\n",
"import onnxruntime as ort\n",
"\n",
"%set_env RUST_LOG=debug\n",
"!{ezkl_binary} gen-witness -D {input_path} -M {ezkl_path} -O {witness_path}\n",
"!{ezkl_binary} prove -M {ezkl_path} -W {witness_path} --pk-path {pk_path} --proof-path {proof_path} --srs-path {srs_path}\n",
"\n",
"def remove_non_ascii(s):\n",
" regex = re.compile(r\"\\x1b\\[([0-9]*;?[0-9]+)?[m|K|h]\")\n",
" return regex.sub(\"\", s)\n",
"\n",
"\n",
"\n",
"proof_in_hex = subprocess.run(\n",
" [ ezkl_binary,\"print-proof-hex\", \"--proof-path\", proof_path],\n",
" text=True,\n",
" stdout=subprocess.PIPE,\n",
" stderr=subprocess.PIPE,\n",
").stdout\n",
"proof_in_hex = remove_non_ascii(proof_in_hex)\n",
"\n",
"split_by_command = proof_in_hex.split(\"| }\")[1]\n",
"instances_string = split_by_command.split(\"[*] [0s, ezkl::execute] - \")[0].split(\"\\n\")[\n",
" 1\n",
"]\n",
"\n",
"instances = []\n",
"instances_string = instances_string.replace(\"[\", \"\").replace(\"]\", \"\")\n",
"for hex_string in instances_string.split(\",\"):\n",
" try:\n",
" instances.append(int(hex_string.strip(), 16))\n",
" except ValueError:\n",
" continue\n",
"\n",
"print(\"Instances from the circuit:\", instances)\n",
"\n",
"print(\"Output tokens from PyTorch: \", torch_out)\n",
"\n",
"\n",
"\n",
"session = ort.InferenceSession(onnx_path, providers=[\"CPUExecutionProvider\"])\n",
"results = session.run(\n",
" None,\n",
" {k.name: [v.numpy().tolist()] for k, v in zip(session.get_inputs(), x)},\n",
")\n",
"print(\"Output tokens from ONNX\", results)\n",
"\n"
]
}
],
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment