Last active
March 22, 2023 11:50
-
-
Save maltempi/f2d30e513ef7046f3c0455a5dc4be3e0 to your computer and use it in GitHub Desktop.
[MO436] ResNet18 using JAX, Flax, Onnx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "U5ChPfvrX9XQ" | |
}, | |
"source": [ | |
"# Project 4: Inference using ResNet18, Flax, JAX and Onnx\n", | |
"\n", | |
"--- \n", | |
"\n", | |
"**MO436 - Machine Learning under the hood** - IC Unicamp\n", | |
"\n", | |
"Written by [Thiago Maltempi](https://github.com/maltempi)\n", | |
"\n", | |
"July 2022\n", | |
"\n", | |
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/maltempi/f2d30e513ef7046f3c0455a5dc4be3e0)", | |
"\n", | |
"--- \n", | |
"\n", | |
"This notebook contains a implementation of ResNet18 in Flax and its parameters is loaded a Onnx file. After that, we fiddle with Jax's transformers (`vmap` and `jit`) get some performance analysis.\n", | |
"\n", | |
"At least one GPU is required to run this notebook. All data displayed in result cells were ran on Google Colab with GPU.\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "GqeXojX0YQEZ" | |
}, | |
"source": [ | |
"## Installing and importing dependencies for this notebook" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "BBQ9aX95X5tF" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install jax flax onnx matplotlib numpy scipy requests opencv-python -q" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"id": "xIJwiPtDYWS9" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Python's\n", | |
"import sys\n", | |
"import os\n", | |
"from os.path import exists\n", | |
"from datetime import datetime\n", | |
"\n", | |
"# Flax stuff\n", | |
"import flax.linen as nn\n", | |
"import optax\n", | |
"from flax.training.train_state import TrainState\n", | |
"from flax.core import freeze, unfreeze\n", | |
"from jax import vmap, pmap, jit, make_jaxpr, device_put\n", | |
"\n", | |
"# Jax stuff\n", | |
"import jax.numpy as jnp\n", | |
"import jax.scipy as jsp\n", | |
"import jax\n", | |
"\n", | |
"# Other dependencies\n", | |
"import numpy as np\n", | |
"import scipy as sp\n", | |
"import matplotlib.pyplot as plt\n", | |
"import requests\n", | |
"import onnx\n", | |
"from onnx import numpy_helper\n", | |
"import cv2" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "G-fustKcvn7t" | |
}, | |
"source": [ | |
"## Defining settings, downloading dataset, labels etc\n", | |
"\n", | |
"Nothing really important, just preparing the environment." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"id": "Tc1Xe8oRZqT0" | |
}, | |
"outputs": [], | |
"source": [ | |
"onnx_model_path = './resnet18.onnx'\n", | |
"onnx_url = 'https://raw.githubusercontent.com/maltempi/MO436-work/main/models/resnet18/resnet18.onnx'\n", | |
"dataset_url = 'https://raw.githubusercontent.com/maltempi/MO436-work/main/datasets/imagenet'\n", | |
"labels_url = 'https://raw.githubusercontent.com/maltempi/MO436-work/main/datasets/imagenet/synset_words.txt'\n", | |
"\n", | |
"INPUT_IMAGE_HEIGHT = 224\n", | |
"INPUT_IMAGE_WIDTH = 224\n", | |
"INPUT_IMAGE_CHANNELS = 3\n", | |
"FLAX_IMAGE_SHAPE = (224,224,3) # HWC\n", | |
"FLAX_INPUT_SHAPE = (1,224,224,3) # NHWC\n", | |
"\n", | |
"def download(url, path):\n", | |
" if not exists(path):\n", | |
" print('Downloading', url, 'in', path)\n", | |
" res = requests.get(url)\n", | |
" with open(path, 'wb') as file:\n", | |
" file.write(res.content)\n", | |
"\n", | |
"\n", | |
"def download_onnx_model():\n", | |
" download(onnx_url, onnx_model_path)\n", | |
"\n", | |
"\n", | |
"def download_dataset(repository_url=dataset_url):\n", | |
" import os \n", | |
" _dir = './imgs'\n", | |
" os.makedirs(_dir, exist_ok=True)\n", | |
"\n", | |
" for i in range(1000):\n", | |
" filename=f'''val_{str(i+1).rjust(8, '0')}.png'''\n", | |
" path = f'{_dir}/{filename}'\n", | |
" url = f'''{repository_url}/{filename}'''\n", | |
" download(url, path)\n", | |
"\n", | |
"\n", | |
"def get_labels(url=labels_url, path='./labels.txt'):\n", | |
" download(labels_url, path)\n", | |
" with open(path, 'r') as f:\n", | |
" return [int(l.rstrip()) for l in f]\n", | |
"\n", | |
"download(f'{dataset_url}/val_00000097.png', './rabbit.png') # 188\n", | |
"download(f'{dataset_url}/val_00000211.png', './cat.png') # 95\n", | |
"download(f'{dataset_url}/val_00000217.png', './bird.png') # 392\n", | |
"\n", | |
"RABBIT_LABEL = '188'\n", | |
"CAT_LABEL = '95'\n", | |
"BIRD_LABEL = '392'\n", | |
"\n", | |
"download_dataset()\n", | |
"download_onnx_model()\n", | |
"labels = jnp.array(get_labels())\n", | |
"\n", | |
"def plot_speedup(t, baseline='numpy'):\n", | |
" '''\n", | |
" Compute speedup between two approaches.\n", | |
" Original code from: https://github.com/MO436-MC934/notebooks/blob/7b99a7676c58b38dd428a95bc73acc07b1282838/05-JAX/util/speedup.py\n", | |
" '''\n", | |
" def avg(lst):\n", | |
" return sum(lst) / len(lst)\n", | |
"\n", | |
" # Prepare data\n", | |
" base = avg(t[baseline].all_runs)\n", | |
" name = t.keys()\n", | |
" time = [base / avg(v.all_runs) for v in t.values()]\n", | |
" # Plot speedups as bar graph\n", | |
" plt.bar(name, time, width=0.8)\n", | |
" plt.ylabel(\"Speedup\")\n", | |
" # Write speedup on top of bars\n", | |
" for (i, n), t in zip(enumerate(name), time):\n", | |
" plt.text(i, t, f\"{t:.2f}x\", horizontalalignment='center', fontweight='bold')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "AFqqViGnYyEY" | |
}, | |
"source": [ | |
"## Implementing Resnet18 using Flax\n", | |
"\n", | |
"For the implementation of Resnet18, I opened the `onnx_model_path` into [NetronApp](https://netron.app) and followed the graph, node by node. Notice I added `name` parameter for every layer using same as in Onnx file, this will be useful for matching parameters' names.\n", | |
"\n", | |
"\n", | |
"### References for this section:\n", | |
"- [1] https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Conv.html\n", | |
"- [2] https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.BatchNorm.html\n", | |
"- [3] https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.relu.html\n", | |
"- [4] https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.max_pool.html\n", | |
"- [5] https://github.com/KaimingHe/deep-residual-networks/issues/10#issuecomment-194037195\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"id": "9lT0K0h5YwYW" | |
}, | |
"outputs": [], | |
"source": [ | |
"class ResNet18(nn.Module): \n", | |
" @nn.compact\n", | |
" def __call__(self, x):\n", | |
" momentum = 0.8999999761581421\n", | |
" epsilon = 0.000009999999747378752\n", | |
"\n", | |
" # why use_bias=false? See: [5]\n", | |
" use_bias = False\n", | |
"\n", | |
" x = nn.Conv(features=64, kernel_size=(7, 7), padding=(3,3), strides=2, use_bias=use_bias, name='resnetv15_conv0')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_batchnorm0')(x)\n", | |
" x = nn.relu(x)\n", | |
" x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding=((1,1),(1,1)))\n", | |
"\n", | |
" # Stage 1\n", | |
" res = x\n", | |
" x = nn.Conv(features=64, kernel_size=(3, 3), padding=(1,1), strides=1, use_bias=use_bias, name='resnetv15_stage1_conv0')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage1_batchnorm0')(x)\n", | |
" x = nn.relu(x)\n", | |
" x = nn.Conv(features=64, kernel_size=(3, 3), padding=(1,1), strides=1, use_bias=use_bias, name='resnetv15_stage1_conv1')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage1_batchnorm1')(x)\n", | |
" x = nn.relu(x + res)\n", | |
"\n", | |
" res = x\n", | |
" x = nn.Conv(features=64, kernel_size=(3, 3), padding=(1,1), strides=1, use_bias=use_bias, name='resnetv15_stage1_conv2')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage1_batchnorm2')(x)\n", | |
" x = nn.relu(x)\n", | |
" x = nn.Conv(features=64, kernel_size=(3, 3), padding=(1,1), strides=1, use_bias=use_bias, name='resnetv15_stage1_conv3')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage1_batchnorm3')(x)\n", | |
" x = nn.relu(x + res)\n", | |
"\n", | |
"\n", | |
" # Stage 2\n", | |
" res = x\n", | |
" x = nn.Conv(features=128, kernel_size=(3, 3), padding=(1,1), strides=2, use_bias=use_bias, name='resnetv15_stage2_conv0')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage2_batchnorm0')(x)\n", | |
" x = nn.relu(x)\n", | |
" x = nn.Conv(features=128, kernel_size=(3, 3), padding=(1,1), strides=1, use_bias=use_bias, name='resnetv15_stage2_conv1')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage2_batchnorm1')(x)\n", | |
" res = nn.Conv(features=128, kernel_size=(1, 1), padding=(0,0), strides=2, use_bias=use_bias, name='resnetv15_stage2_conv2')(res)\n", | |
" res = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage2_batchnorm2')(res)\n", | |
" x = nn.relu(x + res)\n", | |
"\n", | |
" res = x\n", | |
" x = nn.Conv(features=128, kernel_size=(3, 3), padding=(1,1), strides=1, use_bias=use_bias, name='resnetv15_stage2_conv3')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage2_batchnorm3')(x)\n", | |
" x = nn.relu(x)\n", | |
" x = nn.Conv(features=128, kernel_size=(3, 3), padding=(1,1), strides=1, use_bias=use_bias, name='resnetv15_stage2_conv4')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage2_batchnorm4')(x)\n", | |
" x = nn.relu(x + res)\n", | |
"\n", | |
" # Stage 3\n", | |
" res = x\n", | |
" x = nn.Conv(features=256, kernel_size=(3, 3), padding=(1,1), strides=2, use_bias=use_bias, name='resnetv15_stage3_conv0')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage3_batchnorm0')(x)\n", | |
" x = nn.relu(x)\n", | |
" x = nn.Conv(features=256, kernel_size=(3, 3), padding=(1,1), strides=1, use_bias=use_bias, name='resnetv15_stage3_conv1')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage3_batchnorm1')(x)\n", | |
" res = nn.Conv(features=256, kernel_size=(1, 1), padding=(0,0), strides=2, use_bias=use_bias, name='resnetv15_stage3_conv2')(res)\n", | |
" res = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage3_batchnorm2')(res)\n", | |
" x = nn.relu(x + res)\n", | |
"\n", | |
" res = x\n", | |
" x = nn.Conv(features=256, kernel_size=(3, 3), padding=(1,1), strides=1, use_bias=use_bias, name='resnetv15_stage3_conv3')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage3_batchnorm3')(x)\n", | |
" x = nn.relu(x)\n", | |
" x = nn.Conv(features=256, kernel_size=(3, 3), padding=(1,1), strides=1, use_bias=use_bias, name='resnetv15_stage3_conv4')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage3_batchnorm4')(x)\n", | |
" x = nn.relu(x + res)\n", | |
"\n", | |
"\n", | |
" # Stage 4\n", | |
" res = x\n", | |
" x = nn.Conv(features=512, kernel_size=(3, 3), padding=(1,1), strides=2, use_bias=use_bias, name='resnetv15_stage4_conv0')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage4_batchnorm0')(x)\n", | |
" x = nn.relu(x)\n", | |
" x = nn.Conv(features=512, kernel_size=(3, 3), padding=(1,1), strides=1, use_bias=use_bias, name='resnetv15_stage4_conv1')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage4_batchnorm1')(x)\n", | |
" res = nn.Conv(features=512, kernel_size=(1, 1), padding=(0,0), strides=2, use_bias=use_bias, name='resnetv15_stage4_conv2')(res)\n", | |
" res = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage4_batchnorm2')(res)\n", | |
" x = nn.relu(x + res)\n", | |
"\n", | |
" res = x\n", | |
" x = nn.Conv(features=512, kernel_size=(3, 3), padding=(1,1), strides=1, use_bias=use_bias, name='resnetv15_stage4_conv3')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage4_batchnorm3')(x)\n", | |
" x = nn.relu(x)\n", | |
" x = nn.Conv(features=512, kernel_size=(3, 3), padding=(1,1), strides=1, use_bias=use_bias, name='resnetv15_stage4_conv4')(x)\n", | |
" x = nn.BatchNorm(use_running_average=True, momentum=momentum, epsilon=epsilon, dtype=jnp.float32, name='resnetv15_stage4_batchnorm4')(x)\n", | |
" x = nn.relu(x + res)\n", | |
"\n", | |
" # Global AVG pool\n", | |
" x = jnp.mean(x, axis=(1, 2))\n", | |
"\n", | |
" # Flatten\n", | |
" num_of_classes = 1000 # TODO: get this value from constructor or args\n", | |
" x = nn.Dense(num_of_classes, dtype=jnp.float32, name='resnetv15_dense0')(x)\n", | |
"\n", | |
" x = jnp.asarray(x, jnp.float32)\n", | |
"\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "jwmAKMEKZl51" | |
}, | |
"source": [ | |
"# Taking a look in the model\n", | |
"Let's check what we have in the model. First, it is initializing the model and then we can take a look on `params`and `batch_stats`. Those dictionaries we need to fill with params broght from Onnx file.\n", | |
"\n", | |
"Notice in the below cell result we can see the tensor shapes expected by Flax. This will be useful on loading from Onnx." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "_isXLbsFZlSB", | |
"outputId": "de057b86-e16c-4955-83d3-a91727b83b5e" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Some examples of params:\n", | |
"resnetv15_conv0 \n", | |
" FrozenDict({\n", | |
" kernel: (7, 7, 3, 64),\n", | |
"}) \n", | |
"\n", | |
"resnetv15_batchnorm0 \n", | |
" FrozenDict({\n", | |
" bias: (64,),\n", | |
" scale: (64,),\n", | |
"}) \n", | |
"\n", | |
"resnetv15_stage1_conv0 \n", | |
" FrozenDict({\n", | |
" kernel: (3, 3, 64, 64),\n", | |
"}) \n", | |
"\n", | |
"resnetv15_stage1_batchnorm0 \n", | |
" FrozenDict({\n", | |
" bias: (64,),\n", | |
" scale: (64,),\n", | |
"}) \n", | |
"\n", | |
"resnetv15_dense0 \n", | |
" FrozenDict({\n", | |
" bias: (1000,),\n", | |
" kernel: (512, 1000),\n", | |
"}) \n", | |
"\n", | |
"Some examples of batch_stats:\n", | |
"resnetv15_batchnorm0 \n", | |
" FrozenDict({\n", | |
" mean: (64,),\n", | |
" var: (64,),\n", | |
"}) \n", | |
"\n", | |
"resnetv15_stage1_batchnorm0 \n", | |
" FrozenDict({\n", | |
" mean: (64,),\n", | |
" var: (64,),\n", | |
"}) \n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"cnn = ResNet18()\n", | |
"rng = jax.random.PRNGKey(0)\n", | |
"rng, init_rng = jax.random.split(rng)\n", | |
"init = cnn.init(rng, jnp.ones(FLAX_INPUT_SHAPE))\n", | |
"\n", | |
"def print_param(dic, name):\n", | |
" print(name, '\\n', jax.tree_map(lambda x: x.shape, dic[name]), '\\n')\n", | |
"\n", | |
"# Uncomment this to check all params\n", | |
"#print(jax.tree_map(lambda x: x.shape, init['params']))\n", | |
"\n", | |
"print('Some examples of params:')\n", | |
"print_param(init['params'], 'resnetv15_conv0')\n", | |
"print_param(init['params'], 'resnetv15_batchnorm0')\n", | |
"print_param(init['params'], 'resnetv15_stage1_conv0')\n", | |
"print_param(init['params'], 'resnetv15_stage1_batchnorm0')\n", | |
"print_param(init['params'], 'resnetv15_dense0')\n", | |
"\n", | |
"print('Some examples of batch_stats:')\n", | |
"print_param(init['batch_stats'],'resnetv15_batchnorm0')\n", | |
"print_param(init['batch_stats'],'resnetv15_stage1_batchnorm0')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "VQ78Lg0AfNwP" | |
}, | |
"source": [ | |
"# Loading parameters from ONNX\n", | |
"With `onnx` library we can open the Onnx file and iterate through its graph. I'm parsing the names and matching with that names I added for every layer on Flax module definition.\n", | |
"\n", | |
"The trick here is pay attention on the tensor shapes used in Onnx file and what Flax actually expects. Convolutional layers from Onnx is using NCHW while Flax's expects NHWC. Also, for Fully Connected layers, Onnx brings in [outC, inC] while Flaxs uses [inC, outC]. We need to transpose it on the fly before push it to our params dictionary.\n", | |
"\n", | |
"Notice that some parameters names are different in Onnx and Flax, for example, `beta` in onnx is called `bias` in Flax.\n", | |
"\n", | |
"## References for this section:\n", | |
"- [1] https://flax.readthedocs.io/en/latest/howtos/convert_pytorch_to_flax.html" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"id": "FrsExnKhfMUp" | |
}, | |
"outputs": [], | |
"source": [ | |
"model = onnx.load(onnx_model_path)\n", | |
"params = {}\n", | |
"batch_stats = {}\n", | |
"\n", | |
"def initialize(d, key):\n", | |
" if not key in d.keys():\n", | |
" d[key] = {}\n", | |
" \n", | |
"for initializer in model.graph.initializer:\n", | |
" # Debug only\n", | |
" #print(f'- Tensor: {initializer.name!r:45} shape={initializer.dims}')\n", | |
"\n", | |
" if 'stage' in initializer.name:\n", | |
" key = '_'.join(initializer.name.split('_')[:3])\n", | |
" else:\n", | |
" key = '_'.join(initializer.name.split('_')[:2])\n", | |
"\n", | |
" # Debug only\n", | |
" #print('Loading', initializer.name, 'into', key)\n", | |
"\n", | |
" if '_mean' in initializer.name:\n", | |
" # BatchNorm uses _mean suffix (batch_stats)\n", | |
" initialize(batch_stats, key)\n", | |
" batch_stats[key]['mean'] = numpy_helper.to_array(initializer)\n", | |
"\n", | |
" elif '_var' in initializer.name:\n", | |
" # BatchNorm uses _var suffix (batch_stats)\n", | |
" initialize(batch_stats, key)\n", | |
" batch_stats[key]['var'] = numpy_helper.to_array(initializer)\n", | |
"\n", | |
" elif '_beta' in initializer.name:\n", | |
" # Batchnorm uses '_beta' suffix\n", | |
" initialize(params, key)\n", | |
" params[key]['bias'] = numpy_helper.to_array(initializer)\n", | |
" elif '_gamma' in initializer.name:\n", | |
" # Batchnorm uses _gamma suffix (batch_stats)\n", | |
" initialize(params, key)\n", | |
" params[key]['scale'] = numpy_helper.to_array(initializer)\n", | |
"\n", | |
" elif '_weight' in initializer.name:\n", | |
" # Conv uses _weight layer\n", | |
" initialize(params, key)\n", | |
"\n", | |
" weights = numpy_helper.to_array(initializer)\n", | |
"\n", | |
" if len(weights.shape) == 2:\n", | |
" weights = np.transpose(weights, (1, 0))\n", | |
" params[key]['kernel'] = weights\n", | |
" elif len(weights.shape) == 4:\n", | |
" weights = np.transpose(weights, (2, 3, 1, 0))\n", | |
" params[key]['kernel'] = weights\n", | |
" else:\n", | |
" print('WARNING: This should not happen...')\n", | |
"\n", | |
" elif '_bias' in initializer.name:\n", | |
" # Dense layer uses _bias suffix\n", | |
" initialize(params, key)\n", | |
" params[key]['bias'] = numpy_helper.to_array(initializer)\n", | |
" else:\n", | |
" print('WARNING: I dont know what I should do with this:', key, initializer.name)\n", | |
"\n", | |
"params = freeze(params)\n", | |
"batch_stats = freeze(batch_stats)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "rnPwahUmvPer" | |
}, | |
"source": [ | |
"Let's check if params and batch_stats looks like we got in the initialization above." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Yn5TV9rDjT74", | |
"outputId": "d0b55f1c-9f12-4bf7-843f-1c934b9cf120" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Some examples of params:\n", | |
"resnetv15_conv0 \n", | |
" FrozenDict({\n", | |
" kernel: (7, 7, 3, 64),\n", | |
"}) \n", | |
"\n", | |
"resnetv15_batchnorm0 \n", | |
" FrozenDict({\n", | |
" bias: (64,),\n", | |
" scale: (64,),\n", | |
"}) \n", | |
"\n", | |
"resnetv15_stage1_conv0 \n", | |
" FrozenDict({\n", | |
" kernel: (3, 3, 64, 64),\n", | |
"}) \n", | |
"\n", | |
"resnetv15_stage1_batchnorm0 \n", | |
" FrozenDict({\n", | |
" bias: (64,),\n", | |
" scale: (64,),\n", | |
"}) \n", | |
"\n", | |
"resnetv15_dense0 \n", | |
" FrozenDict({\n", | |
" bias: (1000,),\n", | |
" kernel: (512, 1000),\n", | |
"}) \n", | |
"\n", | |
"Some examples of batch_stats:\n", | |
"resnetv15_batchnorm0 \n", | |
" FrozenDict({\n", | |
" mean: (64,),\n", | |
" var: (64,),\n", | |
"}) \n", | |
"\n", | |
"resnetv15_stage1_batchnorm0 \n", | |
" FrozenDict({\n", | |
" mean: (64,),\n", | |
" var: (64,),\n", | |
"}) \n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"print('Some examples of params:')\n", | |
"print_param(params, 'resnetv15_conv0')\n", | |
"print_param(params, 'resnetv15_batchnorm0')\n", | |
"print_param(params, 'resnetv15_stage1_conv0')\n", | |
"print_param(params, 'resnetv15_stage1_batchnorm0')\n", | |
"print_param(params, 'resnetv15_dense0')\n", | |
"\n", | |
"print('Some examples of batch_stats:')\n", | |
"print_param(batch_stats,'resnetv15_batchnorm0')\n", | |
"print_param(batch_stats,'resnetv15_stage1_batchnorm0')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "wZ8s9rc4wzVU" | |
}, | |
"source": [ | |
"# Pre-processing images\n", | |
"\n", | |
"Here I'm defining a function that opens an image, resizes to the right input dimension, and make some color adjustments, following [1].\n", | |
"\n", | |
"\n", | |
"## References for this section\n", | |
"- [1] https://github.com/MO436-MC934/work/blob/bb5ef4e747011d5d045b7373e35492b836aae27e/models/resnet18/main.cpp#L62 \n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"id": "4I0XYsw6ndIv" | |
}, | |
"outputs": [], | |
"source": [ | |
"def get_image(path, show=False):\n", | |
" img = cv2.imread(path)\n", | |
"\n", | |
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", | |
" img = cv2.resize(img, (INPUT_IMAGE_HEIGHT, INPUT_IMAGE_WIDTH))\n", | |
" \n", | |
" if show:\n", | |
" plt.imshow(img)\n", | |
" plt.show()\n", | |
"\n", | |
" mean=[0.485, 0.456, 0.406]\n", | |
" std=[0.229, 0.224, 0.225]\n", | |
"\n", | |
" img = img/255.0\n", | |
" img[:,:,0] = (img[:,:,0] - mean[0]) / std[0]\n", | |
" img[:,:,1] = (img[:,:,1] - mean[1]) / std[1]\n", | |
" img[:,:,2] = (img[:,:,2] - mean[2]) / std[2]\n", | |
"\n", | |
" return img" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "NS9MkTl_wifT" | |
}, | |
"source": [ | |
"# Testing inference...\n", | |
"\n", | |
"Just running a very basic inference and checking if the result is correct." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 286 | |
}, | |
"id": "C0HyBNeVniKh", | |
"outputId": "001ddb65-59da-4806-f196-b7a82ab00b8f" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Correct value: 188; Predicted value: 188\n" | |
] | |
} | |
], | |
"source": [ | |
"def get_top1(activations):\n", | |
" '''\n", | |
" Get predicted labels... We don't care about performance here\n", | |
" '''\n", | |
" preds = []\n", | |
"\n", | |
" for activation in activations:\n", | |
" pred = jnp.argmax(activation)\n", | |
" preds.append(labels[pred])\n", | |
"\n", | |
" return preds\n", | |
"\n", | |
"def predict(input, params, batch_stats):\n", | |
" return ResNet18().apply({'params': params, 'batch_stats': batch_stats}, input)\n", | |
"\n", | |
"\n", | |
"def predict_single(image):\n", | |
" input = jnp.expand_dims(image, axis=0)\n", | |
" return ResNet18().apply({'params': params, 'batch_stats': batch_stats}, input)\n", | |
"\n", | |
"rabbit = get_image('./rabbit.png', show=True)\n", | |
"rabbit_pred_label = get_top1(predict_single(rabbit))[0]\n", | |
"print('Correct value: 188; Predicted value:', rabbit_pred_label)\n", | |
"assert 188 == rabbit_pred_label" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "BwYG04avTjn7" | |
}, | |
"source": [ | |
"# Performance time!\n", | |
"\n", | |
"The goal here will be predict 1000 images at once in Flax. So input shape will be `1000x224x224x3`. \n", | |
"\n", | |
"Jax runs our code in GPU in a very transparent way, so the first step will clearly define what is running on GPU and what is running on CPU.\n", | |
"\n", | |
"## References for this section:\n", | |
"- [1] https://jax.readthedocs.io/en/latest/notebooks/quickstart.html?highlight=vmap#jax-quickstart" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"predict_cpu = jit(predict, backend='cpu')\n", | |
"make_jaxpr(predict_cpu)(jnp.ones((1000, 224,224,3)), params, batch_stats)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "97NYg1jZhytz", | |
"outputId": "bc214383-cdde-4c83-9cd6-0ac2d11b98be" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[1000,224,224,3]\u001b[39m b\u001b[35m:f32[64]\u001b[39m c\u001b[35m:f32[64]\u001b[39m d\u001b[35m:f32[7,7,3,64]\u001b[39m e\u001b[35m:f32[1000]\u001b[39m\n", | |
" f\u001b[35m:f32[512,1000]\u001b[39m g\u001b[35m:f32[64]\u001b[39m h\u001b[35m:f32[64]\u001b[39m i\u001b[35m:f32[64]\u001b[39m j\u001b[35m:f32[64]\u001b[39m k\u001b[35m:f32[64]\u001b[39m l\u001b[35m:f32[64]\u001b[39m m\u001b[35m:f32[64]\u001b[39m\n", | |
" n\u001b[35m:f32[64]\u001b[39m o\u001b[35m:f32[3,3,64,64]\u001b[39m p\u001b[35m:f32[3,3,64,64]\u001b[39m q\u001b[35m:f32[3,3,64,64]\u001b[39m r\u001b[35m:f32[3,3,64,64]\u001b[39m\n", | |
" s\u001b[35m:f32[128]\u001b[39m t\u001b[35m:f32[128]\u001b[39m u\u001b[35m:f32[128]\u001b[39m v\u001b[35m:f32[128]\u001b[39m w\u001b[35m:f32[128]\u001b[39m x\u001b[35m:f32[128]\u001b[39m y\u001b[35m:f32[128]\u001b[39m\n", | |
" z\u001b[35m:f32[128]\u001b[39m ba\u001b[35m:f32[128]\u001b[39m bb\u001b[35m:f32[128]\u001b[39m bc\u001b[35m:f32[3,3,64,128]\u001b[39m bd\u001b[35m:f32[3,3,128,128]\u001b[39m be\u001b[35m:f32[1,1,64,128]\u001b[39m\n", | |
" bf\u001b[35m:f32[3,3,128,128]\u001b[39m bg\u001b[35m:f32[3,3,128,128]\u001b[39m bh\u001b[35m:f32[256]\u001b[39m bi\u001b[35m:f32[256]\u001b[39m bj\u001b[35m:f32[256]\u001b[39m bk\u001b[35m:f32[256]\u001b[39m\n", | |
" bl\u001b[35m:f32[256]\u001b[39m bm\u001b[35m:f32[256]\u001b[39m bn\u001b[35m:f32[256]\u001b[39m bo\u001b[35m:f32[256]\u001b[39m bp\u001b[35m:f32[256]\u001b[39m bq\u001b[35m:f32[256]\u001b[39m br\u001b[35m:f32[3,3,128,256]\u001b[39m\n", | |
" bs\u001b[35m:f32[3,3,256,256]\u001b[39m bt\u001b[35m:f32[1,1,128,256]\u001b[39m bu\u001b[35m:f32[3,3,256,256]\u001b[39m bv\u001b[35m:f32[3,3,256,256]\u001b[39m\n", | |
" bw\u001b[35m:f32[512]\u001b[39m bx\u001b[35m:f32[512]\u001b[39m by\u001b[35m:f32[512]\u001b[39m bz\u001b[35m:f32[512]\u001b[39m ca\u001b[35m:f32[512]\u001b[39m cb\u001b[35m:f32[512]\u001b[39m cc\u001b[35m:f32[512]\u001b[39m\n", | |
" cd\u001b[35m:f32[512]\u001b[39m ce\u001b[35m:f32[512]\u001b[39m cf\u001b[35m:f32[512]\u001b[39m cg\u001b[35m:f32[3,3,256,512]\u001b[39m ch\u001b[35m:f32[3,3,512,512]\u001b[39m ci\u001b[35m:f32[1,1,256,512]\u001b[39m\n", | |
" cj\u001b[35m:f32[3,3,512,512]\u001b[39m ck\u001b[35m:f32[3,3,512,512]\u001b[39m cl\u001b[35m:f32[64]\u001b[39m cm\u001b[35m:f32[64]\u001b[39m cn\u001b[35m:f32[64]\u001b[39m co\u001b[35m:f32[64]\u001b[39m\n", | |
" cp\u001b[35m:f32[64]\u001b[39m cq\u001b[35m:f32[64]\u001b[39m cr\u001b[35m:f32[64]\u001b[39m cs\u001b[35m:f32[64]\u001b[39m ct\u001b[35m:f32[64]\u001b[39m cu\u001b[35m:f32[64]\u001b[39m cv\u001b[35m:f32[128]\u001b[39m\n", | |
" cw\u001b[35m:f32[128]\u001b[39m cx\u001b[35m:f32[128]\u001b[39m cy\u001b[35m:f32[128]\u001b[39m cz\u001b[35m:f32[128]\u001b[39m da\u001b[35m:f32[128]\u001b[39m db\u001b[35m:f32[128]\u001b[39m dc\u001b[35m:f32[128]\u001b[39m\n", | |
" dd\u001b[35m:f32[128]\u001b[39m de\u001b[35m:f32[128]\u001b[39m df\u001b[35m:f32[256]\u001b[39m dg\u001b[35m:f32[256]\u001b[39m dh\u001b[35m:f32[256]\u001b[39m di\u001b[35m:f32[256]\u001b[39m dj\u001b[35m:f32[256]\u001b[39m\n", | |
" dk\u001b[35m:f32[256]\u001b[39m dl\u001b[35m:f32[256]\u001b[39m dm\u001b[35m:f32[256]\u001b[39m dn\u001b[35m:f32[256]\u001b[39m do\u001b[35m:f32[256]\u001b[39m dp\u001b[35m:f32[512]\u001b[39m dq\u001b[35m:f32[512]\u001b[39m\n", | |
" dr\u001b[35m:f32[512]\u001b[39m ds\u001b[35m:f32[512]\u001b[39m dt\u001b[35m:f32[512]\u001b[39m du\u001b[35m:f32[512]\u001b[39m dv\u001b[35m:f32[512]\u001b[39m dw\u001b[35m:f32[512]\u001b[39m dx\u001b[35m:f32[512]\u001b[39m\n", | |
" dy\u001b[35m:f32[512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mdz\u001b[35m:f32[1000,1000]\u001b[39m = xla_call[\n", | |
" backend=cpu\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ea\u001b[35m:f32[1000,224,224,3]\u001b[39m eb\u001b[35m:f32[64]\u001b[39m ec\u001b[35m:f32[64]\u001b[39m ed\u001b[35m:f32[7,7,3,64]\u001b[39m\n", | |
" ee\u001b[35m:f32[1000]\u001b[39m ef\u001b[35m:f32[512,1000]\u001b[39m eg\u001b[35m:f32[64]\u001b[39m eh\u001b[35m:f32[64]\u001b[39m ei\u001b[35m:f32[64]\u001b[39m ej\u001b[35m:f32[64]\u001b[39m\n", | |
" ek\u001b[35m:f32[64]\u001b[39m el\u001b[35m:f32[64]\u001b[39m em\u001b[35m:f32[64]\u001b[39m en\u001b[35m:f32[64]\u001b[39m eo\u001b[35m:f32[3,3,64,64]\u001b[39m ep\u001b[35m:f32[3,3,64,64]\u001b[39m\n", | |
" eq\u001b[35m:f32[3,3,64,64]\u001b[39m er\u001b[35m:f32[3,3,64,64]\u001b[39m es\u001b[35m:f32[128]\u001b[39m et\u001b[35m:f32[128]\u001b[39m eu\u001b[35m:f32[128]\u001b[39m\n", | |
" ev\u001b[35m:f32[128]\u001b[39m ew\u001b[35m:f32[128]\u001b[39m ex\u001b[35m:f32[128]\u001b[39m ey\u001b[35m:f32[128]\u001b[39m ez\u001b[35m:f32[128]\u001b[39m fa\u001b[35m:f32[128]\u001b[39m\n", | |
" fb\u001b[35m:f32[128]\u001b[39m fc\u001b[35m:f32[3,3,64,128]\u001b[39m fd\u001b[35m:f32[3,3,128,128]\u001b[39m fe\u001b[35m:f32[1,1,64,128]\u001b[39m ff\u001b[35m:f32[3,3,128,128]\u001b[39m\n", | |
" fg\u001b[35m:f32[3,3,128,128]\u001b[39m fh\u001b[35m:f32[256]\u001b[39m fi\u001b[35m:f32[256]\u001b[39m fj\u001b[35m:f32[256]\u001b[39m fk\u001b[35m:f32[256]\u001b[39m fl\u001b[35m:f32[256]\u001b[39m\n", | |
" fm\u001b[35m:f32[256]\u001b[39m fn\u001b[35m:f32[256]\u001b[39m fo\u001b[35m:f32[256]\u001b[39m fp\u001b[35m:f32[256]\u001b[39m fq\u001b[35m:f32[256]\u001b[39m fr\u001b[35m:f32[3,3,128,256]\u001b[39m\n", | |
" fs\u001b[35m:f32[3,3,256,256]\u001b[39m ft\u001b[35m:f32[1,1,128,256]\u001b[39m fu\u001b[35m:f32[3,3,256,256]\u001b[39m fv\u001b[35m:f32[3,3,256,256]\u001b[39m\n", | |
" fw\u001b[35m:f32[512]\u001b[39m fx\u001b[35m:f32[512]\u001b[39m fy\u001b[35m:f32[512]\u001b[39m fz\u001b[35m:f32[512]\u001b[39m ga\u001b[35m:f32[512]\u001b[39m gb\u001b[35m:f32[512]\u001b[39m\n", | |
" gc\u001b[35m:f32[512]\u001b[39m gd\u001b[35m:f32[512]\u001b[39m ge\u001b[35m:f32[512]\u001b[39m gf\u001b[35m:f32[512]\u001b[39m gg\u001b[35m:f32[3,3,256,512]\u001b[39m gh\u001b[35m:f32[3,3,512,512]\u001b[39m\n", | |
" gi\u001b[35m:f32[1,1,256,512]\u001b[39m gj\u001b[35m:f32[3,3,512,512]\u001b[39m gk\u001b[35m:f32[3,3,512,512]\u001b[39m gl\u001b[35m:f32[64]\u001b[39m\n", | |
" gm\u001b[35m:f32[64]\u001b[39m gn\u001b[35m:f32[64]\u001b[39m go\u001b[35m:f32[64]\u001b[39m gp\u001b[35m:f32[64]\u001b[39m gq\u001b[35m:f32[64]\u001b[39m gr\u001b[35m:f32[64]\u001b[39m gs\u001b[35m:f32[64]\u001b[39m\n", | |
" gt\u001b[35m:f32[64]\u001b[39m gu\u001b[35m:f32[64]\u001b[39m gv\u001b[35m:f32[128]\u001b[39m gw\u001b[35m:f32[128]\u001b[39m gx\u001b[35m:f32[128]\u001b[39m gy\u001b[35m:f32[128]\u001b[39m gz\u001b[35m:f32[128]\u001b[39m\n", | |
" ha\u001b[35m:f32[128]\u001b[39m hb\u001b[35m:f32[128]\u001b[39m hc\u001b[35m:f32[128]\u001b[39m hd\u001b[35m:f32[128]\u001b[39m he\u001b[35m:f32[128]\u001b[39m hf\u001b[35m:f32[256]\u001b[39m\n", | |
" hg\u001b[35m:f32[256]\u001b[39m hh\u001b[35m:f32[256]\u001b[39m hi\u001b[35m:f32[256]\u001b[39m hj\u001b[35m:f32[256]\u001b[39m hk\u001b[35m:f32[256]\u001b[39m hl\u001b[35m:f32[256]\u001b[39m\n", | |
" hm\u001b[35m:f32[256]\u001b[39m hn\u001b[35m:f32[256]\u001b[39m ho\u001b[35m:f32[256]\u001b[39m hp\u001b[35m:f32[512]\u001b[39m hq\u001b[35m:f32[512]\u001b[39m hr\u001b[35m:f32[512]\u001b[39m\n", | |
" hs\u001b[35m:f32[512]\u001b[39m ht\u001b[35m:f32[512]\u001b[39m hu\u001b[35m:f32[512]\u001b[39m hv\u001b[35m:f32[512]\u001b[39m hw\u001b[35m:f32[512]\u001b[39m hx\u001b[35m:f32[512]\u001b[39m\n", | |
" hy\u001b[35m:f32[512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mhz\u001b[35m:f32[1000,112,112,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 224, 224, 3)\n", | |
" padding=((3, 3), (3, 3))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(7, 7, 3, 64)\n", | |
" window_strides=(2, 2)\n", | |
" ] ea ed\n", | |
" ia\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gl\n", | |
" ib\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gm\n", | |
" ic\u001b[35m:f32[1000,112,112,64]\u001b[39m = sub hz ia\n", | |
" id\u001b[35m:f32[1,1,1,64]\u001b[39m = add ib 9.999999747378752e-06\n", | |
" ie\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt id\n", | |
" if\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] ec\n", | |
" ig\u001b[35m:f32[1,1,1,64]\u001b[39m = mul ie if\n", | |
" ih\u001b[35m:f32[1000,112,112,64]\u001b[39m = mul ic ig\n", | |
" ii\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] eb\n", | |
" ij\u001b[35m:f32[1000,112,112,64]\u001b[39m = add ih ii\n", | |
" ik\u001b[35m:f32[1000,112,112,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; il\u001b[35m:f32[1000,112,112,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mim\u001b[35m:f32[1000,112,112,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; in\u001b[35m:f32[1000,112,112,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mio\u001b[35m:f32[1000,112,112,64]\u001b[39m = max in 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(io,) }\n", | |
" name=relu\n", | |
" ] il\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(im,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a10fa70>\n", | |
" num_consts=0\n", | |
" ] ij\n", | |
" ip\u001b[35m:f32[1000,56,56,64]\u001b[39m = reduce_window_max[\n", | |
" base_dilation=(1, 1, 1, 1)\n", | |
" padding=((0, 0), (1, 1), (1, 1), (0, 0))\n", | |
" window_dilation=(1, 1, 1, 1)\n", | |
" window_dimensions=(1, 3, 3, 1)\n", | |
" window_strides=(1, 2, 2, 1)\n", | |
" ] ik\n", | |
" iq\u001b[35m:f32[1000,56,56,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 64)\n", | |
" window_strides=(1, 1)\n", | |
" ] ip eo\n", | |
" ir\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gn\n", | |
" is\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] go\n", | |
" it\u001b[35m:f32[1000,56,56,64]\u001b[39m = sub iq ir\n", | |
" iu\u001b[35m:f32[1,1,1,64]\u001b[39m = add is 9.999999747378752e-06\n", | |
" iv\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt iu\n", | |
" iw\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] eh\n", | |
" ix\u001b[35m:f32[1,1,1,64]\u001b[39m = mul iv iw\n", | |
" iy\u001b[35m:f32[1000,56,56,64]\u001b[39m = mul it ix\n", | |
" iz\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] eg\n", | |
" ja\u001b[35m:f32[1000,56,56,64]\u001b[39m = add iy iz\n", | |
" jb\u001b[35m:f32[1000,56,56,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; jc\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mjd\u001b[35m:f32[1000,56,56,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; je\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mjf\u001b[35m:f32[1000,56,56,64]\u001b[39m = max je 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(jf,) }\n", | |
" name=relu\n", | |
" ] jc\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(jd,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a10f440>\n", | |
" num_consts=0\n", | |
" ] ja\n", | |
" jg\u001b[35m:f32[1000,56,56,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 64)\n", | |
" window_strides=(1, 1)\n", | |
" ] jb ep\n", | |
" jh\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gp\n", | |
" ji\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gq\n", | |
" jj\u001b[35m:f32[1000,56,56,64]\u001b[39m = sub jg jh\n", | |
" jk\u001b[35m:f32[1,1,1,64]\u001b[39m = add ji 9.999999747378752e-06\n", | |
" jl\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt jk\n", | |
" jm\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] ej\n", | |
" jn\u001b[35m:f32[1,1,1,64]\u001b[39m = mul jl jm\n", | |
" jo\u001b[35m:f32[1000,56,56,64]\u001b[39m = mul jj jn\n", | |
" jp\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] ei\n", | |
" jq\u001b[35m:f32[1000,56,56,64]\u001b[39m = add jo jp\n", | |
" jr\u001b[35m:f32[1000,56,56,64]\u001b[39m = add jq ip\n", | |
" js\u001b[35m:f32[1000,56,56,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; jt\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mju\u001b[35m:f32[1000,56,56,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; jv\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mjw\u001b[35m:f32[1000,56,56,64]\u001b[39m = max jv 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(jw,) }\n", | |
" name=relu\n", | |
" ] jt\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ju,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a10fdd0>\n", | |
" num_consts=0\n", | |
" ] jr\n", | |
" jx\u001b[35m:f32[1000,56,56,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 64)\n", | |
" window_strides=(1, 1)\n", | |
" ] js eq\n", | |
" jy\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gr\n", | |
" jz\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gs\n", | |
" ka\u001b[35m:f32[1000,56,56,64]\u001b[39m = sub jx jy\n", | |
" kb\u001b[35m:f32[1,1,1,64]\u001b[39m = add jz 9.999999747378752e-06\n", | |
" kc\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt kb\n", | |
" kd\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] el\n", | |
" ke\u001b[35m:f32[1,1,1,64]\u001b[39m = mul kc kd\n", | |
" kf\u001b[35m:f32[1000,56,56,64]\u001b[39m = mul ka ke\n", | |
" kg\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] ek\n", | |
" kh\u001b[35m:f32[1000,56,56,64]\u001b[39m = add kf kg\n", | |
" ki\u001b[35m:f32[1000,56,56,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; kj\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mkk\u001b[35m:f32[1000,56,56,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; kl\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mkm\u001b[35m:f32[1000,56,56,64]\u001b[39m = max kl 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(km,) }\n", | |
" name=relu\n", | |
" ] kj\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(kk,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a10f710>\n", | |
" num_consts=0\n", | |
" ] kh\n", | |
" kn\u001b[35m:f32[1000,56,56,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 64)\n", | |
" window_strides=(1, 1)\n", | |
" ] ki er\n", | |
" ko\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gt\n", | |
" kp\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gu\n", | |
" kq\u001b[35m:f32[1000,56,56,64]\u001b[39m = sub kn ko\n", | |
" kr\u001b[35m:f32[1,1,1,64]\u001b[39m = add kp 9.999999747378752e-06\n", | |
" ks\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt kr\n", | |
" kt\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] en\n", | |
" ku\u001b[35m:f32[1,1,1,64]\u001b[39m = mul ks kt\n", | |
" kv\u001b[35m:f32[1000,56,56,64]\u001b[39m = mul kq ku\n", | |
" kw\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] em\n", | |
" kx\u001b[35m:f32[1000,56,56,64]\u001b[39m = add kv kw\n", | |
" ky\u001b[35m:f32[1000,56,56,64]\u001b[39m = add kx js\n", | |
" kz\u001b[35m:f32[1000,56,56,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; la\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mlb\u001b[35m:f32[1000,56,56,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; lc\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mld\u001b[35m:f32[1000,56,56,64]\u001b[39m = max lc 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ld,) }\n", | |
" name=relu\n", | |
" ] la\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(lb,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a10f4d0>\n", | |
" num_consts=0\n", | |
" ] ky\n", | |
" le\u001b[35m:f32[1000,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 128)\n", | |
" window_strides=(2, 2)\n", | |
" ] kz fc\n", | |
" lf\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] gv\n", | |
" lg\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] gw\n", | |
" lh\u001b[35m:f32[1000,28,28,128]\u001b[39m = sub le lf\n", | |
" li\u001b[35m:f32[1,1,1,128]\u001b[39m = add lg 9.999999747378752e-06\n", | |
" lj\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt li\n", | |
" lk\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] et\n", | |
" ll\u001b[35m:f32[1,1,1,128]\u001b[39m = mul lj lk\n", | |
" lm\u001b[35m:f32[1000,28,28,128]\u001b[39m = mul lh ll\n", | |
" ln\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] es\n", | |
" lo\u001b[35m:f32[1000,28,28,128]\u001b[39m = add lm ln\n", | |
" lp\u001b[35m:f32[1000,28,28,128]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; lq\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mlr\u001b[35m:f32[1000,28,28,128]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ls\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mlt\u001b[35m:f32[1000,28,28,128]\u001b[39m = max ls 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(lt,) }\n", | |
" name=relu\n", | |
" ] lq\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(lr,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a10fd40>\n", | |
" num_consts=0\n", | |
" ] lo\n", | |
" lu\u001b[35m:f32[1000,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 28, 28, 128)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 128, 128)\n", | |
" window_strides=(1, 1)\n", | |
" ] lp fd\n", | |
" lv\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] gx\n", | |
" lw\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] gy\n", | |
" lx\u001b[35m:f32[1000,28,28,128]\u001b[39m = sub lu lv\n", | |
" ly\u001b[35m:f32[1,1,1,128]\u001b[39m = add lw 9.999999747378752e-06\n", | |
" lz\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt ly\n", | |
" ma\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] ev\n", | |
" mb\u001b[35m:f32[1,1,1,128]\u001b[39m = mul lz ma\n", | |
" mc\u001b[35m:f32[1000,28,28,128]\u001b[39m = mul lx mb\n", | |
" md\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] eu\n", | |
" me\u001b[35m:f32[1000,28,28,128]\u001b[39m = add mc md\n", | |
" mf\u001b[35m:f32[1000,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 56, 56, 64)\n", | |
" padding=((0, 0), (0, 0))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(1, 1, 64, 128)\n", | |
" window_strides=(2, 2)\n", | |
" ] kz fe\n", | |
" mg\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] gz\n", | |
" mh\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] ha\n", | |
" mi\u001b[35m:f32[1000,28,28,128]\u001b[39m = sub mf mg\n", | |
" mj\u001b[35m:f32[1,1,1,128]\u001b[39m = add mh 9.999999747378752e-06\n", | |
" mk\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt mj\n", | |
" ml\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] ex\n", | |
" mm\u001b[35m:f32[1,1,1,128]\u001b[39m = mul mk ml\n", | |
" mn\u001b[35m:f32[1000,28,28,128]\u001b[39m = mul mi mm\n", | |
" mo\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] ew\n", | |
" mp\u001b[35m:f32[1000,28,28,128]\u001b[39m = add mn mo\n", | |
" mq\u001b[35m:f32[1000,28,28,128]\u001b[39m = add me mp\n", | |
" mr\u001b[35m:f32[1000,28,28,128]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ms\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mmt\u001b[35m:f32[1000,28,28,128]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; mu\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mmv\u001b[35m:f32[1000,28,28,128]\u001b[39m = max mu 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(mv,) }\n", | |
" name=relu\n", | |
" ] ms\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(mt,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a10f5f0>\n", | |
" num_consts=0\n", | |
" ] mq\n", | |
" mw\u001b[35m:f32[1000,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 28, 28, 128)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 128, 128)\n", | |
" window_strides=(1, 1)\n", | |
" ] mr ff\n", | |
" mx\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] hb\n", | |
" my\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] hc\n", | |
" mz\u001b[35m:f32[1000,28,28,128]\u001b[39m = sub mw mx\n", | |
" na\u001b[35m:f32[1,1,1,128]\u001b[39m = add my 9.999999747378752e-06\n", | |
" nb\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt na\n", | |
" nc\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] ez\n", | |
" nd\u001b[35m:f32[1,1,1,128]\u001b[39m = mul nb nc\n", | |
" ne\u001b[35m:f32[1000,28,28,128]\u001b[39m = mul mz nd\n", | |
" nf\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] ey\n", | |
" ng\u001b[35m:f32[1000,28,28,128]\u001b[39m = add ne nf\n", | |
" nh\u001b[35m:f32[1000,28,28,128]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ni\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mnj\u001b[35m:f32[1000,28,28,128]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; nk\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mnl\u001b[35m:f32[1000,28,28,128]\u001b[39m = max nk 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(nl,) }\n", | |
" name=relu\n", | |
" ] ni\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(nj,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a10f7a0>\n", | |
" num_consts=0\n", | |
" ] ng\n", | |
" nm\u001b[35m:f32[1000,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 28, 28, 128)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 128, 128)\n", | |
" window_strides=(1, 1)\n", | |
" ] nh fg\n", | |
" nn\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] hd\n", | |
" no\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] he\n", | |
" np\u001b[35m:f32[1000,28,28,128]\u001b[39m = sub nm nn\n", | |
" nq\u001b[35m:f32[1,1,1,128]\u001b[39m = add no 9.999999747378752e-06\n", | |
" nr\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt nq\n", | |
" ns\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] fb\n", | |
" nt\u001b[35m:f32[1,1,1,128]\u001b[39m = mul nr ns\n", | |
" nu\u001b[35m:f32[1000,28,28,128]\u001b[39m = mul np nt\n", | |
" nv\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] fa\n", | |
" nw\u001b[35m:f32[1000,28,28,128]\u001b[39m = add nu nv\n", | |
" nx\u001b[35m:f32[1000,28,28,128]\u001b[39m = add nw mr\n", | |
" ny\u001b[35m:f32[1000,28,28,128]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; nz\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22moa\u001b[35m:f32[1000,28,28,128]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ob\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22moc\u001b[35m:f32[1000,28,28,128]\u001b[39m = max ob 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(oc,) }\n", | |
" name=relu\n", | |
" ] nz\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(oa,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a10f8c0>\n", | |
" num_consts=0\n", | |
" ] nx\n", | |
" od\u001b[35m:f32[1000,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 28, 28, 128)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 128, 256)\n", | |
" window_strides=(2, 2)\n", | |
" ] ny fr\n", | |
" oe\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hf\n", | |
" of\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hg\n", | |
" og\u001b[35m:f32[1000,14,14,256]\u001b[39m = sub od oe\n", | |
" oh\u001b[35m:f32[1,1,1,256]\u001b[39m = add of 9.999999747378752e-06\n", | |
" oi\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt oh\n", | |
" oj\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fi\n", | |
" ok\u001b[35m:f32[1,1,1,256]\u001b[39m = mul oi oj\n", | |
" ol\u001b[35m:f32[1000,14,14,256]\u001b[39m = mul og ok\n", | |
" om\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fh\n", | |
" on\u001b[35m:f32[1000,14,14,256]\u001b[39m = add ol om\n", | |
" oo\u001b[35m:f32[1000,14,14,256]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; op\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22moq\u001b[35m:f32[1000,14,14,256]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; or\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mos\u001b[35m:f32[1000,14,14,256]\u001b[39m = max or 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(os,) }\n", | |
" name=relu\n", | |
" ] op\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(oq,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a10fcb0>\n", | |
" num_consts=0\n", | |
" ] on\n", | |
" ot\u001b[35m:f32[1000,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 14, 14, 256)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 256, 256)\n", | |
" window_strides=(1, 1)\n", | |
" ] oo fs\n", | |
" ou\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hh\n", | |
" ov\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hi\n", | |
" ow\u001b[35m:f32[1000,14,14,256]\u001b[39m = sub ot ou\n", | |
" ox\u001b[35m:f32[1,1,1,256]\u001b[39m = add ov 9.999999747378752e-06\n", | |
" oy\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt ox\n", | |
" oz\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fk\n", | |
" pa\u001b[35m:f32[1,1,1,256]\u001b[39m = mul oy oz\n", | |
" pb\u001b[35m:f32[1000,14,14,256]\u001b[39m = mul ow pa\n", | |
" pc\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fj\n", | |
" pd\u001b[35m:f32[1000,14,14,256]\u001b[39m = add pb pc\n", | |
" pe\u001b[35m:f32[1000,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 28, 28, 128)\n", | |
" padding=((0, 0), (0, 0))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(1, 1, 128, 256)\n", | |
" window_strides=(2, 2)\n", | |
" ] ny ft\n", | |
" pf\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hj\n", | |
" pg\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hk\n", | |
" ph\u001b[35m:f32[1000,14,14,256]\u001b[39m = sub pe pf\n", | |
" pi\u001b[35m:f32[1,1,1,256]\u001b[39m = add pg 9.999999747378752e-06\n", | |
" pj\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt pi\n", | |
" pk\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fm\n", | |
" pl\u001b[35m:f32[1,1,1,256]\u001b[39m = mul pj pk\n", | |
" pm\u001b[35m:f32[1000,14,14,256]\u001b[39m = mul ph pl\n", | |
" pn\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fl\n", | |
" po\u001b[35m:f32[1000,14,14,256]\u001b[39m = add pm pn\n", | |
" pp\u001b[35m:f32[1000,14,14,256]\u001b[39m = add pd po\n", | |
" pq\u001b[35m:f32[1000,14,14,256]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; pr\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mps\u001b[35m:f32[1000,14,14,256]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; pt\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mpu\u001b[35m:f32[1000,14,14,256]\u001b[39m = max pt 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(pu,) }\n", | |
" name=relu\n", | |
" ] pr\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ps,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a0de440>\n", | |
" num_consts=0\n", | |
" ] pp\n", | |
" pv\u001b[35m:f32[1000,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 14, 14, 256)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 256, 256)\n", | |
" window_strides=(1, 1)\n", | |
" ] pq fu\n", | |
" pw\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hl\n", | |
" px\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hm\n", | |
" py\u001b[35m:f32[1000,14,14,256]\u001b[39m = sub pv pw\n", | |
" pz\u001b[35m:f32[1,1,1,256]\u001b[39m = add px 9.999999747378752e-06\n", | |
" qa\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt pz\n", | |
" qb\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fo\n", | |
" qc\u001b[35m:f32[1,1,1,256]\u001b[39m = mul qa qb\n", | |
" qd\u001b[35m:f32[1000,14,14,256]\u001b[39m = mul py qc\n", | |
" qe\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fn\n", | |
" qf\u001b[35m:f32[1000,14,14,256]\u001b[39m = add qd qe\n", | |
" qg\u001b[35m:f32[1000,14,14,256]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; qh\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mqi\u001b[35m:f32[1000,14,14,256]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; qj\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mqk\u001b[35m:f32[1000,14,14,256]\u001b[39m = max qj 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(qk,) }\n", | |
" name=relu\n", | |
" ] qh\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(qi,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a0ded40>\n", | |
" num_consts=0\n", | |
" ] qf\n", | |
" ql\u001b[35m:f32[1000,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 14, 14, 256)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 256, 256)\n", | |
" window_strides=(1, 1)\n", | |
" ] qg fv\n", | |
" qm\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hn\n", | |
" qn\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] ho\n", | |
" qo\u001b[35m:f32[1000,14,14,256]\u001b[39m = sub ql qm\n", | |
" qp\u001b[35m:f32[1,1,1,256]\u001b[39m = add qn 9.999999747378752e-06\n", | |
" qq\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt qp\n", | |
" qr\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fq\n", | |
" qs\u001b[35m:f32[1,1,1,256]\u001b[39m = mul qq qr\n", | |
" qt\u001b[35m:f32[1000,14,14,256]\u001b[39m = mul qo qs\n", | |
" qu\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fp\n", | |
" qv\u001b[35m:f32[1000,14,14,256]\u001b[39m = add qt qu\n", | |
" qw\u001b[35m:f32[1000,14,14,256]\u001b[39m = add qv pq\n", | |
" qx\u001b[35m:f32[1000,14,14,256]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; qy\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mqz\u001b[35m:f32[1000,14,14,256]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ra\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mrb\u001b[35m:f32[1000,14,14,256]\u001b[39m = max ra 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(rb,) }\n", | |
" name=relu\n", | |
" ] qy\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(qz,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a0de050>\n", | |
" num_consts=0\n", | |
" ] qw\n", | |
" rc\u001b[35m:f32[1000,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 14, 14, 256)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 256, 512)\n", | |
" window_strides=(2, 2)\n", | |
" ] qx gg\n", | |
" rd\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hp\n", | |
" re\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hq\n", | |
" rf\u001b[35m:f32[1000,7,7,512]\u001b[39m = sub rc rd\n", | |
" rg\u001b[35m:f32[1,1,1,512]\u001b[39m = add re 9.999999747378752e-06\n", | |
" rh\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt rg\n", | |
" ri\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] fx\n", | |
" rj\u001b[35m:f32[1,1,1,512]\u001b[39m = mul rh ri\n", | |
" rk\u001b[35m:f32[1000,7,7,512]\u001b[39m = mul rf rj\n", | |
" rl\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] fw\n", | |
" rm\u001b[35m:f32[1000,7,7,512]\u001b[39m = add rk rl\n", | |
" rn\u001b[35m:f32[1000,7,7,512]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ro\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mrp\u001b[35m:f32[1000,7,7,512]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; rq\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mrr\u001b[35m:f32[1000,7,7,512]\u001b[39m = max rq 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(rr,) }\n", | |
" name=relu\n", | |
" ] ro\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(rp,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a0de4d0>\n", | |
" num_consts=0\n", | |
" ] rm\n", | |
" rs\u001b[35m:f32[1000,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 7, 7, 512)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 512, 512)\n", | |
" window_strides=(1, 1)\n", | |
" ] rn gh\n", | |
" rt\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hr\n", | |
" ru\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hs\n", | |
" rv\u001b[35m:f32[1000,7,7,512]\u001b[39m = sub rs rt\n", | |
" rw\u001b[35m:f32[1,1,1,512]\u001b[39m = add ru 9.999999747378752e-06\n", | |
" rx\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt rw\n", | |
" ry\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] fz\n", | |
" rz\u001b[35m:f32[1,1,1,512]\u001b[39m = mul rx ry\n", | |
" sa\u001b[35m:f32[1000,7,7,512]\u001b[39m = mul rv rz\n", | |
" sb\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] fy\n", | |
" sc\u001b[35m:f32[1000,7,7,512]\u001b[39m = add sa sb\n", | |
" sd\u001b[35m:f32[1000,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 14, 14, 256)\n", | |
" padding=((0, 0), (0, 0))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(1, 1, 256, 512)\n", | |
" window_strides=(2, 2)\n", | |
" ] qx gi\n", | |
" se\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] ht\n", | |
" sf\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hu\n", | |
" sg\u001b[35m:f32[1000,7,7,512]\u001b[39m = sub sd se\n", | |
" sh\u001b[35m:f32[1,1,1,512]\u001b[39m = add sf 9.999999747378752e-06\n", | |
" si\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt sh\n", | |
" sj\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] gb\n", | |
" sk\u001b[35m:f32[1,1,1,512]\u001b[39m = mul si sj\n", | |
" sl\u001b[35m:f32[1000,7,7,512]\u001b[39m = mul sg sk\n", | |
" sm\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] ga\n", | |
" sn\u001b[35m:f32[1000,7,7,512]\u001b[39m = add sl sm\n", | |
" so\u001b[35m:f32[1000,7,7,512]\u001b[39m = add sc sn\n", | |
" sp\u001b[35m:f32[1000,7,7,512]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; sq\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22msr\u001b[35m:f32[1000,7,7,512]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ss\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mst\u001b[35m:f32[1000,7,7,512]\u001b[39m = max ss 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(st,) }\n", | |
" name=relu\n", | |
" ] sq\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(sr,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a0de7a0>\n", | |
" num_consts=0\n", | |
" ] so\n", | |
" su\u001b[35m:f32[1000,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 7, 7, 512)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 512, 512)\n", | |
" window_strides=(1, 1)\n", | |
" ] sp gj\n", | |
" sv\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hv\n", | |
" sw\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hw\n", | |
" sx\u001b[35m:f32[1000,7,7,512]\u001b[39m = sub su sv\n", | |
" sy\u001b[35m:f32[1,1,1,512]\u001b[39m = add sw 9.999999747378752e-06\n", | |
" sz\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt sy\n", | |
" ta\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] gd\n", | |
" tb\u001b[35m:f32[1,1,1,512]\u001b[39m = mul sz ta\n", | |
" tc\u001b[35m:f32[1000,7,7,512]\u001b[39m = mul sx tb\n", | |
" td\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] gc\n", | |
" te\u001b[35m:f32[1000,7,7,512]\u001b[39m = add tc td\n", | |
" tf\u001b[35m:f32[1000,7,7,512]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; tg\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mth\u001b[35m:f32[1000,7,7,512]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ti\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mtj\u001b[35m:f32[1000,7,7,512]\u001b[39m = max ti 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(tj,) }\n", | |
" name=relu\n", | |
" ] tg\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(th,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a0de200>\n", | |
" num_consts=0\n", | |
" ] te\n", | |
" tk\u001b[35m:f32[1000,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 7, 7, 512)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 512, 512)\n", | |
" window_strides=(1, 1)\n", | |
" ] tf gk\n", | |
" tl\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hx\n", | |
" tm\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hy\n", | |
" tn\u001b[35m:f32[1000,7,7,512]\u001b[39m = sub tk tl\n", | |
" to\u001b[35m:f32[1,1,1,512]\u001b[39m = add tm 9.999999747378752e-06\n", | |
" tp\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt to\n", | |
" tq\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] gf\n", | |
" tr\u001b[35m:f32[1,1,1,512]\u001b[39m = mul tp tq\n", | |
" ts\u001b[35m:f32[1000,7,7,512]\u001b[39m = mul tn tr\n", | |
" tt\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] ge\n", | |
" tu\u001b[35m:f32[1000,7,7,512]\u001b[39m = add ts tt\n", | |
" tv\u001b[35m:f32[1000,7,7,512]\u001b[39m = add tu sp\n", | |
" tw\u001b[35m:f32[1000,7,7,512]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; tx\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mty\u001b[35m:f32[1000,7,7,512]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; tz\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mua\u001b[35m:f32[1000,7,7,512]\u001b[39m = max tz 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ua,) }\n", | |
" name=relu\n", | |
" ] tx\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ty,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a0de950>\n", | |
" num_consts=0\n", | |
" ] tv\n", | |
" ub\u001b[35m:f32[1000,512]\u001b[39m = reduce_sum[axes=(1, 2)] tw\n", | |
" uc\u001b[35m:f32[1000,512]\u001b[39m = div ub 49.0\n", | |
" ud\u001b[35m:f32[1000,1000]\u001b[39m = dot_general[\n", | |
" dimension_numbers=(((1,), (0,)), ((), ()))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" ] uc ef\n", | |
" ue\u001b[35m:f32[1,1000]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1000)] ee\n", | |
" uf\u001b[35m:f32[1000,1000]\u001b[39m = add ud ue\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(uf,) }\n", | |
" name=predict\n", | |
" ] a b c d e f g h i j k l m n o p q r s t u v w x y z ba bb bc bd be bf bg bh\n", | |
" bi bj bk bl bm bn bo bp bq br bs bt bu bv bw bx by bz ca cb cc cd ce cf cg\n", | |
" ch ci cj ck cl cm cn co cp cq cr cs ct cu cv cw cx cy cz da db dc dd de df\n", | |
" dg dh di dj dk dl dm dn do dp dq dr ds dt du dv dw dx dy\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(dz,) }" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"predict_gpu = jit(predict, backend='gpu')\n", | |
"make_jaxpr(predict_gpu)(jnp.ones((1000, 224,224,3)), params, batch_stats)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "tH4sIjp3h_oE", | |
"outputId": "0b10eec5-182c-44ef-cfba-abed0d1aff88" | |
}, | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[1000,224,224,3]\u001b[39m b\u001b[35m:f32[64]\u001b[39m c\u001b[35m:f32[64]\u001b[39m d\u001b[35m:f32[7,7,3,64]\u001b[39m e\u001b[35m:f32[1000]\u001b[39m\n", | |
" f\u001b[35m:f32[512,1000]\u001b[39m g\u001b[35m:f32[64]\u001b[39m h\u001b[35m:f32[64]\u001b[39m i\u001b[35m:f32[64]\u001b[39m j\u001b[35m:f32[64]\u001b[39m k\u001b[35m:f32[64]\u001b[39m l\u001b[35m:f32[64]\u001b[39m m\u001b[35m:f32[64]\u001b[39m\n", | |
" n\u001b[35m:f32[64]\u001b[39m o\u001b[35m:f32[3,3,64,64]\u001b[39m p\u001b[35m:f32[3,3,64,64]\u001b[39m q\u001b[35m:f32[3,3,64,64]\u001b[39m r\u001b[35m:f32[3,3,64,64]\u001b[39m\n", | |
" s\u001b[35m:f32[128]\u001b[39m t\u001b[35m:f32[128]\u001b[39m u\u001b[35m:f32[128]\u001b[39m v\u001b[35m:f32[128]\u001b[39m w\u001b[35m:f32[128]\u001b[39m x\u001b[35m:f32[128]\u001b[39m y\u001b[35m:f32[128]\u001b[39m\n", | |
" z\u001b[35m:f32[128]\u001b[39m ba\u001b[35m:f32[128]\u001b[39m bb\u001b[35m:f32[128]\u001b[39m bc\u001b[35m:f32[3,3,64,128]\u001b[39m bd\u001b[35m:f32[3,3,128,128]\u001b[39m be\u001b[35m:f32[1,1,64,128]\u001b[39m\n", | |
" bf\u001b[35m:f32[3,3,128,128]\u001b[39m bg\u001b[35m:f32[3,3,128,128]\u001b[39m bh\u001b[35m:f32[256]\u001b[39m bi\u001b[35m:f32[256]\u001b[39m bj\u001b[35m:f32[256]\u001b[39m bk\u001b[35m:f32[256]\u001b[39m\n", | |
" bl\u001b[35m:f32[256]\u001b[39m bm\u001b[35m:f32[256]\u001b[39m bn\u001b[35m:f32[256]\u001b[39m bo\u001b[35m:f32[256]\u001b[39m bp\u001b[35m:f32[256]\u001b[39m bq\u001b[35m:f32[256]\u001b[39m br\u001b[35m:f32[3,3,128,256]\u001b[39m\n", | |
" bs\u001b[35m:f32[3,3,256,256]\u001b[39m bt\u001b[35m:f32[1,1,128,256]\u001b[39m bu\u001b[35m:f32[3,3,256,256]\u001b[39m bv\u001b[35m:f32[3,3,256,256]\u001b[39m\n", | |
" bw\u001b[35m:f32[512]\u001b[39m bx\u001b[35m:f32[512]\u001b[39m by\u001b[35m:f32[512]\u001b[39m bz\u001b[35m:f32[512]\u001b[39m ca\u001b[35m:f32[512]\u001b[39m cb\u001b[35m:f32[512]\u001b[39m cc\u001b[35m:f32[512]\u001b[39m\n", | |
" cd\u001b[35m:f32[512]\u001b[39m ce\u001b[35m:f32[512]\u001b[39m cf\u001b[35m:f32[512]\u001b[39m cg\u001b[35m:f32[3,3,256,512]\u001b[39m ch\u001b[35m:f32[3,3,512,512]\u001b[39m ci\u001b[35m:f32[1,1,256,512]\u001b[39m\n", | |
" cj\u001b[35m:f32[3,3,512,512]\u001b[39m ck\u001b[35m:f32[3,3,512,512]\u001b[39m cl\u001b[35m:f32[64]\u001b[39m cm\u001b[35m:f32[64]\u001b[39m cn\u001b[35m:f32[64]\u001b[39m co\u001b[35m:f32[64]\u001b[39m\n", | |
" cp\u001b[35m:f32[64]\u001b[39m cq\u001b[35m:f32[64]\u001b[39m cr\u001b[35m:f32[64]\u001b[39m cs\u001b[35m:f32[64]\u001b[39m ct\u001b[35m:f32[64]\u001b[39m cu\u001b[35m:f32[64]\u001b[39m cv\u001b[35m:f32[128]\u001b[39m\n", | |
" cw\u001b[35m:f32[128]\u001b[39m cx\u001b[35m:f32[128]\u001b[39m cy\u001b[35m:f32[128]\u001b[39m cz\u001b[35m:f32[128]\u001b[39m da\u001b[35m:f32[128]\u001b[39m db\u001b[35m:f32[128]\u001b[39m dc\u001b[35m:f32[128]\u001b[39m\n", | |
" dd\u001b[35m:f32[128]\u001b[39m de\u001b[35m:f32[128]\u001b[39m df\u001b[35m:f32[256]\u001b[39m dg\u001b[35m:f32[256]\u001b[39m dh\u001b[35m:f32[256]\u001b[39m di\u001b[35m:f32[256]\u001b[39m dj\u001b[35m:f32[256]\u001b[39m\n", | |
" dk\u001b[35m:f32[256]\u001b[39m dl\u001b[35m:f32[256]\u001b[39m dm\u001b[35m:f32[256]\u001b[39m dn\u001b[35m:f32[256]\u001b[39m do\u001b[35m:f32[256]\u001b[39m dp\u001b[35m:f32[512]\u001b[39m dq\u001b[35m:f32[512]\u001b[39m\n", | |
" dr\u001b[35m:f32[512]\u001b[39m ds\u001b[35m:f32[512]\u001b[39m dt\u001b[35m:f32[512]\u001b[39m du\u001b[35m:f32[512]\u001b[39m dv\u001b[35m:f32[512]\u001b[39m dw\u001b[35m:f32[512]\u001b[39m dx\u001b[35m:f32[512]\u001b[39m\n", | |
" dy\u001b[35m:f32[512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mdz\u001b[35m:f32[1000,1000]\u001b[39m = xla_call[\n", | |
" backend=gpu\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ea\u001b[35m:f32[1000,224,224,3]\u001b[39m eb\u001b[35m:f32[64]\u001b[39m ec\u001b[35m:f32[64]\u001b[39m ed\u001b[35m:f32[7,7,3,64]\u001b[39m\n", | |
" ee\u001b[35m:f32[1000]\u001b[39m ef\u001b[35m:f32[512,1000]\u001b[39m eg\u001b[35m:f32[64]\u001b[39m eh\u001b[35m:f32[64]\u001b[39m ei\u001b[35m:f32[64]\u001b[39m ej\u001b[35m:f32[64]\u001b[39m\n", | |
" ek\u001b[35m:f32[64]\u001b[39m el\u001b[35m:f32[64]\u001b[39m em\u001b[35m:f32[64]\u001b[39m en\u001b[35m:f32[64]\u001b[39m eo\u001b[35m:f32[3,3,64,64]\u001b[39m ep\u001b[35m:f32[3,3,64,64]\u001b[39m\n", | |
" eq\u001b[35m:f32[3,3,64,64]\u001b[39m er\u001b[35m:f32[3,3,64,64]\u001b[39m es\u001b[35m:f32[128]\u001b[39m et\u001b[35m:f32[128]\u001b[39m eu\u001b[35m:f32[128]\u001b[39m\n", | |
" ev\u001b[35m:f32[128]\u001b[39m ew\u001b[35m:f32[128]\u001b[39m ex\u001b[35m:f32[128]\u001b[39m ey\u001b[35m:f32[128]\u001b[39m ez\u001b[35m:f32[128]\u001b[39m fa\u001b[35m:f32[128]\u001b[39m\n", | |
" fb\u001b[35m:f32[128]\u001b[39m fc\u001b[35m:f32[3,3,64,128]\u001b[39m fd\u001b[35m:f32[3,3,128,128]\u001b[39m fe\u001b[35m:f32[1,1,64,128]\u001b[39m ff\u001b[35m:f32[3,3,128,128]\u001b[39m\n", | |
" fg\u001b[35m:f32[3,3,128,128]\u001b[39m fh\u001b[35m:f32[256]\u001b[39m fi\u001b[35m:f32[256]\u001b[39m fj\u001b[35m:f32[256]\u001b[39m fk\u001b[35m:f32[256]\u001b[39m fl\u001b[35m:f32[256]\u001b[39m\n", | |
" fm\u001b[35m:f32[256]\u001b[39m fn\u001b[35m:f32[256]\u001b[39m fo\u001b[35m:f32[256]\u001b[39m fp\u001b[35m:f32[256]\u001b[39m fq\u001b[35m:f32[256]\u001b[39m fr\u001b[35m:f32[3,3,128,256]\u001b[39m\n", | |
" fs\u001b[35m:f32[3,3,256,256]\u001b[39m ft\u001b[35m:f32[1,1,128,256]\u001b[39m fu\u001b[35m:f32[3,3,256,256]\u001b[39m fv\u001b[35m:f32[3,3,256,256]\u001b[39m\n", | |
" fw\u001b[35m:f32[512]\u001b[39m fx\u001b[35m:f32[512]\u001b[39m fy\u001b[35m:f32[512]\u001b[39m fz\u001b[35m:f32[512]\u001b[39m ga\u001b[35m:f32[512]\u001b[39m gb\u001b[35m:f32[512]\u001b[39m\n", | |
" gc\u001b[35m:f32[512]\u001b[39m gd\u001b[35m:f32[512]\u001b[39m ge\u001b[35m:f32[512]\u001b[39m gf\u001b[35m:f32[512]\u001b[39m gg\u001b[35m:f32[3,3,256,512]\u001b[39m gh\u001b[35m:f32[3,3,512,512]\u001b[39m\n", | |
" gi\u001b[35m:f32[1,1,256,512]\u001b[39m gj\u001b[35m:f32[3,3,512,512]\u001b[39m gk\u001b[35m:f32[3,3,512,512]\u001b[39m gl\u001b[35m:f32[64]\u001b[39m\n", | |
" gm\u001b[35m:f32[64]\u001b[39m gn\u001b[35m:f32[64]\u001b[39m go\u001b[35m:f32[64]\u001b[39m gp\u001b[35m:f32[64]\u001b[39m gq\u001b[35m:f32[64]\u001b[39m gr\u001b[35m:f32[64]\u001b[39m gs\u001b[35m:f32[64]\u001b[39m\n", | |
" gt\u001b[35m:f32[64]\u001b[39m gu\u001b[35m:f32[64]\u001b[39m gv\u001b[35m:f32[128]\u001b[39m gw\u001b[35m:f32[128]\u001b[39m gx\u001b[35m:f32[128]\u001b[39m gy\u001b[35m:f32[128]\u001b[39m gz\u001b[35m:f32[128]\u001b[39m\n", | |
" ha\u001b[35m:f32[128]\u001b[39m hb\u001b[35m:f32[128]\u001b[39m hc\u001b[35m:f32[128]\u001b[39m hd\u001b[35m:f32[128]\u001b[39m he\u001b[35m:f32[128]\u001b[39m hf\u001b[35m:f32[256]\u001b[39m\n", | |
" hg\u001b[35m:f32[256]\u001b[39m hh\u001b[35m:f32[256]\u001b[39m hi\u001b[35m:f32[256]\u001b[39m hj\u001b[35m:f32[256]\u001b[39m hk\u001b[35m:f32[256]\u001b[39m hl\u001b[35m:f32[256]\u001b[39m\n", | |
" hm\u001b[35m:f32[256]\u001b[39m hn\u001b[35m:f32[256]\u001b[39m ho\u001b[35m:f32[256]\u001b[39m hp\u001b[35m:f32[512]\u001b[39m hq\u001b[35m:f32[512]\u001b[39m hr\u001b[35m:f32[512]\u001b[39m\n", | |
" hs\u001b[35m:f32[512]\u001b[39m ht\u001b[35m:f32[512]\u001b[39m hu\u001b[35m:f32[512]\u001b[39m hv\u001b[35m:f32[512]\u001b[39m hw\u001b[35m:f32[512]\u001b[39m hx\u001b[35m:f32[512]\u001b[39m\n", | |
" hy\u001b[35m:f32[512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mhz\u001b[35m:f32[1000,112,112,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 224, 224, 3)\n", | |
" padding=((3, 3), (3, 3))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(7, 7, 3, 64)\n", | |
" window_strides=(2, 2)\n", | |
" ] ea ed\n", | |
" ia\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gl\n", | |
" ib\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gm\n", | |
" ic\u001b[35m:f32[1000,112,112,64]\u001b[39m = sub hz ia\n", | |
" id\u001b[35m:f32[1,1,1,64]\u001b[39m = add ib 9.999999747378752e-06\n", | |
" ie\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt id\n", | |
" if\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] ec\n", | |
" ig\u001b[35m:f32[1,1,1,64]\u001b[39m = mul ie if\n", | |
" ih\u001b[35m:f32[1000,112,112,64]\u001b[39m = mul ic ig\n", | |
" ii\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] eb\n", | |
" ij\u001b[35m:f32[1000,112,112,64]\u001b[39m = add ih ii\n", | |
" ik\u001b[35m:f32[1000,112,112,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; il\u001b[35m:f32[1000,112,112,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mim\u001b[35m:f32[1000,112,112,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; in\u001b[35m:f32[1000,112,112,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mio\u001b[35m:f32[1000,112,112,64]\u001b[39m = max in 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(io,) }\n", | |
" name=relu\n", | |
" ] il\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(im,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a2a84d0>\n", | |
" num_consts=0\n", | |
" ] ij\n", | |
" ip\u001b[35m:f32[1000,56,56,64]\u001b[39m = reduce_window_max[\n", | |
" base_dilation=(1, 1, 1, 1)\n", | |
" padding=((0, 0), (1, 1), (1, 1), (0, 0))\n", | |
" window_dilation=(1, 1, 1, 1)\n", | |
" window_dimensions=(1, 3, 3, 1)\n", | |
" window_strides=(1, 2, 2, 1)\n", | |
" ] ik\n", | |
" iq\u001b[35m:f32[1000,56,56,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 64)\n", | |
" window_strides=(1, 1)\n", | |
" ] ip eo\n", | |
" ir\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gn\n", | |
" is\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] go\n", | |
" it\u001b[35m:f32[1000,56,56,64]\u001b[39m = sub iq ir\n", | |
" iu\u001b[35m:f32[1,1,1,64]\u001b[39m = add is 9.999999747378752e-06\n", | |
" iv\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt iu\n", | |
" iw\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] eh\n", | |
" ix\u001b[35m:f32[1,1,1,64]\u001b[39m = mul iv iw\n", | |
" iy\u001b[35m:f32[1000,56,56,64]\u001b[39m = mul it ix\n", | |
" iz\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] eg\n", | |
" ja\u001b[35m:f32[1000,56,56,64]\u001b[39m = add iy iz\n", | |
" jb\u001b[35m:f32[1000,56,56,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; jc\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mjd\u001b[35m:f32[1000,56,56,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; je\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mjf\u001b[35m:f32[1000,56,56,64]\u001b[39m = max je 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(jf,) }\n", | |
" name=relu\n", | |
" ] jc\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(jd,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a0de290>\n", | |
" num_consts=0\n", | |
" ] ja\n", | |
" jg\u001b[35m:f32[1000,56,56,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 64)\n", | |
" window_strides=(1, 1)\n", | |
" ] jb ep\n", | |
" jh\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gp\n", | |
" ji\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gq\n", | |
" jj\u001b[35m:f32[1000,56,56,64]\u001b[39m = sub jg jh\n", | |
" jk\u001b[35m:f32[1,1,1,64]\u001b[39m = add ji 9.999999747378752e-06\n", | |
" jl\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt jk\n", | |
" jm\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] ej\n", | |
" jn\u001b[35m:f32[1,1,1,64]\u001b[39m = mul jl jm\n", | |
" jo\u001b[35m:f32[1000,56,56,64]\u001b[39m = mul jj jn\n", | |
" jp\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] ei\n", | |
" jq\u001b[35m:f32[1000,56,56,64]\u001b[39m = add jo jp\n", | |
" jr\u001b[35m:f32[1000,56,56,64]\u001b[39m = add jq ip\n", | |
" js\u001b[35m:f32[1000,56,56,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; jt\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mju\u001b[35m:f32[1000,56,56,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; jv\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mjw\u001b[35m:f32[1000,56,56,64]\u001b[39m = max jv 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(jw,) }\n", | |
" name=relu\n", | |
" ] jt\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ju,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a0def80>\n", | |
" num_consts=0\n", | |
" ] jr\n", | |
" jx\u001b[35m:f32[1000,56,56,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 64)\n", | |
" window_strides=(1, 1)\n", | |
" ] js eq\n", | |
" jy\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gr\n", | |
" jz\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gs\n", | |
" ka\u001b[35m:f32[1000,56,56,64]\u001b[39m = sub jx jy\n", | |
" kb\u001b[35m:f32[1,1,1,64]\u001b[39m = add jz 9.999999747378752e-06\n", | |
" kc\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt kb\n", | |
" kd\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] el\n", | |
" ke\u001b[35m:f32[1,1,1,64]\u001b[39m = mul kc kd\n", | |
" kf\u001b[35m:f32[1000,56,56,64]\u001b[39m = mul ka ke\n", | |
" kg\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] ek\n", | |
" kh\u001b[35m:f32[1000,56,56,64]\u001b[39m = add kf kg\n", | |
" ki\u001b[35m:f32[1000,56,56,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; kj\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mkk\u001b[35m:f32[1000,56,56,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; kl\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mkm\u001b[35m:f32[1000,56,56,64]\u001b[39m = max kl 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(km,) }\n", | |
" name=relu\n", | |
" ] kj\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(kk,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a0dedd0>\n", | |
" num_consts=0\n", | |
" ] kh\n", | |
" kn\u001b[35m:f32[1000,56,56,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 64)\n", | |
" window_strides=(1, 1)\n", | |
" ] ki er\n", | |
" ko\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gt\n", | |
" kp\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] gu\n", | |
" kq\u001b[35m:f32[1000,56,56,64]\u001b[39m = sub kn ko\n", | |
" kr\u001b[35m:f32[1,1,1,64]\u001b[39m = add kp 9.999999747378752e-06\n", | |
" ks\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt kr\n", | |
" kt\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] en\n", | |
" ku\u001b[35m:f32[1,1,1,64]\u001b[39m = mul ks kt\n", | |
" kv\u001b[35m:f32[1000,56,56,64]\u001b[39m = mul kq ku\n", | |
" kw\u001b[35m:f32[1,1,1,64]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 64)] em\n", | |
" kx\u001b[35m:f32[1000,56,56,64]\u001b[39m = add kv kw\n", | |
" ky\u001b[35m:f32[1000,56,56,64]\u001b[39m = add kx js\n", | |
" kz\u001b[35m:f32[1000,56,56,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; la\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mlb\u001b[35m:f32[1000,56,56,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; lc\u001b[35m:f32[1000,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mld\u001b[35m:f32[1000,56,56,64]\u001b[39m = max lc 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ld,) }\n", | |
" name=relu\n", | |
" ] la\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(lb,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1a8bf7e4d0>\n", | |
" num_consts=0\n", | |
" ] ky\n", | |
" le\u001b[35m:f32[1000,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 128)\n", | |
" window_strides=(2, 2)\n", | |
" ] kz fc\n", | |
" lf\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] gv\n", | |
" lg\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] gw\n", | |
" lh\u001b[35m:f32[1000,28,28,128]\u001b[39m = sub le lf\n", | |
" li\u001b[35m:f32[1,1,1,128]\u001b[39m = add lg 9.999999747378752e-06\n", | |
" lj\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt li\n", | |
" lk\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] et\n", | |
" ll\u001b[35m:f32[1,1,1,128]\u001b[39m = mul lj lk\n", | |
" lm\u001b[35m:f32[1000,28,28,128]\u001b[39m = mul lh ll\n", | |
" ln\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] es\n", | |
" lo\u001b[35m:f32[1000,28,28,128]\u001b[39m = add lm ln\n", | |
" lp\u001b[35m:f32[1000,28,28,128]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; lq\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mlr\u001b[35m:f32[1000,28,28,128]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ls\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mlt\u001b[35m:f32[1000,28,28,128]\u001b[39m = max ls 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(lt,) }\n", | |
" name=relu\n", | |
" ] lq\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(lr,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1a8bf7e7a0>\n", | |
" num_consts=0\n", | |
" ] lo\n", | |
" lu\u001b[35m:f32[1000,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 28, 28, 128)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 128, 128)\n", | |
" window_strides=(1, 1)\n", | |
" ] lp fd\n", | |
" lv\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] gx\n", | |
" lw\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] gy\n", | |
" lx\u001b[35m:f32[1000,28,28,128]\u001b[39m = sub lu lv\n", | |
" ly\u001b[35m:f32[1,1,1,128]\u001b[39m = add lw 9.999999747378752e-06\n", | |
" lz\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt ly\n", | |
" ma\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] ev\n", | |
" mb\u001b[35m:f32[1,1,1,128]\u001b[39m = mul lz ma\n", | |
" mc\u001b[35m:f32[1000,28,28,128]\u001b[39m = mul lx mb\n", | |
" md\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] eu\n", | |
" me\u001b[35m:f32[1000,28,28,128]\u001b[39m = add mc md\n", | |
" mf\u001b[35m:f32[1000,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 56, 56, 64)\n", | |
" padding=((0, 0), (0, 0))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(1, 1, 64, 128)\n", | |
" window_strides=(2, 2)\n", | |
" ] kz fe\n", | |
" mg\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] gz\n", | |
" mh\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] ha\n", | |
" mi\u001b[35m:f32[1000,28,28,128]\u001b[39m = sub mf mg\n", | |
" mj\u001b[35m:f32[1,1,1,128]\u001b[39m = add mh 9.999999747378752e-06\n", | |
" mk\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt mj\n", | |
" ml\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] ex\n", | |
" mm\u001b[35m:f32[1,1,1,128]\u001b[39m = mul mk ml\n", | |
" mn\u001b[35m:f32[1000,28,28,128]\u001b[39m = mul mi mm\n", | |
" mo\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] ew\n", | |
" mp\u001b[35m:f32[1000,28,28,128]\u001b[39m = add mn mo\n", | |
" mq\u001b[35m:f32[1000,28,28,128]\u001b[39m = add me mp\n", | |
" mr\u001b[35m:f32[1000,28,28,128]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ms\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mmt\u001b[35m:f32[1000,28,28,128]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; mu\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mmv\u001b[35m:f32[1000,28,28,128]\u001b[39m = max mu 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(mv,) }\n", | |
" name=relu\n", | |
" ] ms\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(mt,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1a8bf7e5f0>\n", | |
" num_consts=0\n", | |
" ] mq\n", | |
" mw\u001b[35m:f32[1000,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 28, 28, 128)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 128, 128)\n", | |
" window_strides=(1, 1)\n", | |
" ] mr ff\n", | |
" mx\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] hb\n", | |
" my\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] hc\n", | |
" mz\u001b[35m:f32[1000,28,28,128]\u001b[39m = sub mw mx\n", | |
" na\u001b[35m:f32[1,1,1,128]\u001b[39m = add my 9.999999747378752e-06\n", | |
" nb\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt na\n", | |
" nc\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] ez\n", | |
" nd\u001b[35m:f32[1,1,1,128]\u001b[39m = mul nb nc\n", | |
" ne\u001b[35m:f32[1000,28,28,128]\u001b[39m = mul mz nd\n", | |
" nf\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] ey\n", | |
" ng\u001b[35m:f32[1000,28,28,128]\u001b[39m = add ne nf\n", | |
" nh\u001b[35m:f32[1000,28,28,128]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ni\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mnj\u001b[35m:f32[1000,28,28,128]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; nk\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mnl\u001b[35m:f32[1000,28,28,128]\u001b[39m = max nk 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(nl,) }\n", | |
" name=relu\n", | |
" ] ni\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(nj,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1a8bf7e830>\n", | |
" num_consts=0\n", | |
" ] ng\n", | |
" nm\u001b[35m:f32[1000,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 28, 28, 128)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 128, 128)\n", | |
" window_strides=(1, 1)\n", | |
" ] nh fg\n", | |
" nn\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] hd\n", | |
" no\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] he\n", | |
" np\u001b[35m:f32[1000,28,28,128]\u001b[39m = sub nm nn\n", | |
" nq\u001b[35m:f32[1,1,1,128]\u001b[39m = add no 9.999999747378752e-06\n", | |
" nr\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt nq\n", | |
" ns\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] fb\n", | |
" nt\u001b[35m:f32[1,1,1,128]\u001b[39m = mul nr ns\n", | |
" nu\u001b[35m:f32[1000,28,28,128]\u001b[39m = mul np nt\n", | |
" nv\u001b[35m:f32[1,1,1,128]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 128)] fa\n", | |
" nw\u001b[35m:f32[1000,28,28,128]\u001b[39m = add nu nv\n", | |
" nx\u001b[35m:f32[1000,28,28,128]\u001b[39m = add nw mr\n", | |
" ny\u001b[35m:f32[1000,28,28,128]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; nz\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22moa\u001b[35m:f32[1000,28,28,128]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ob\u001b[35m:f32[1000,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22moc\u001b[35m:f32[1000,28,28,128]\u001b[39m = max ob 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(oc,) }\n", | |
" name=relu\n", | |
" ] nz\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(oa,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1a8bf7ecb0>\n", | |
" num_consts=0\n", | |
" ] nx\n", | |
" od\u001b[35m:f32[1000,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 28, 28, 128)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 128, 256)\n", | |
" window_strides=(2, 2)\n", | |
" ] ny fr\n", | |
" oe\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hf\n", | |
" of\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hg\n", | |
" og\u001b[35m:f32[1000,14,14,256]\u001b[39m = sub od oe\n", | |
" oh\u001b[35m:f32[1,1,1,256]\u001b[39m = add of 9.999999747378752e-06\n", | |
" oi\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt oh\n", | |
" oj\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fi\n", | |
" ok\u001b[35m:f32[1,1,1,256]\u001b[39m = mul oi oj\n", | |
" ol\u001b[35m:f32[1000,14,14,256]\u001b[39m = mul og ok\n", | |
" om\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fh\n", | |
" on\u001b[35m:f32[1000,14,14,256]\u001b[39m = add ol om\n", | |
" oo\u001b[35m:f32[1000,14,14,256]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; op\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22moq\u001b[35m:f32[1000,14,14,256]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; or\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mos\u001b[35m:f32[1000,14,14,256]\u001b[39m = max or 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(os,) }\n", | |
" name=relu\n", | |
" ] op\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(oq,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1a8bf7e9e0>\n", | |
" num_consts=0\n", | |
" ] on\n", | |
" ot\u001b[35m:f32[1000,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 14, 14, 256)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 256, 256)\n", | |
" window_strides=(1, 1)\n", | |
" ] oo fs\n", | |
" ou\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hh\n", | |
" ov\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hi\n", | |
" ow\u001b[35m:f32[1000,14,14,256]\u001b[39m = sub ot ou\n", | |
" ox\u001b[35m:f32[1,1,1,256]\u001b[39m = add ov 9.999999747378752e-06\n", | |
" oy\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt ox\n", | |
" oz\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fk\n", | |
" pa\u001b[35m:f32[1,1,1,256]\u001b[39m = mul oy oz\n", | |
" pb\u001b[35m:f32[1000,14,14,256]\u001b[39m = mul ow pa\n", | |
" pc\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fj\n", | |
" pd\u001b[35m:f32[1000,14,14,256]\u001b[39m = add pb pc\n", | |
" pe\u001b[35m:f32[1000,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 28, 28, 128)\n", | |
" padding=((0, 0), (0, 0))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(1, 1, 128, 256)\n", | |
" window_strides=(2, 2)\n", | |
" ] ny ft\n", | |
" pf\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hj\n", | |
" pg\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hk\n", | |
" ph\u001b[35m:f32[1000,14,14,256]\u001b[39m = sub pe pf\n", | |
" pi\u001b[35m:f32[1,1,1,256]\u001b[39m = add pg 9.999999747378752e-06\n", | |
" pj\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt pi\n", | |
" pk\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fm\n", | |
" pl\u001b[35m:f32[1,1,1,256]\u001b[39m = mul pj pk\n", | |
" pm\u001b[35m:f32[1000,14,14,256]\u001b[39m = mul ph pl\n", | |
" pn\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fl\n", | |
" po\u001b[35m:f32[1000,14,14,256]\u001b[39m = add pm pn\n", | |
" pp\u001b[35m:f32[1000,14,14,256]\u001b[39m = add pd po\n", | |
" pq\u001b[35m:f32[1000,14,14,256]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; pr\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mps\u001b[35m:f32[1000,14,14,256]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; pt\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mpu\u001b[35m:f32[1000,14,14,256]\u001b[39m = max pt 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(pu,) }\n", | |
" name=relu\n", | |
" ] pr\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ps,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a1e7f80>\n", | |
" num_consts=0\n", | |
" ] pp\n", | |
" pv\u001b[35m:f32[1000,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 14, 14, 256)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 256, 256)\n", | |
" window_strides=(1, 1)\n", | |
" ] pq fu\n", | |
" pw\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hl\n", | |
" px\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hm\n", | |
" py\u001b[35m:f32[1000,14,14,256]\u001b[39m = sub pv pw\n", | |
" pz\u001b[35m:f32[1,1,1,256]\u001b[39m = add px 9.999999747378752e-06\n", | |
" qa\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt pz\n", | |
" qb\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fo\n", | |
" qc\u001b[35m:f32[1,1,1,256]\u001b[39m = mul qa qb\n", | |
" qd\u001b[35m:f32[1000,14,14,256]\u001b[39m = mul py qc\n", | |
" qe\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fn\n", | |
" qf\u001b[35m:f32[1000,14,14,256]\u001b[39m = add qd qe\n", | |
" qg\u001b[35m:f32[1000,14,14,256]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; qh\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mqi\u001b[35m:f32[1000,14,14,256]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; qj\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mqk\u001b[35m:f32[1000,14,14,256]\u001b[39m = max qj 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(qk,) }\n", | |
" name=relu\n", | |
" ] qh\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(qi,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1a8bf7e710>\n", | |
" num_consts=0\n", | |
" ] qf\n", | |
" ql\u001b[35m:f32[1000,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 14, 14, 256)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 256, 256)\n", | |
" window_strides=(1, 1)\n", | |
" ] qg fv\n", | |
" qm\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] hn\n", | |
" qn\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] ho\n", | |
" qo\u001b[35m:f32[1000,14,14,256]\u001b[39m = sub ql qm\n", | |
" qp\u001b[35m:f32[1,1,1,256]\u001b[39m = add qn 9.999999747378752e-06\n", | |
" qq\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt qp\n", | |
" qr\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fq\n", | |
" qs\u001b[35m:f32[1,1,1,256]\u001b[39m = mul qq qr\n", | |
" qt\u001b[35m:f32[1000,14,14,256]\u001b[39m = mul qo qs\n", | |
" qu\u001b[35m:f32[1,1,1,256]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 256)] fp\n", | |
" qv\u001b[35m:f32[1000,14,14,256]\u001b[39m = add qt qu\n", | |
" qw\u001b[35m:f32[1000,14,14,256]\u001b[39m = add qv pq\n", | |
" qx\u001b[35m:f32[1000,14,14,256]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; qy\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mqz\u001b[35m:f32[1000,14,14,256]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ra\u001b[35m:f32[1000,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mrb\u001b[35m:f32[1000,14,14,256]\u001b[39m = max ra 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(rb,) }\n", | |
" name=relu\n", | |
" ] qy\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(qz,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a1e73b0>\n", | |
" num_consts=0\n", | |
" ] qw\n", | |
" rc\u001b[35m:f32[1000,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 14, 14, 256)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 256, 512)\n", | |
" window_strides=(2, 2)\n", | |
" ] qx gg\n", | |
" rd\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hp\n", | |
" re\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hq\n", | |
" rf\u001b[35m:f32[1000,7,7,512]\u001b[39m = sub rc rd\n", | |
" rg\u001b[35m:f32[1,1,1,512]\u001b[39m = add re 9.999999747378752e-06\n", | |
" rh\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt rg\n", | |
" ri\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] fx\n", | |
" rj\u001b[35m:f32[1,1,1,512]\u001b[39m = mul rh ri\n", | |
" rk\u001b[35m:f32[1000,7,7,512]\u001b[39m = mul rf rj\n", | |
" rl\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] fw\n", | |
" rm\u001b[35m:f32[1000,7,7,512]\u001b[39m = add rk rl\n", | |
" rn\u001b[35m:f32[1000,7,7,512]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ro\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mrp\u001b[35m:f32[1000,7,7,512]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; rq\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mrr\u001b[35m:f32[1000,7,7,512]\u001b[39m = max rq 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(rr,) }\n", | |
" name=relu\n", | |
" ] ro\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(rp,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1a8bf7e170>\n", | |
" num_consts=0\n", | |
" ] rm\n", | |
" rs\u001b[35m:f32[1000,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 7, 7, 512)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 512, 512)\n", | |
" window_strides=(1, 1)\n", | |
" ] rn gh\n", | |
" rt\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hr\n", | |
" ru\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hs\n", | |
" rv\u001b[35m:f32[1000,7,7,512]\u001b[39m = sub rs rt\n", | |
" rw\u001b[35m:f32[1,1,1,512]\u001b[39m = add ru 9.999999747378752e-06\n", | |
" rx\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt rw\n", | |
" ry\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] fz\n", | |
" rz\u001b[35m:f32[1,1,1,512]\u001b[39m = mul rx ry\n", | |
" sa\u001b[35m:f32[1000,7,7,512]\u001b[39m = mul rv rz\n", | |
" sb\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] fy\n", | |
" sc\u001b[35m:f32[1000,7,7,512]\u001b[39m = add sa sb\n", | |
" sd\u001b[35m:f32[1000,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 14, 14, 256)\n", | |
" padding=((0, 0), (0, 0))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(1, 1, 256, 512)\n", | |
" window_strides=(2, 2)\n", | |
" ] qx gi\n", | |
" se\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] ht\n", | |
" sf\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hu\n", | |
" sg\u001b[35m:f32[1000,7,7,512]\u001b[39m = sub sd se\n", | |
" sh\u001b[35m:f32[1,1,1,512]\u001b[39m = add sf 9.999999747378752e-06\n", | |
" si\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt sh\n", | |
" sj\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] gb\n", | |
" sk\u001b[35m:f32[1,1,1,512]\u001b[39m = mul si sj\n", | |
" sl\u001b[35m:f32[1000,7,7,512]\u001b[39m = mul sg sk\n", | |
" sm\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] ga\n", | |
" sn\u001b[35m:f32[1000,7,7,512]\u001b[39m = add sl sm\n", | |
" so\u001b[35m:f32[1000,7,7,512]\u001b[39m = add sc sn\n", | |
" sp\u001b[35m:f32[1000,7,7,512]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; sq\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22msr\u001b[35m:f32[1000,7,7,512]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ss\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mst\u001b[35m:f32[1000,7,7,512]\u001b[39m = max ss 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(st,) }\n", | |
" name=relu\n", | |
" ] sq\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(sr,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1a8bf7e0e0>\n", | |
" num_consts=0\n", | |
" ] so\n", | |
" su\u001b[35m:f32[1000,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 7, 7, 512)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 512, 512)\n", | |
" window_strides=(1, 1)\n", | |
" ] sp gj\n", | |
" sv\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hv\n", | |
" sw\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hw\n", | |
" sx\u001b[35m:f32[1000,7,7,512]\u001b[39m = sub su sv\n", | |
" sy\u001b[35m:f32[1,1,1,512]\u001b[39m = add sw 9.999999747378752e-06\n", | |
" sz\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt sy\n", | |
" ta\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] gd\n", | |
" tb\u001b[35m:f32[1,1,1,512]\u001b[39m = mul sz ta\n", | |
" tc\u001b[35m:f32[1000,7,7,512]\u001b[39m = mul sx tb\n", | |
" td\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] gc\n", | |
" te\u001b[35m:f32[1000,7,7,512]\u001b[39m = add tc td\n", | |
" tf\u001b[35m:f32[1000,7,7,512]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; tg\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mth\u001b[35m:f32[1000,7,7,512]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ti\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mtj\u001b[35m:f32[1000,7,7,512]\u001b[39m = max ti 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(tj,) }\n", | |
" name=relu\n", | |
" ] tg\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(th,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a1e75f0>\n", | |
" num_consts=0\n", | |
" ] te\n", | |
" tk\u001b[35m:f32[1000,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1000, 7, 7, 512)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 512, 512)\n", | |
" window_strides=(1, 1)\n", | |
" ] tf gk\n", | |
" tl\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hx\n", | |
" tm\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] hy\n", | |
" tn\u001b[35m:f32[1000,7,7,512]\u001b[39m = sub tk tl\n", | |
" to\u001b[35m:f32[1,1,1,512]\u001b[39m = add tm 9.999999747378752e-06\n", | |
" tp\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt to\n", | |
" tq\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] gf\n", | |
" tr\u001b[35m:f32[1,1,1,512]\u001b[39m = mul tp tq\n", | |
" ts\u001b[35m:f32[1000,7,7,512]\u001b[39m = mul tn tr\n", | |
" tt\u001b[35m:f32[1,1,1,512]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1, 1, 512)] ge\n", | |
" tu\u001b[35m:f32[1000,7,7,512]\u001b[39m = add ts tt\n", | |
" tv\u001b[35m:f32[1000,7,7,512]\u001b[39m = add tu sp\n", | |
" tw\u001b[35m:f32[1000,7,7,512]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; tx\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mty\u001b[35m:f32[1000,7,7,512]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; tz\u001b[35m:f32[1000,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mua\u001b[35m:f32[1000,7,7,512]\u001b[39m = max tz 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ua,) }\n", | |
" name=relu\n", | |
" ] tx\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ty,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a1e7ef0>\n", | |
" num_consts=0\n", | |
" ] tv\n", | |
" ub\u001b[35m:f32[1000,512]\u001b[39m = reduce_sum[axes=(1, 2)] tw\n", | |
" uc\u001b[35m:f32[1000,512]\u001b[39m = div ub 49.0\n", | |
" ud\u001b[35m:f32[1000,1000]\u001b[39m = dot_general[\n", | |
" dimension_numbers=(((1,), (0,)), ((), ()))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" ] uc ef\n", | |
" ue\u001b[35m:f32[1,1000]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1000)] ee\n", | |
" uf\u001b[35m:f32[1000,1000]\u001b[39m = add ud ue\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(uf,) }\n", | |
" name=predict\n", | |
" ] a b c d e f g h i j k l m n o p q r s t u v w x y z ba bb bc bd be bf bg bh\n", | |
" bi bj bk bl bm bn bo bp bq br bs bt bu bv bw bx by bz ca cb cc cd ce cf cg\n", | |
" ch ci cj ck cl cm cn co cp cq cr cs ct cu cv cw cx cy cz da db dc dd de df\n", | |
" dg dh di dj dk dl dm dn do dp dq dr ds dt du dv dw dx dy\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(dz,) }" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 11 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Let's load all images and copy them to GPU memory" | |
], | |
"metadata": { | |
"id": "H9ekpGpViINR" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"id": "hIUOD0k4lot3" | |
}, | |
"outputs": [], | |
"source": [ | |
"all_images = []\n", | |
"for filename in sorted(os.listdir('./imgs')):\n", | |
" all_images.append(get_image(os.path.join('./imgs', filename)))\n", | |
"\n", | |
"all_images = jnp.array(all_images)\n", | |
"\n", | |
"all_images_dev = device_put(all_images)\n", | |
"params_dev = device_put(params)\n", | |
"batch_stats_dev = device_put(batch_stats)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Let's run some scenarios and check if we can have any speedup using GPU against CPU." | |
], | |
"metadata": { | |
"id": "7Z_aB8wRjGbO" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"First let's run it in CPU (if you are running this, it's a time to grab some coffee ;-))" | |
], | |
"metadata": { | |
"id": "TaClaYOcjRXB" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"cpu_time = %timeit -o -r 3 predict_cpu(all_images, params, batch_stats).block_until_ready()" | |
], | |
"metadata": { | |
"id": "jT3BDoVyjQbb" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Now let's try with GPU" | |
], | |
"metadata": { | |
"id": "_q_YjweopZwR" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"gpu_time = %timeit -o -r 3 predict_gpu(all_images_dev, params_dev, batch_stats_dev).block_until_ready()" | |
], | |
"metadata": { | |
"id": "uRW8_VQkfoPH" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "cttDh5qCW8nF" | |
}, | |
"outputs": [], | |
"source": [ | |
"plot_speedup({'cpu':cpu_time, 'gpu': gpu_time}, 'cpu')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Quite impressive eh?" | |
], | |
"metadata": { | |
"id": "dnRYqcNJs9K5" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Ok, but could we easily improve performance running on CPU?\n", | |
"\n", | |
"We can use `batching` with vmax. It splits the input in 1000 batches and run the input in a vectorized way." | |
], | |
"metadata": { | |
"id": "PkFQzuY-vuXC" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"predict_single_cpu = jit(predict_single, backend='cpu')\n", | |
"make_jaxpr(predict_single_cpu)(jnp.ones((224,224,3)))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "miUJKs93gHop", | |
"outputId": "bed426f7-09c1-4125-f55a-93db51d5b214" | |
}, | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22ma\u001b[35m:f32[7,7,3,64]\u001b[39m b\u001b[35m:f32[1,1,1,64]\u001b[39m c\u001b[35m:f32[1,1,1,64]\u001b[39m d\u001b[35m:f32[1,1,1,64]\u001b[39m e\u001b[35m:f32[1,1,1,64]\u001b[39m\n", | |
" f\u001b[35m:f32[3,3,64,64]\u001b[39m g\u001b[35m:f32[1,1,1,64]\u001b[39m h\u001b[35m:f32[1,1,1,64]\u001b[39m i\u001b[35m:f32[1,1,1,64]\u001b[39m j\u001b[35m:f32[1,1,1,64]\u001b[39m\n", | |
" k\u001b[35m:f32[3,3,64,64]\u001b[39m l\u001b[35m:f32[1,1,1,64]\u001b[39m m\u001b[35m:f32[1,1,1,64]\u001b[39m n\u001b[35m:f32[1,1,1,64]\u001b[39m o\u001b[35m:f32[1,1,1,64]\u001b[39m\n", | |
" p\u001b[35m:f32[3,3,64,64]\u001b[39m q\u001b[35m:f32[1,1,1,64]\u001b[39m r\u001b[35m:f32[1,1,1,64]\u001b[39m s\u001b[35m:f32[1,1,1,64]\u001b[39m t\u001b[35m:f32[1,1,1,64]\u001b[39m\n", | |
" u\u001b[35m:f32[3,3,64,64]\u001b[39m v\u001b[35m:f32[1,1,1,64]\u001b[39m w\u001b[35m:f32[1,1,1,64]\u001b[39m x\u001b[35m:f32[1,1,1,64]\u001b[39m y\u001b[35m:f32[1,1,1,64]\u001b[39m\n", | |
" z\u001b[35m:f32[3,3,64,128]\u001b[39m ba\u001b[35m:f32[1,1,1,128]\u001b[39m bb\u001b[35m:f32[1,1,1,128]\u001b[39m bc\u001b[35m:f32[1,1,1,128]\u001b[39m bd\u001b[35m:f32[1,1,1,128]\u001b[39m\n", | |
" be\u001b[35m:f32[3,3,128,128]\u001b[39m bf\u001b[35m:f32[1,1,1,128]\u001b[39m bg\u001b[35m:f32[1,1,1,128]\u001b[39m bh\u001b[35m:f32[1,1,1,128]\u001b[39m bi\u001b[35m:f32[1,1,1,128]\u001b[39m\n", | |
" bj\u001b[35m:f32[1,1,64,128]\u001b[39m bk\u001b[35m:f32[1,1,1,128]\u001b[39m bl\u001b[35m:f32[1,1,1,128]\u001b[39m bm\u001b[35m:f32[1,1,1,128]\u001b[39m bn\u001b[35m:f32[1,1,1,128]\u001b[39m\n", | |
" bo\u001b[35m:f32[3,3,128,128]\u001b[39m bp\u001b[35m:f32[1,1,1,128]\u001b[39m bq\u001b[35m:f32[1,1,1,128]\u001b[39m br\u001b[35m:f32[1,1,1,128]\u001b[39m bs\u001b[35m:f32[1,1,1,128]\u001b[39m\n", | |
" bt\u001b[35m:f32[3,3,128,128]\u001b[39m bu\u001b[35m:f32[1,1,1,128]\u001b[39m bv\u001b[35m:f32[1,1,1,128]\u001b[39m bw\u001b[35m:f32[1,1,1,128]\u001b[39m bx\u001b[35m:f32[1,1,1,128]\u001b[39m\n", | |
" by\u001b[35m:f32[3,3,128,256]\u001b[39m bz\u001b[35m:f32[1,1,1,256]\u001b[39m ca\u001b[35m:f32[1,1,1,256]\u001b[39m cb\u001b[35m:f32[1,1,1,256]\u001b[39m cc\u001b[35m:f32[1,1,1,256]\u001b[39m\n", | |
" cd\u001b[35m:f32[3,3,256,256]\u001b[39m ce\u001b[35m:f32[1,1,1,256]\u001b[39m cf\u001b[35m:f32[1,1,1,256]\u001b[39m cg\u001b[35m:f32[1,1,1,256]\u001b[39m ch\u001b[35m:f32[1,1,1,256]\u001b[39m\n", | |
" ci\u001b[35m:f32[1,1,128,256]\u001b[39m cj\u001b[35m:f32[1,1,1,256]\u001b[39m ck\u001b[35m:f32[1,1,1,256]\u001b[39m cl\u001b[35m:f32[1,1,1,256]\u001b[39m cm\u001b[35m:f32[1,1,1,256]\u001b[39m\n", | |
" cn\u001b[35m:f32[3,3,256,256]\u001b[39m co\u001b[35m:f32[1,1,1,256]\u001b[39m cp\u001b[35m:f32[1,1,1,256]\u001b[39m cq\u001b[35m:f32[1,1,1,256]\u001b[39m cr\u001b[35m:f32[1,1,1,256]\u001b[39m\n", | |
" cs\u001b[35m:f32[3,3,256,256]\u001b[39m ct\u001b[35m:f32[1,1,1,256]\u001b[39m cu\u001b[35m:f32[1,1,1,256]\u001b[39m cv\u001b[35m:f32[1,1,1,256]\u001b[39m cw\u001b[35m:f32[1,1,1,256]\u001b[39m\n", | |
" cx\u001b[35m:f32[3,3,256,512]\u001b[39m cy\u001b[35m:f32[1,1,1,512]\u001b[39m cz\u001b[35m:f32[1,1,1,512]\u001b[39m da\u001b[35m:f32[1,1,1,512]\u001b[39m db\u001b[35m:f32[1,1,1,512]\u001b[39m\n", | |
" dc\u001b[35m:f32[3,3,512,512]\u001b[39m dd\u001b[35m:f32[1,1,1,512]\u001b[39m de\u001b[35m:f32[1,1,1,512]\u001b[39m df\u001b[35m:f32[1,1,1,512]\u001b[39m dg\u001b[35m:f32[1,1,1,512]\u001b[39m\n", | |
" dh\u001b[35m:f32[1,1,256,512]\u001b[39m di\u001b[35m:f32[1,1,1,512]\u001b[39m dj\u001b[35m:f32[1,1,1,512]\u001b[39m dk\u001b[35m:f32[1,1,1,512]\u001b[39m dl\u001b[35m:f32[1,1,1,512]\u001b[39m\n", | |
" dm\u001b[35m:f32[3,3,512,512]\u001b[39m dn\u001b[35m:f32[1,1,1,512]\u001b[39m do\u001b[35m:f32[1,1,1,512]\u001b[39m dp\u001b[35m:f32[1,1,1,512]\u001b[39m dq\u001b[35m:f32[1,1,1,512]\u001b[39m\n", | |
" dr\u001b[35m:f32[3,3,512,512]\u001b[39m ds\u001b[35m:f32[1,1,1,512]\u001b[39m dt\u001b[35m:f32[1,1,1,512]\u001b[39m du\u001b[35m:f32[1,1,1,512]\u001b[39m dv\u001b[35m:f32[1,1,1,512]\u001b[39m\n", | |
" dw\u001b[35m:f32[512,1000]\u001b[39m dx\u001b[35m:f32[1000]\u001b[39m; dy\u001b[35m:f32[224,224,3]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mdz\u001b[35m:f32[1,1000]\u001b[39m = xla_call[\n", | |
" backend=cpu\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ea\u001b[35m:f32[7,7,3,64]\u001b[39m eb\u001b[35m:f32[1,1,1,64]\u001b[39m ec\u001b[35m:f32[1,1,1,64]\u001b[39m ed\u001b[35m:f32[1,1,1,64]\u001b[39m\n", | |
" ee\u001b[35m:f32[1,1,1,64]\u001b[39m ef\u001b[35m:f32[3,3,64,64]\u001b[39m eg\u001b[35m:f32[1,1,1,64]\u001b[39m eh\u001b[35m:f32[1,1,1,64]\u001b[39m ei\u001b[35m:f32[1,1,1,64]\u001b[39m\n", | |
" ej\u001b[35m:f32[1,1,1,64]\u001b[39m ek\u001b[35m:f32[3,3,64,64]\u001b[39m el\u001b[35m:f32[1,1,1,64]\u001b[39m em\u001b[35m:f32[1,1,1,64]\u001b[39m en\u001b[35m:f32[1,1,1,64]\u001b[39m\n", | |
" eo\u001b[35m:f32[1,1,1,64]\u001b[39m ep\u001b[35m:f32[3,3,64,64]\u001b[39m eq\u001b[35m:f32[1,1,1,64]\u001b[39m er\u001b[35m:f32[1,1,1,64]\u001b[39m es\u001b[35m:f32[1,1,1,64]\u001b[39m\n", | |
" et\u001b[35m:f32[1,1,1,64]\u001b[39m eu\u001b[35m:f32[3,3,64,64]\u001b[39m ev\u001b[35m:f32[1,1,1,64]\u001b[39m ew\u001b[35m:f32[1,1,1,64]\u001b[39m ex\u001b[35m:f32[1,1,1,64]\u001b[39m\n", | |
" ey\u001b[35m:f32[1,1,1,64]\u001b[39m ez\u001b[35m:f32[3,3,64,128]\u001b[39m fa\u001b[35m:f32[1,1,1,128]\u001b[39m fb\u001b[35m:f32[1,1,1,128]\u001b[39m\n", | |
" fc\u001b[35m:f32[1,1,1,128]\u001b[39m fd\u001b[35m:f32[1,1,1,128]\u001b[39m fe\u001b[35m:f32[3,3,128,128]\u001b[39m ff\u001b[35m:f32[1,1,1,128]\u001b[39m\n", | |
" fg\u001b[35m:f32[1,1,1,128]\u001b[39m fh\u001b[35m:f32[1,1,1,128]\u001b[39m fi\u001b[35m:f32[1,1,1,128]\u001b[39m fj\u001b[35m:f32[1,1,64,128]\u001b[39m\n", | |
" fk\u001b[35m:f32[1,1,1,128]\u001b[39m fl\u001b[35m:f32[1,1,1,128]\u001b[39m fm\u001b[35m:f32[1,1,1,128]\u001b[39m fn\u001b[35m:f32[1,1,1,128]\u001b[39m\n", | |
" fo\u001b[35m:f32[3,3,128,128]\u001b[39m fp\u001b[35m:f32[1,1,1,128]\u001b[39m fq\u001b[35m:f32[1,1,1,128]\u001b[39m fr\u001b[35m:f32[1,1,1,128]\u001b[39m\n", | |
" fs\u001b[35m:f32[1,1,1,128]\u001b[39m ft\u001b[35m:f32[3,3,128,128]\u001b[39m fu\u001b[35m:f32[1,1,1,128]\u001b[39m fv\u001b[35m:f32[1,1,1,128]\u001b[39m\n", | |
" fw\u001b[35m:f32[1,1,1,128]\u001b[39m fx\u001b[35m:f32[1,1,1,128]\u001b[39m fy\u001b[35m:f32[3,3,128,256]\u001b[39m fz\u001b[35m:f32[1,1,1,256]\u001b[39m\n", | |
" ga\u001b[35m:f32[1,1,1,256]\u001b[39m gb\u001b[35m:f32[1,1,1,256]\u001b[39m gc\u001b[35m:f32[1,1,1,256]\u001b[39m gd\u001b[35m:f32[3,3,256,256]\u001b[39m\n", | |
" ge\u001b[35m:f32[1,1,1,256]\u001b[39m gf\u001b[35m:f32[1,1,1,256]\u001b[39m gg\u001b[35m:f32[1,1,1,256]\u001b[39m gh\u001b[35m:f32[1,1,1,256]\u001b[39m\n", | |
" gi\u001b[35m:f32[1,1,128,256]\u001b[39m gj\u001b[35m:f32[1,1,1,256]\u001b[39m gk\u001b[35m:f32[1,1,1,256]\u001b[39m gl\u001b[35m:f32[1,1,1,256]\u001b[39m\n", | |
" gm\u001b[35m:f32[1,1,1,256]\u001b[39m gn\u001b[35m:f32[3,3,256,256]\u001b[39m go\u001b[35m:f32[1,1,1,256]\u001b[39m gp\u001b[35m:f32[1,1,1,256]\u001b[39m\n", | |
" gq\u001b[35m:f32[1,1,1,256]\u001b[39m gr\u001b[35m:f32[1,1,1,256]\u001b[39m gs\u001b[35m:f32[3,3,256,256]\u001b[39m gt\u001b[35m:f32[1,1,1,256]\u001b[39m\n", | |
" gu\u001b[35m:f32[1,1,1,256]\u001b[39m gv\u001b[35m:f32[1,1,1,256]\u001b[39m gw\u001b[35m:f32[1,1,1,256]\u001b[39m gx\u001b[35m:f32[3,3,256,512]\u001b[39m\n", | |
" gy\u001b[35m:f32[1,1,1,512]\u001b[39m gz\u001b[35m:f32[1,1,1,512]\u001b[39m ha\u001b[35m:f32[1,1,1,512]\u001b[39m hb\u001b[35m:f32[1,1,1,512]\u001b[39m\n", | |
" hc\u001b[35m:f32[3,3,512,512]\u001b[39m hd\u001b[35m:f32[1,1,1,512]\u001b[39m he\u001b[35m:f32[1,1,1,512]\u001b[39m hf\u001b[35m:f32[1,1,1,512]\u001b[39m\n", | |
" hg\u001b[35m:f32[1,1,1,512]\u001b[39m hh\u001b[35m:f32[1,1,256,512]\u001b[39m hi\u001b[35m:f32[1,1,1,512]\u001b[39m hj\u001b[35m:f32[1,1,1,512]\u001b[39m\n", | |
" hk\u001b[35m:f32[1,1,1,512]\u001b[39m hl\u001b[35m:f32[1,1,1,512]\u001b[39m hm\u001b[35m:f32[3,3,512,512]\u001b[39m hn\u001b[35m:f32[1,1,1,512]\u001b[39m\n", | |
" ho\u001b[35m:f32[1,1,1,512]\u001b[39m hp\u001b[35m:f32[1,1,1,512]\u001b[39m hq\u001b[35m:f32[1,1,1,512]\u001b[39m hr\u001b[35m:f32[3,3,512,512]\u001b[39m\n", | |
" hs\u001b[35m:f32[1,1,1,512]\u001b[39m ht\u001b[35m:f32[1,1,1,512]\u001b[39m hu\u001b[35m:f32[1,1,1,512]\u001b[39m hv\u001b[35m:f32[1,1,1,512]\u001b[39m\n", | |
" hw\u001b[35m:f32[512,1000]\u001b[39m hx\u001b[35m:f32[1000]\u001b[39m hy\u001b[35m:f32[224,224,3]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mhz\u001b[35m:f32[1,224,224,3]\u001b[39m = broadcast_in_dim[\n", | |
" broadcast_dimensions=(1, 2, 3)\n", | |
" shape=(1, 224, 224, 3)\n", | |
" ] hy\n", | |
" ia\u001b[35m:f32[1,112,112,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 224, 224, 3)\n", | |
" padding=((3, 3), (3, 3))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(7, 7, 3, 64)\n", | |
" window_strides=(2, 2)\n", | |
" ] hz ea\n", | |
" ib\u001b[35m:f32[1,112,112,64]\u001b[39m = sub ia eb\n", | |
" ic\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt ec\n", | |
" id\u001b[35m:f32[1,1,1,64]\u001b[39m = mul ic ed\n", | |
" ie\u001b[35m:f32[1,112,112,64]\u001b[39m = mul ib id\n", | |
" if\u001b[35m:f32[1,112,112,64]\u001b[39m = add ie ee\n", | |
" ig\u001b[35m:f32[1,112,112,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ih\u001b[35m:f32[1,112,112,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mii\u001b[35m:f32[1,112,112,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ij\u001b[35m:f32[1,112,112,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mik\u001b[35m:f32[1,112,112,64]\u001b[39m = max ij 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ik,) }\n", | |
" name=relu\n", | |
" ] ih\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ii,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9a148b00>\n", | |
" num_consts=0\n", | |
" ] if\n", | |
" il\u001b[35m:f32[1,56,56,64]\u001b[39m = reduce_window_max[\n", | |
" base_dilation=(1, 1, 1, 1)\n", | |
" padding=((0, 0), (1, 1), (1, 1), (0, 0))\n", | |
" window_dilation=(1, 1, 1, 1)\n", | |
" window_dimensions=(1, 3, 3, 1)\n", | |
" window_strides=(1, 2, 2, 1)\n", | |
" ] ig\n", | |
" im\u001b[35m:f32[1,56,56,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 64)\n", | |
" window_strides=(1, 1)\n", | |
" ] il ef\n", | |
" in\u001b[35m:f32[1,56,56,64]\u001b[39m = sub im eg\n", | |
" io\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt eh\n", | |
" ip\u001b[35m:f32[1,1,1,64]\u001b[39m = mul io ei\n", | |
" iq\u001b[35m:f32[1,56,56,64]\u001b[39m = mul in ip\n", | |
" ir\u001b[35m:f32[1,56,56,64]\u001b[39m = add iq ej\n", | |
" is\u001b[35m:f32[1,56,56,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; it\u001b[35m:f32[1,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22miu\u001b[35m:f32[1,56,56,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; iv\u001b[35m:f32[1,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22miw\u001b[35m:f32[1,56,56,64]\u001b[39m = max iv 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(iw,) }\n", | |
" name=relu\n", | |
" ] it\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(iu,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9036fdd0>\n", | |
" num_consts=0\n", | |
" ] ir\n", | |
" ix\u001b[35m:f32[1,56,56,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 64)\n", | |
" window_strides=(1, 1)\n", | |
" ] is ek\n", | |
" iy\u001b[35m:f32[1,56,56,64]\u001b[39m = sub ix el\n", | |
" iz\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt em\n", | |
" ja\u001b[35m:f32[1,1,1,64]\u001b[39m = mul iz en\n", | |
" jb\u001b[35m:f32[1,56,56,64]\u001b[39m = mul iy ja\n", | |
" jc\u001b[35m:f32[1,56,56,64]\u001b[39m = add jb eo\n", | |
" jd\u001b[35m:f32[1,56,56,64]\u001b[39m = add jc il\n", | |
" je\u001b[35m:f32[1,56,56,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; jf\u001b[35m:f32[1,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mjg\u001b[35m:f32[1,56,56,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; jh\u001b[35m:f32[1,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mji\u001b[35m:f32[1,56,56,64]\u001b[39m = max jh 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ji,) }\n", | |
" name=relu\n", | |
" ] jf\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(jg,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9036fc20>\n", | |
" num_consts=0\n", | |
" ] jd\n", | |
" jj\u001b[35m:f32[1,56,56,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 64)\n", | |
" window_strides=(1, 1)\n", | |
" ] je ep\n", | |
" jk\u001b[35m:f32[1,56,56,64]\u001b[39m = sub jj eq\n", | |
" jl\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt er\n", | |
" jm\u001b[35m:f32[1,1,1,64]\u001b[39m = mul jl es\n", | |
" jn\u001b[35m:f32[1,56,56,64]\u001b[39m = mul jk jm\n", | |
" jo\u001b[35m:f32[1,56,56,64]\u001b[39m = add jn et\n", | |
" jp\u001b[35m:f32[1,56,56,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; jq\u001b[35m:f32[1,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mjr\u001b[35m:f32[1,56,56,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; js\u001b[35m:f32[1,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mjt\u001b[35m:f32[1,56,56,64]\u001b[39m = max js 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(jt,) }\n", | |
" name=relu\n", | |
" ] jq\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(jr,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9036fd40>\n", | |
" num_consts=0\n", | |
" ] jo\n", | |
" ju\u001b[35m:f32[1,56,56,64]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 64)\n", | |
" window_strides=(1, 1)\n", | |
" ] jp eu\n", | |
" jv\u001b[35m:f32[1,56,56,64]\u001b[39m = sub ju ev\n", | |
" jw\u001b[35m:f32[1,1,1,64]\u001b[39m = rsqrt ew\n", | |
" jx\u001b[35m:f32[1,1,1,64]\u001b[39m = mul jw ex\n", | |
" jy\u001b[35m:f32[1,56,56,64]\u001b[39m = mul jv jx\n", | |
" jz\u001b[35m:f32[1,56,56,64]\u001b[39m = add jy ey\n", | |
" ka\u001b[35m:f32[1,56,56,64]\u001b[39m = add jz je\n", | |
" kb\u001b[35m:f32[1,56,56,64]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; kc\u001b[35m:f32[1,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mkd\u001b[35m:f32[1,56,56,64]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ke\u001b[35m:f32[1,56,56,64]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mkf\u001b[35m:f32[1,56,56,64]\u001b[39m = max ke 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(kf,) }\n", | |
" name=relu\n", | |
" ] kc\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(kd,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e9036fe60>\n", | |
" num_consts=0\n", | |
" ] ka\n", | |
" kg\u001b[35m:f32[1,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 56, 56, 64)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 64, 128)\n", | |
" window_strides=(2, 2)\n", | |
" ] kb ez\n", | |
" kh\u001b[35m:f32[1,28,28,128]\u001b[39m = sub kg fa\n", | |
" ki\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt fb\n", | |
" kj\u001b[35m:f32[1,1,1,128]\u001b[39m = mul ki fc\n", | |
" kk\u001b[35m:f32[1,28,28,128]\u001b[39m = mul kh kj\n", | |
" kl\u001b[35m:f32[1,28,28,128]\u001b[39m = add kk fd\n", | |
" km\u001b[35m:f32[1,28,28,128]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; kn\u001b[35m:f32[1,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mko\u001b[35m:f32[1,28,28,128]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; kp\u001b[35m:f32[1,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mkq\u001b[35m:f32[1,28,28,128]\u001b[39m = max kp 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(kq,) }\n", | |
" name=relu\n", | |
" ] kn\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ko,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e903400e0>\n", | |
" num_consts=0\n", | |
" ] kl\n", | |
" kr\u001b[35m:f32[1,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 28, 28, 128)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 128, 128)\n", | |
" window_strides=(1, 1)\n", | |
" ] km fe\n", | |
" ks\u001b[35m:f32[1,28,28,128]\u001b[39m = sub kr ff\n", | |
" kt\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt fg\n", | |
" ku\u001b[35m:f32[1,1,1,128]\u001b[39m = mul kt fh\n", | |
" kv\u001b[35m:f32[1,28,28,128]\u001b[39m = mul ks ku\n", | |
" kw\u001b[35m:f32[1,28,28,128]\u001b[39m = add kv fi\n", | |
" kx\u001b[35m:f32[1,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 56, 56, 64)\n", | |
" padding=((0, 0), (0, 0))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(1, 1, 64, 128)\n", | |
" window_strides=(2, 2)\n", | |
" ] kb fj\n", | |
" ky\u001b[35m:f32[1,28,28,128]\u001b[39m = sub kx fk\n", | |
" kz\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt fl\n", | |
" la\u001b[35m:f32[1,1,1,128]\u001b[39m = mul kz fm\n", | |
" lb\u001b[35m:f32[1,28,28,128]\u001b[39m = mul ky la\n", | |
" lc\u001b[35m:f32[1,28,28,128]\u001b[39m = add lb fn\n", | |
" ld\u001b[35m:f32[1,28,28,128]\u001b[39m = add kw lc\n", | |
" le\u001b[35m:f32[1,28,28,128]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; lf\u001b[35m:f32[1,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mlg\u001b[35m:f32[1,28,28,128]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; lh\u001b[35m:f32[1,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mli\u001b[35m:f32[1,28,28,128]\u001b[39m = max lh 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(li,) }\n", | |
" name=relu\n", | |
" ] lf\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(lg,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e90340050>\n", | |
" num_consts=0\n", | |
" ] ld\n", | |
" lj\u001b[35m:f32[1,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 28, 28, 128)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 128, 128)\n", | |
" window_strides=(1, 1)\n", | |
" ] le fo\n", | |
" lk\u001b[35m:f32[1,28,28,128]\u001b[39m = sub lj fp\n", | |
" ll\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt fq\n", | |
" lm\u001b[35m:f32[1,1,1,128]\u001b[39m = mul ll fr\n", | |
" ln\u001b[35m:f32[1,28,28,128]\u001b[39m = mul lk lm\n", | |
" lo\u001b[35m:f32[1,28,28,128]\u001b[39m = add ln fs\n", | |
" lp\u001b[35m:f32[1,28,28,128]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; lq\u001b[35m:f32[1,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mlr\u001b[35m:f32[1,28,28,128]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ls\u001b[35m:f32[1,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mlt\u001b[35m:f32[1,28,28,128]\u001b[39m = max ls 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(lt,) }\n", | |
" name=relu\n", | |
" ] lq\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(lr,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e90340710>\n", | |
" num_consts=0\n", | |
" ] lo\n", | |
" lu\u001b[35m:f32[1,28,28,128]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 28, 28, 128)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 128, 128)\n", | |
" window_strides=(1, 1)\n", | |
" ] lp ft\n", | |
" lv\u001b[35m:f32[1,28,28,128]\u001b[39m = sub lu fu\n", | |
" lw\u001b[35m:f32[1,1,1,128]\u001b[39m = rsqrt fv\n", | |
" lx\u001b[35m:f32[1,1,1,128]\u001b[39m = mul lw fw\n", | |
" ly\u001b[35m:f32[1,28,28,128]\u001b[39m = mul lv lx\n", | |
" lz\u001b[35m:f32[1,28,28,128]\u001b[39m = add ly fx\n", | |
" ma\u001b[35m:f32[1,28,28,128]\u001b[39m = add lz le\n", | |
" mb\u001b[35m:f32[1,28,28,128]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; mc\u001b[35m:f32[1,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mmd\u001b[35m:f32[1,28,28,128]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; me\u001b[35m:f32[1,28,28,128]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mmf\u001b[35m:f32[1,28,28,128]\u001b[39m = max me 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(mf,) }\n", | |
" name=relu\n", | |
" ] mc\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(md,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e90340560>\n", | |
" num_consts=0\n", | |
" ] ma\n", | |
" mg\u001b[35m:f32[1,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 28, 28, 128)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 128, 256)\n", | |
" window_strides=(2, 2)\n", | |
" ] mb fy\n", | |
" mh\u001b[35m:f32[1,14,14,256]\u001b[39m = sub mg fz\n", | |
" mi\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt ga\n", | |
" mj\u001b[35m:f32[1,1,1,256]\u001b[39m = mul mi gb\n", | |
" mk\u001b[35m:f32[1,14,14,256]\u001b[39m = mul mh mj\n", | |
" ml\u001b[35m:f32[1,14,14,256]\u001b[39m = add mk gc\n", | |
" mm\u001b[35m:f32[1,14,14,256]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; mn\u001b[35m:f32[1,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mmo\u001b[35m:f32[1,14,14,256]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; mp\u001b[35m:f32[1,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mmq\u001b[35m:f32[1,14,14,256]\u001b[39m = max mp 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(mq,) }\n", | |
" name=relu\n", | |
" ] mn\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(mo,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e90340e60>\n", | |
" num_consts=0\n", | |
" ] ml\n", | |
" mr\u001b[35m:f32[1,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 14, 14, 256)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 256, 256)\n", | |
" window_strides=(1, 1)\n", | |
" ] mm gd\n", | |
" ms\u001b[35m:f32[1,14,14,256]\u001b[39m = sub mr ge\n", | |
" mt\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt gf\n", | |
" mu\u001b[35m:f32[1,1,1,256]\u001b[39m = mul mt gg\n", | |
" mv\u001b[35m:f32[1,14,14,256]\u001b[39m = mul ms mu\n", | |
" mw\u001b[35m:f32[1,14,14,256]\u001b[39m = add mv gh\n", | |
" mx\u001b[35m:f32[1,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 28, 28, 128)\n", | |
" padding=((0, 0), (0, 0))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(1, 1, 128, 256)\n", | |
" window_strides=(2, 2)\n", | |
" ] mb gi\n", | |
" my\u001b[35m:f32[1,14,14,256]\u001b[39m = sub mx gj\n", | |
" mz\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt gk\n", | |
" na\u001b[35m:f32[1,1,1,256]\u001b[39m = mul mz gl\n", | |
" nb\u001b[35m:f32[1,14,14,256]\u001b[39m = mul my na\n", | |
" nc\u001b[35m:f32[1,14,14,256]\u001b[39m = add nb gm\n", | |
" nd\u001b[35m:f32[1,14,14,256]\u001b[39m = add mw nc\n", | |
" ne\u001b[35m:f32[1,14,14,256]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; nf\u001b[35m:f32[1,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mng\u001b[35m:f32[1,14,14,256]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; nh\u001b[35m:f32[1,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mni\u001b[35m:f32[1,14,14,256]\u001b[39m = max nh 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ni,) }\n", | |
" name=relu\n", | |
" ] nf\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ng,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e90340d40>\n", | |
" num_consts=0\n", | |
" ] nd\n", | |
" nj\u001b[35m:f32[1,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 14, 14, 256)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 256, 256)\n", | |
" window_strides=(1, 1)\n", | |
" ] ne gn\n", | |
" nk\u001b[35m:f32[1,14,14,256]\u001b[39m = sub nj go\n", | |
" nl\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt gp\n", | |
" nm\u001b[35m:f32[1,1,1,256]\u001b[39m = mul nl gq\n", | |
" nn\u001b[35m:f32[1,14,14,256]\u001b[39m = mul nk nm\n", | |
" no\u001b[35m:f32[1,14,14,256]\u001b[39m = add nn gr\n", | |
" np\u001b[35m:f32[1,14,14,256]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; nq\u001b[35m:f32[1,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mnr\u001b[35m:f32[1,14,14,256]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ns\u001b[35m:f32[1,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mnt\u001b[35m:f32[1,14,14,256]\u001b[39m = max ns 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(nt,) }\n", | |
" name=relu\n", | |
" ] nq\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(nr,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e90340950>\n", | |
" num_consts=0\n", | |
" ] no\n", | |
" nu\u001b[35m:f32[1,14,14,256]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 14, 14, 256)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 256, 256)\n", | |
" window_strides=(1, 1)\n", | |
" ] np gs\n", | |
" nv\u001b[35m:f32[1,14,14,256]\u001b[39m = sub nu gt\n", | |
" nw\u001b[35m:f32[1,1,1,256]\u001b[39m = rsqrt gu\n", | |
" nx\u001b[35m:f32[1,1,1,256]\u001b[39m = mul nw gv\n", | |
" ny\u001b[35m:f32[1,14,14,256]\u001b[39m = mul nv nx\n", | |
" nz\u001b[35m:f32[1,14,14,256]\u001b[39m = add ny gw\n", | |
" oa\u001b[35m:f32[1,14,14,256]\u001b[39m = add nz ne\n", | |
" ob\u001b[35m:f32[1,14,14,256]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; oc\u001b[35m:f32[1,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mod\u001b[35m:f32[1,14,14,256]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; oe\u001b[35m:f32[1,14,14,256]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mof\u001b[35m:f32[1,14,14,256]\u001b[39m = max oe 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(of,) }\n", | |
" name=relu\n", | |
" ] oc\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(od,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e903407a0>\n", | |
" num_consts=0\n", | |
" ] oa\n", | |
" og\u001b[35m:f32[1,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 14, 14, 256)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 256, 512)\n", | |
" window_strides=(2, 2)\n", | |
" ] ob gx\n", | |
" oh\u001b[35m:f32[1,7,7,512]\u001b[39m = sub og gy\n", | |
" oi\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt gz\n", | |
" oj\u001b[35m:f32[1,1,1,512]\u001b[39m = mul oi ha\n", | |
" ok\u001b[35m:f32[1,7,7,512]\u001b[39m = mul oh oj\n", | |
" ol\u001b[35m:f32[1,7,7,512]\u001b[39m = add ok hb\n", | |
" om\u001b[35m:f32[1,7,7,512]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; on\u001b[35m:f32[1,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22moo\u001b[35m:f32[1,7,7,512]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; op\u001b[35m:f32[1,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22moq\u001b[35m:f32[1,7,7,512]\u001b[39m = max op 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(oq,) }\n", | |
" name=relu\n", | |
" ] on\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(oo,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e90340cb0>\n", | |
" num_consts=0\n", | |
" ] ol\n", | |
" or\u001b[35m:f32[1,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 7, 7, 512)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 512, 512)\n", | |
" window_strides=(1, 1)\n", | |
" ] om hc\n", | |
" os\u001b[35m:f32[1,7,7,512]\u001b[39m = sub or hd\n", | |
" ot\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt he\n", | |
" ou\u001b[35m:f32[1,1,1,512]\u001b[39m = mul ot hf\n", | |
" ov\u001b[35m:f32[1,7,7,512]\u001b[39m = mul os ou\n", | |
" ow\u001b[35m:f32[1,7,7,512]\u001b[39m = add ov hg\n", | |
" ox\u001b[35m:f32[1,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 14, 14, 256)\n", | |
" padding=((0, 0), (0, 0))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(1, 1, 256, 512)\n", | |
" window_strides=(2, 2)\n", | |
" ] ob hh\n", | |
" oy\u001b[35m:f32[1,7,7,512]\u001b[39m = sub ox hi\n", | |
" oz\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt hj\n", | |
" pa\u001b[35m:f32[1,1,1,512]\u001b[39m = mul oz hk\n", | |
" pb\u001b[35m:f32[1,7,7,512]\u001b[39m = mul oy pa\n", | |
" pc\u001b[35m:f32[1,7,7,512]\u001b[39m = add pb hl\n", | |
" pd\u001b[35m:f32[1,7,7,512]\u001b[39m = add ow pc\n", | |
" pe\u001b[35m:f32[1,7,7,512]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; pf\u001b[35m:f32[1,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mpg\u001b[35m:f32[1,7,7,512]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ph\u001b[35m:f32[1,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mpi\u001b[35m:f32[1,7,7,512]\u001b[39m = max ph 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(pi,) }\n", | |
" name=relu\n", | |
" ] pf\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(pg,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e903408c0>\n", | |
" num_consts=0\n", | |
" ] pd\n", | |
" pj\u001b[35m:f32[1,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 7, 7, 512)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 512, 512)\n", | |
" window_strides=(1, 1)\n", | |
" ] pe hm\n", | |
" pk\u001b[35m:f32[1,7,7,512]\u001b[39m = sub pj hn\n", | |
" pl\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt ho\n", | |
" pm\u001b[35m:f32[1,1,1,512]\u001b[39m = mul pl hp\n", | |
" pn\u001b[35m:f32[1,7,7,512]\u001b[39m = mul pk pm\n", | |
" po\u001b[35m:f32[1,7,7,512]\u001b[39m = add pn hq\n", | |
" pp\u001b[35m:f32[1,7,7,512]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; pq\u001b[35m:f32[1,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mpr\u001b[35m:f32[1,7,7,512]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; ps\u001b[35m:f32[1,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mpt\u001b[35m:f32[1,7,7,512]\u001b[39m = max ps 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(pt,) }\n", | |
" name=relu\n", | |
" ] pq\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(pr,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e90340830>\n", | |
" num_consts=0\n", | |
" ] po\n", | |
" pu\u001b[35m:f32[1,7,7,512]\u001b[39m = conv_general_dilated[\n", | |
" batch_group_count=1\n", | |
" dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n", | |
" feature_group_count=1\n", | |
" lhs_dilation=(1, 1)\n", | |
" lhs_shape=(1, 7, 7, 512)\n", | |
" padding=((1, 1), (1, 1))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" rhs_dilation=(1, 1)\n", | |
" rhs_shape=(3, 3, 512, 512)\n", | |
" window_strides=(1, 1)\n", | |
" ] pp hr\n", | |
" pv\u001b[35m:f32[1,7,7,512]\u001b[39m = sub pu hs\n", | |
" pw\u001b[35m:f32[1,1,1,512]\u001b[39m = rsqrt ht\n", | |
" px\u001b[35m:f32[1,1,1,512]\u001b[39m = mul pw hu\n", | |
" py\u001b[35m:f32[1,7,7,512]\u001b[39m = mul pv px\n", | |
" pz\u001b[35m:f32[1,7,7,512]\u001b[39m = add py hv\n", | |
" qa\u001b[35m:f32[1,7,7,512]\u001b[39m = add pz pe\n", | |
" qb\u001b[35m:f32[1,7,7,512]\u001b[39m = custom_jvp_call_jaxpr[\n", | |
" fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; qc\u001b[35m:f32[1,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mqd\u001b[35m:f32[1,7,7,512]\u001b[39m = xla_call[\n", | |
" call_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; qe\u001b[35m:f32[1,7,7,512]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mqf\u001b[35m:f32[1,7,7,512]\u001b[39m = max qe 0.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(qf,) }\n", | |
" name=relu\n", | |
" ] qc\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(qd,) }\n", | |
" jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f1e90340ef0>\n", | |
" num_consts=0\n", | |
" ] qa\n", | |
" qg\u001b[35m:f32[1,512]\u001b[39m = reduce_sum[axes=(1, 2)] qb\n", | |
" qh\u001b[35m:f32[1,512]\u001b[39m = div qg 49.0\n", | |
" qi\u001b[35m:f32[1,1000]\u001b[39m = dot_general[\n", | |
" dimension_numbers=(((1,), (0,)), ((), ()))\n", | |
" precision=None\n", | |
" preferred_element_type=None\n", | |
" ] qh hw\n", | |
" qj\u001b[35m:f32[1,1000]\u001b[39m = reshape[dimensions=None new_sizes=(1, 1000)] hx\n", | |
" qk\u001b[35m:f32[1,1000]\u001b[39m = add qi qj\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(qk,) }\n", | |
" name=predict_single\n", | |
" ] a b c d e f g h i j k l m n o p q r s t u v w x y z ba bb bc bd be bf bg bh\n", | |
" bi bj bk bl bm bn bo bp bq br bs bt bu bv bw bx by bz ca cb cc cd ce cf cg\n", | |
" ch ci cj ck cl cm cn co cp cq cr cs ct cu cv cw cx cy cz da db dc dd de df\n", | |
" dg dh di dj dk dl dm dn do dp dq dr ds dt du dv dw dx dy\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(dz,) }" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 16 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"vmap_time = %timeit -o -r 3 vmap(predict_single_cpu)(all_images)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Zem81xRnga9-", | |
"outputId": "af28a9ec-7a88-499c-bb35-025b04cc0e56" | |
}, | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"1 loop, best of 3: 1min 23s per loop\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"plot_speedup({'cpu': cpu_time, 'vmap_cpu': vmap_time,}, 'cpu')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 265 | |
}, | |
"id": "p4BWYWw90aKe", | |
"outputId": "b0646339-9e7e-4f3d-d18e-95c8729e6a59" | |
}, | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD4CAYAAADhNOGaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAASiElEQVR4nO3df7DddX3n8eeLxJAVRaYkWEgIN5TUSmvEcBctZVr80d3wozC7gIa2iuKQdVf8Mba70m2lDB1nKW5l6goodpEOHcCgDmYxirUqVoqW8FPADQ1Bm4QfgiBLBYLB9/5xDvZwuffmEPK9N7mf52Pmzj2f7+fz/Z73yXxzX+f743xOqgpJUrt2m+4CJEnTyyCQpMYZBJLUOINAkhpnEEhS42ZPdwHP17x582pkZGS6y5CkXcqNN974UFXNH69vlwuCkZER1q5dO91lSNIuJckPJurz1JAkNc4gkKTGGQS7uPe+9728/OUvJwnHHnvshOOuu+46li5dyu67786yZcu46aabft531VVXcdBBBzF37lyOPPJI7rnnnqkoXQ1xP925GQQzwIoVKybtf/LJJznhhBN47LHHOO+883jggQc48cQTefrpp7n//vtZsWIFe+65Jx/5yEe48cYbOeWUU6aocrXE/XQnVlW71M+hhx5aerZ77rmngDrmmGPG7f/85z9fQJ177rlVVfWhD32ogPrqV79aH/3oRwuoVatWVVXVW9/61gJq/fr1ddJJJ9Xs2bPrtttuq+uvv7522223OvXUU6fsdWlmcT+dXsDamuDv6i5315Cev2cOoRcsWADAwoULAdiwYcOkfRdccAHXXnstp512Gj/5yU9YsGAB55133lSXr0a4n06fzoIgycXAscAPq+rXxukP8JfA0cDjwNur6qax47Tj1SQzzg72zZs3jwsvvJATTjgBgGuuuYY999yz8/okcD+dSl1eI7gEWD5J/1HAkv7PSuDCDmtpzpYtW9iyZQsAixcvBmDTpk0AbN68GYADDzxw0j6A++677+fbvP/++6egcrXE/XQnMdE5ox3xA4wAt0/Q90ng5IH2OmDfbW3TawTPdvXVV9c555xTQC1durQ+9alP1V133VUHHHBA7bHHHlVV9cQTT9Q+++xTIyMjdcEFF9R+++1XIyMjtXXr1rr33ntrzpw5tWzZsvrYxz5WL3nJS+qII46oqqq777679thjj1q+fHkdfvjhtddee9XmzZun8+VqF+V+Ov2Y5BpBqsMvpkkyAlxd458auho4p6q+1W//HfDBqnrOx4aTrKR31MCiRYsO/cEPJvyA3KRGzvjidq23M7v/sjPYsvH2Zy3b++j38+NvXcbPnvh/LPrAZwF4cuPtPPyVC/npw5t50bxF7L38Pey+7xIAHl/3DzzyjYvZ+thD7L7vK9j76Pcze69f5IHL/4inHrib/d55PvXTLdx3yfuYu2gp+5x01lS/zCnx/XOOme4SZuQ+Cu6nO8oL2UeT3FhVo+P27QpBMGh0dLS2d4qJmfqfTDuGQaCdXVdBMJ2fI9gM7D/QXthfJkmaQtMZBKuBt6XndcCjVXXftlaSJO1YXd4+ejlwJDAvySbgT4EXAVTVJ4A19G4dXU/v9tF3dFWLJGlinQVBVZ28jf4C3t3V80uShuNcQ5LUOINAkhpnEEhS4wwCSWqcQSBJjTMIJKlxBoEkNc4gkKTGGQSS1DiDQJIaZxBIUuMMAklqnEEgSY0zCCSpcQaBJDXOIJCkxhkEktQ4g0CSGmcQSFLjDAJJapxBIEmNMwgkqXEGgSQ1ziCQpMYZBJLUOINAkhpnEEhS4wwCSWqcQSBJjTMIJKlxBoEkNc4gkKTGdRoESZYnWZdkfZIzxulflOTrSW5OcluSo7usR5L0XJ0FQZJZwPnAUcDBwMlJDh4z7E+AVVX1GmAFcEFX9UiSxtflEcFhwPqq2lBVTwFXAMePGVPAnv3HLwPu7bAeSdI4Zne47QXAxoH2JuC1Y8acBXwlyXuAPYA3dViPJGkc032x+GTgkqpaCBwNXJrkOTUlWZlkbZK1Dz744JQXKUkzWZdBsBnYf6C9sL9s0DuBVQBVdT0wF5g3dkNVdVFVjVbV6Pz58zsqV5La1GUQ3AAsSbI4yRx6F4NXjxnzz8AbAZK8kl4Q+JZfkqZQZ0FQVVuB04FrgO/RuzvojiRnJzmuP+wPgNOS3ApcDry9qqqrmiRJz9XlxWKqag2wZsyyMwce3wn8Rpc1SJImN90XiyVJ08wgkKTGGQSS1DiDQJIaZxBIUuMMAklqnEEgSY0zCCSpcQaBJDXOIJCkxhkEktQ4g0CSGmcQSFLjDAJJapxBIEmNMwgkqXEGgSQ1ziCQpMYZBJLUOINAkhpnEEhS4wwCSWqcQSBJjTMIJKlxBoEkNc4gkKTGGQSS1DiDQJIaZxBIUuMMAklqnEEgSY0zCCSpcbOHHZhkDvArQAHrquqpzqqSJE2ZoY4IkhwD3A18DPg4sD7JUUOstzzJuiTrk5wxwZg3J7kzyR1JLns+xUuSXrhhjwj+Anh9Va0HSPJLwBeBL020QpJZwPnAbwObgBuSrK6qOwfGLAH+CPiNqnokyT7b9zIkSdtr2GsEjz0TAn0bgMe2sc5hwPqq2tA/jXQFcPyYMacB51fVIwBV9cMh65Ek7SDDHhGsTbIGWEXvGsFJ9N7h/0eAqvr8OOssADYOtDcBrx0z5pcBklwHzALOqqovj91QkpXASoBFixYNWbIkaRjDBsFc4AHgt/rtB4F/A/wOvWAYLwiGff4lwJHAQuCbSV5VVT8eHFRVFwEXAYyOjtZ2PpckaRxDBUFVvWM7tr0Z2H+gvbC/bNAm4DtV9VPgniR30QuGG7bj+SRJ22GoIEjyaXrv/J+lqk6dZLUbgCVJFtMLgBXA744ZcxVwMvDpJPPonSraMExNkqQdY9hTQ1cPPJ4L/Afg3slWqKqtSU4HrqF3/v/iqrojydnA2qpa3e/7d0nuBJ4G/mtV/ej5vghJ0vYb9tTQ5wbbSS4HvjXEemuANWOWnTnwuIAP9H8kSdNge6eYWAJ4z78kzQDDXiN4jN41gvR/3w98sMO6JElTZNhTQy/tuhBJ0vSYNAiSLJusv6pu2rHlSJKm2raOCP6i/3suMArcSu/00FJgLfDr3ZUmSZoKk14srqrXV9XrgfuAZVU1WlWHAq/huR8OkyTtgoa9a+gVVfXdZxpVdTvwym5KkiRNpWE/UHZbkr8C/qbf/j3gtm5KkiRNpWGD4B3Afwbe129/E7iwk4okSVNq2NtHn0zyCWBNVa3ruCZJ0hQa9qsqjwNuAb7cbx+SZHWXhUmSpsawF4v/lN43jv0YoKpuARZ3VZQkaeoMGwQ/rapHxyzzC2IkaQYY9mLxHUl+F5jV/8L59wL/0F1ZkqSpMuwRwXuAXwW2AJcBjwLv76ooSdLUGfauoceBP07y4f5jSdIMMexdQ4f3v0Xs//bbr05yQaeVSZKmxLCnhs4D/j3wI4CquhX4za6KkiRNnaG/oayqNo5Z9PQOrkWSNA2GvWtoY5LDgUryInpTTXyvu7IkSVNl2COCdwHvBhYA9wKH9NuSpF3csHcNPURvxlFJ0gwz7F1DByb5P0keTPLDJF9IcmDXxUmSujfsqaHLgFXAvsB+wJXA5V0VJUmaOsMGwYur6tKq2tr/+Rt632MsSdrFDXvX0JeSnAFcQW+yubcAa5L8AkBVPdxRfZKkjg0bBG/u/17Z/53+7xX0gsHrBZK0i5o0CJL8W2BjVS3ut08BTgC+D5zlkYAk7fq2dY3gk8BTAEl+E/gfwF/Tm330om5LkyRNhW2dGpo18K7/LcBFVfU54HNJbum2NEnSVNjWEcGsJM+ExRuBrw30DXt9QZK0E9vWH/PLgWuTPAQ8Afw9QJKD6J0ekiTt4iY9IqiqDwN/AFwCHFFVz3xP8W70vrVsUkmWJ1mXZH3/9tOJxp2QpJKMDl+6JGlH2Obpnar69jjL7trWeklmAecDvw1sAm5Isrqq7hwz7qX0ZjP9zrBFS5J2nKG/j2A7HAasr6oNVfUUvQ+jHT/OuD8D/hx4ssNaJEkT6DIIFgCDX2azqb/s55IsA/avqi9OtqEkK5OsTbL2wQcf3PGVSlLDugyCSSXZDfgovWsQk6qqi6pqtKpG58+f331xktSQLoNgM7D/QHthf9kzXgr8GvCNJN8HXges9oKxJE2tLoPgBmBJksVJ5tCbl2j1M51V9WhVzauqkaoaAb4NHFdVazusSZI0RmdBUFVbgdOBa+h9v/GqqrojydlJjuvqeSVJz0+nnw6uqjXAmjHLzpxg7JFd1iJJGt+0XSyWJO0cDAJJapxBIEmNMwgkqXEGgSQ1ziCQpMYZBJLUOINAkhpnEEhS4wwCSWqcQSBJjTMIJKlxBoEkNc4gkKTGGQSS1DiDQJIaZxBIUuMMAklqnEEgSY0zCCSpcQaBJDXOIJCkxhkEktQ4g0CSGmcQSFLjDAJJapxBIEmNMwgkqXEGgSQ1ziCQpMYZBJLUOINAkhrXaRAkWZ5kXZL1Sc4Yp/8DSe5McluSv0tyQJf1SJKeq7MgSDILOB84CjgYODnJwWOG3QyMVtVS4LPAuV3VI0kaX5dHBIcB66tqQ1U9BVwBHD84oKq+XlWP95vfBhZ2WI8kaRxdBsECYONAe1N/2UTeCXypw3okSeOYPd0FACT5fWAU+K0J+lcCKwEWLVo0hZVJ0szX5RHBZmD/gfbC/rJnSfIm4I+B46pqy3gbqqqLqmq0qkbnz5/fSbGS1Koug+AGYEmSxUnmACuA1YMDkrwG+CS9EPhhh7VIkibQWRBU1VbgdOAa4HvAqqq6I8nZSY7rD/sI8BLgyiS3JFk9weYkSR3p9BpBVa0B1oxZdubA4zd1+fySpG3zk8WS1DiDQJIaZxBIUuMMAklqnEEgSY0zCCSpcQaBJDXOIJCkxhkEktQ4g0CSGmcQSFLjDAJJapxBIEmNMwgkqXEGgSQ1ziCQpMYZBJLUOINAkhpnEEhS4wwCSWqcQSBJjTMIJKlxBoEkNc4gkKTGGQSS1DiDQJIaZxBIUuMMAklqnEEgSY0zCCSpcQaBJDXOIJCkxhkEktS4ToMgyfIk65KsT3LGOP27J/lMv/87SUa6rEeS9FydBUGSWcD5wFHAwcDJSQ4eM+ydwCNVdRBwHvDnXdUjSRpfl0cEhwHrq2pDVT0FXAEcP2bM8cBf9x9/FnhjknRYkyRpjNkdbnsBsHGgvQl47URjqmprkkeBvYGHBgclWQms7Df/Jcm6TipuzzzG/Fu3LB6P7ozcRwe8wH30gIk6ugyCHaaqLgIumu46Zpoka6tqdLrrkCbiPjo1ujw1tBnYf6C9sL9s3DFJZgMvA37UYU2SpDG6DIIbgCVJFieZA6wAVo8Zsxo4pf/4ROBrVVUd1iRJGqOzU0P9c/6nA9cAs4CLq+qOJGcDa6tqNfC/gUuTrAcephcWmjqebtPOzn10CsQ34JLUNj9ZLEmNMwgkqXEGgSQ1ziCQpMYZBDNYkrcluS3JrUkuTXJJkk8kWZvkriTH9se9PcnHB9a7OsmR01a4dglJzkny7oH2WUn+MMm1Sb6QZEN/zO8l+cck303yS/2xv9OfaPLmJF9N8vKBbVya5Pok/5TktG3U8MH+dm9Nck5/2TeS/GWSW5LcnuSwwfoG1r3diS57DIIZKsmvAn8CvKGqXg28r981Qm8eqGOATySZOz0Vagb4DPDmgfabgQeAVwPvAl4JvBX45ao6DPgr4D39sd8CXldVr6E3D9l/G9jOUuANwK8DZybZb7wnT3IUvfnKXtvfx88d6H5xVR0C/Bfg4hfyIluwS0wxoe3yBuDKqnoIoKoe7s/nt6qqfgb8U5INwK9MY43ahVXVzUn26f+hng88Qm/usBuq6j6AJHcDX+mv8l3g9f3HC4HPJNkXmAPcM7DpL1TVE8ATSb5O743LVeOU8Cbg01X1eL+ehwf6Lu8v+2aSPZPs9cJf8czlEUF7xn5wpICtPHtf8ChBw7qS3qwAb6F3hACwZaD/ZwPtn/Gvbz7/F/DxqnoV8J949j433j76fLmfPw8Gwcz1NeCkJHsDJPmF/vKTkuzWP1d7ILAO+D5wSH/5/vTegUnD+Ay9GQFOpBcKw3oZ/zr32Clj+o5PMre/7x5Jb7qa8fwt8I4kL4Zn7ePQCyaSHAE8WlWP0tvPl/WXLwMWP496ZzRPDc1Q/ek8Pgxcm+Rp4OZ+1z8D/wjsCbyrqp5Mch29Q/M7ge8BN01Hzdr19PezlwKbq+q+JK8YctWzgCuTPELvTcvgH+XbgK/Tm4L6z6rq3gme+8tJDgHWJnkKWAP89373k0luBl4EnNpf9jngbUnuAL4D3DXs65zpnGKiIUkuAa6uqs9Ody3SeJKcBfxLVf3PF7CNbwB/WFVrd1RdM52nhiSpcR4RSNqpJXkVcOmYxVuqauw3Hmo7GQSS1DhPDUlS4wwCSWqcQSBJjTMIJKlx/x8dFPXnCblyIgAAAABJRU5ErkJggg==\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"It does not seem effective. Probably Flax already does that behind the scenes.\n", | |
"\n", | |
"Something important we still need to consider is if we would like to predict 1000 images in a naive python loop this would take **very** long. Do not believe me, take a look on this:" | |
], | |
"metadata": { | |
"id": "5WIO4iG0IQid" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"results = []\n", | |
"\n", | |
"def naive_py_loop(images):\n", | |
" for image in images:\n", | |
" predict_single(image)\n", | |
"\n", | |
"naive_python_loop_time = %timeit -o -r 3 naive_py_loop(all_images)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "gbPn4rUPJTIZ", | |
"outputId": "bda7df08-2a75-40dd-bb05-7b59f1fb977f" | |
}, | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"1 loop, best of 3: 9min 33s per loop\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"plot_speedup({'naive_python_loop': naive_python_loop_time, 'vmap_cpu': vmap_time,}, 'naive_python_loop')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 266 | |
}, | |
"id": "5NdqMf09Jtxm", | |
"outputId": "b530fe3b-429a-4bb5-a112-25c2a4a5bf95" | |
}, | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD5CAYAAAAtBi5vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAWTUlEQVR4nO3dfZhWdZ3H8ffHAURd0E1GV0EafMg0H2GyMk00bSEtMxQpbU25mHzIp7YUd7dVr71sJbclMyHIlBZzMhSzQkVTIc20BiXBp4oHr4FgHCoVSAbB7/5xn5F7hnm4mZkzM/z4vK5rrrnP+Z1zft97OPdnzvz43edWRGBmZunZqacLMDOzfDjgzcwS5YA3M0uUA97MLFEOeDOzRPXp6QKKDRo0KCoqKnq6DDOz7caCBQvWRER5S229KuArKiqoqanp6TLMzLYbkl5trc1DNGZmiXLAm5klygFvZh02Y8YMJG31tXz58q22nTZtGkOGDGGXXXbh9NNP5y9/+UuT9scee+zd/T1U2zUc8GbWYSeccALV1dVUV1czc+ZM+vXrx957783gwYObbPfcc89x4YUXcsghh3D99dczZ84crrzyynfb33rrLaqqqth11127+ykkzQFvZh02bNgwxo0bx7hx4+jfvz8bN27kggsuoG/fvk22mzFjBgDf+MY3uOqqqzj22GOprq5mw4YNAFx77bXsvvvufPazn22y39ixY+nbty+LFi3i6aefpqysjPHjx3fLc0tBr5pFY2bbr2nTprHTTjtRVVW1VduyZcsA3r2yHzJkCJs2baK2tpZ169Zxyy238NRTT3HzzTc32W/KlCnMnz+fCRMmsH79egYPHszkyZPzfzKJcMCbWactWbKERx99lNGjR1PKe1mK72J72WWXcdZZZzFgwADWrl0LwIoVKzjssMMYNGgQU6dOZcyYMQDMnTuXgQMH5vIcUpTbEI2kgyUtLPp6U9IVefVnZj1n2rRpRAQXXXQRUAjwDRs28PbbbwOFoRwoBDfAypUr6dOnD/vttx+1tbXMnDmTgw46iNmzZwNwxhlnvPsfratWrXq3n9WrV3fbc0pCROT+BZQBq4H3trXdiBEjwsy2Lw0NDVFeXh5Dhw6NzZs3R0TEsmXLAohTTz01IiJqamoCiJNPPjkmTZoUZWVlce6550ZExAMPPBCzZs2KWbNmxciRIwOISZMmxZo1a2LJkiWx2267xahRo+LYY4+NPfbYI1auXNljz7U3AmqilUztriGajwNLIqLVd1yZpa5i4pyeLiEX61+cz5r6enY//lz2/7cHAdj0Rh0Aj7382rvP+z2nXMTjv7mbXz4+n10qhjPvH0cX/Ux2AWDNmjIAJr/Qj1u/+Rvqqq9h49vvsGjYWOLtBt747eUcePzp7HXWdd36HPO2/MZTczmuohs+0UnS7cCzEfHdFtqqgCqAoUOHjnj1Vf8OsDSlGvDWeZ0JeEkLIqKypbbcp0lK6gd8GpjVUntETI+IyoioLC9v8X45ZmbWAd0xD340hav3um7oy8zMMt0R8J8DqruhHzMzK5JrwEvaDTgFmJ1nP2ZmtrVcZ9FExHpgzzz7MDOzlvleNGZmiXLAm5klygFvZpYoB7yZWaIc8GZmiXLAm5klygFvZpYoB7yZWaIc8GZmiXLAm5klygFvZpYoB7yZWaIc8GZmiXLAm5klygFvZpYoB7yZWaIc8GZmiXLAm5klygFvZpYoB7yZWaJyDXhJe0i6R9LLkl6S9JE8+zMzsy365Hz8m4GHIuJMSf2AXXPuz8zMMrkFvKTdgY8BXwSIiI3Axrz6MzOzpvIcohkG1AN3SHpO0m2Sdmu+kaQqSTWSaurr63Msx8xsx5JnwPcBhgNTI+JoYD0wsflGETE9IiojorK8vDzHcszMdix5BvwKYEVEPJMt30Mh8M3MrBvkFvARsRqolXRwturjwIt59WdmZk3lPYvmUuBH2QyapcD5OfdnZmaZXAM+IhYClXn2YWZmLfM7Wc3MEuWANzNLlAPezCxRDngzs0Q54M3MEuWANzNLlAPezCxRDngzs0Q54M3MEuWANzNLlAPezCxRDngzs0Q54M3MEuWANzNLlAPezCxRDngzs0Q54M3MEuWANzNLlAPezCxRDngzs0Q54M3MEtUnz4NLWg6sBTYDmyKiMs/+zMxsi1wDPnNiRKzphn7MzKyIh2jMzBKVd8AH8LCkBZKqWtpAUpWkGkk19fX1OZdjZrbjyDvgj4uI4cBo4BJJH2u+QURMj4jKiKgsLy/PuRwzsx1HrgEfESuz768B9wHH5NmfmZltkVvAS9pN0oDGx8AngMV59WdmZk3lOYtmb+A+SY393BURD+XYn5mZFckt4CNiKXBkXsc3M7O2eZqkmVmiHPBmZolywJuZJcoBb2aWKAe8mVmiHPBmZolywJuZJcoBb2aWKAe8mVmiHPBmZolywJuZJcoBb2aWKAe8mVmiSr6bpKR+wPspfAzfKxGxMbeqzMys00oKeEmnAt8DlgAChkn6UkQ8mGdxZmbWcaVewX8LODEi/gQg6QBgDuCANzPrpUodg1/bGO6ZpcDaHOoxM7MuUuoVfI2kB4CfUBiDPwv4naTPAkTE7JzqMzOzDio14PsDdcAJ2XI9sAvwKQqB74A3M+tlSgr4iDg/70LMzKxrlTqL5g4KV+pNRMQFJexbBtQAKyPitG2u0MzMOqTUIZpfFD3uD5wB/LnEfS8HXgIGbkNdZmbWSaUO0dxbvCypGniyvf0kDQFOBW4AvtKRAs3MrGM6equCg4C9Stju28BVwDutbSCpSlKNpJr6+voOlmNmZs2VFPCS1kp6s/E78HPg6nb2OQ14LSIWtLVdREyPiMqIqCwvLy+5cDMza1upQzQDOnDsjwKflvRJCuP2AyXdGRHnduBYZma2jdoMeEnD22qPiGfbaLsGuCY7zkjgqw53M7Pu094V/Ley7/2BSuD3FG42dgSFqY8fya80MzPrjDbH4CPixIg4EVgFDM/GykcARwMrS+0kIuZ5DryZWfcqdRbNwRGxqHEhIhYDh+RTkpmZdYVS3+j0vKTbgDuz5XOA5/MpyczMukKpAX8+cBGFd6UC/AqYmktFZmbWJUqdJrlB0veAByLilZxrMjOzLlDqG50+DSwEHsqWj5L0szwLMzOzzin1P1mvBY4BXgeIiIXAsLyKMjOzzis14N+OiDeardvq9sFmZtZ7lPqfrC9I+jxQJukg4DLgqfzKMjOzzir1Cv5S4ANAA3AX8AZwRV5FmZlZ55U6i+bvwL9LuiF7bGZmvVyps2iOlfQi8HK2fKSkKblWZmZmnVLqEM1k4J+BvwBExO+Bj+VVlJmZdV7Jn+gUEbXNVm3u4lrMzKwLlTqLplbSsUBI6suWD9I2M7NeqtQr+AuBS4DBwJ+Bo7JlMzPrpUqdRbOGwh0kzcxsO1HqLJr9Jf1cUr2k1yTdL2n/vIszM7OOK3WI5i7gJ8A+wL7ALKA6r6LMzKzzSg34XSNiZkRsyr7upPA5rWZm1kuVOovmQUkTgR9TuMnY2cADkt4DEBF/zak+MzProFIDfmz2vSr7ruz7OAqBv9V4vKT+FD75aeesn3si4tqOl2pmZtuizYCX9EGgNiKGZcvnAWOA5cB17Vy5NwAnRcS6bO78k5IejIinu6Z0MzNrS3tj8NOAjQCSPgb8N/BDCneTnN7WjlGwLlvsm335HvJmZt2kvYAvK7pKPxuYHhH3RsTXgQPbO7ikMkkLgdeARyLimc6Va2ZmpWo34CU1DuN8HHisqK3d8fuI2BwRRwFDgGMkHdZ8G0lVkmok1dTX15dat5mZtaO9gK8G5ku6H3gLeAJA0oEUhmlKEhGvA48Do1pomx4RlRFRWV5eXnLhZmbWtjavwiPiBkmPUniD08MR0TiGvhOFT3lqlaRyCp/l+rqkXYBTgEldULOZmZWglGGWrWa9RMQfSjj2PsAPJZVR+IXwk4j4xbaXaGZmHVHqPPhtFhHPA0fndXwzM2tbyR/4YWZm2xcHvJlZohzwZmaJcsCbmSXKAW9mligHvJlZohzwZmaJcsCbmSXKAW9mligHvJlZohzwZmaJcsCbmSXKAW9mligHvJlZohzwZmaJcsCbmSXKAW9mligHvJlZohzwZmaJcsCbmSXKAW9mlqjcAl7SfpIel/SipBckXZ5XX2ZmtrU+OR57E/CvEfGspAHAAkmPRMSLOfZpZmaZ3K7gI2JVRDybPV4LvAQMzqs/MzNrqlvG4CVVAEcDz7TQViWpRlJNfX19d5RjZrZDyD3gJf0DcC9wRUS82bw9IqZHRGVEVJaXl+ddjpnZDiPXgJfUl0K4/ygiZufZl5mZNZXnLBoBPwBeioj/zasfMzNrWZ5X8B8FvgCcJGlh9vXJHPszM7MiuU2TjIgnAeV1fDMza5vfyWpmligHvJlZohzwZmaJcsCbmSXKAW9mligHvJlZohzwZmaJcsCbmSXKAW9mligHvJlZohzwZmaJcsCbmSXKAW9mligHvJlZohzwZmaJcsCbmSXKAd+LXXbZZey9995I4rTTTmt1u1//+tccccQR7LzzzgwfPpxnn3323baf/vSnHHjggfTv35+RI0eybNmy7ijdzHoBB3wvN27cuDbbN2zYwJgxY1i7di2TJ0+mrq6OM888k82bN7N69WrGjRvHwIEDuemmm1iwYAHnnXdeN1VuZj3NAd+Lfec73+HKK69sc5sHH3yQuro6Lr74Yi6++GLGjx/PsmXLmDdvHtXV1TQ0NHDNNddw6aWXcsYZZ/DEE0+wZMkSxo4dS9++fVm0aBFPP/00ZWVljB8/vpuemZl1h9w+k9W6R+OQy+DBgwEYMmQIAEuXLm2zbcqUKcyfP58JEyawfv16Bg8ezOTJk7u7fDPLUW4BL+l24DTgtYg4LK9+rKmIKKlt0KBBTJ06lTFjxgAwd+5cBg4cmHt9ZtZ98hyimQGMyvH4O6yGhgYaGhoAGDZsGAArVqwAYOXKlQDsv//+bbYBrFq16t1jrl69uhsqN7PulNsVfET8SlJFXsffEcyZM4fFixcDUFtby2233cYJJ5zAKaecwpo1a1i3bh2jR49mr732YurUqQwYMIAf/OAHVFRUMHLkSA499FAmTpzIpEmTqKur47777uO4447jgAMOYOnSpVx99dWMGjWKN998k8svv5yTTz6Zfffdt4eftZl1FbX1J32nD14I+F+0NUQjqQqoAhg6dOiIV199tUN9VUyc06H9erPVd02koXZxk3V7fvIKXn/yLt55602GfuUeADbULuavD0/l7b+upO+goew56lJ23ucgAP7+ylP8bd7tbFq7hp33OZg9P3kFffb4J+qqr2Fj3RL2HX8r8XYDq2ZcTv+hR7DXWdd199PM3fIbT+3pEoA0z1HrGp05RyUtiIjKFtt6OuCLVVZWRk1NTYf68ovHWuOAt94ur4D3NEkzs0Q54M3MEpVbwEuqBn4DHCxphSS/i8bMrBvlOYvmc3kd28zM2uchGjOzRDngzcwS5YA3M0uUA97MLFEOeDOzRDngzcwS5YA3M0uUA97MLFEOeDOzRDngzcwS5YA3M0uUA97MLFEOeDOzRDngzcwS5YA3M0uUA97MLFEOeDOzRDngzcwS5YA3M0uUA97MLFEOeDOzROUa8JJGSXpF0p8kTcyzLzMzayq3gJdUBtwKjAYOBT4n6dC8+jMzs6byvII/BvhTRCyNiI3Aj4HTc+zPzMyK9Mnx2IOB2qLlFcCHmm8kqQqoyhbXSXolx5p2FIOANT1dRG+hST1dgbXC52mmk+foe1tryDPgSxIR04HpPV1HSiTVRERlT9dh1hafp/nLc4hmJbBf0fKQbJ2ZmXWDPAP+d8BBkoZJ6geMA36WY39mZlYktyGaiNgk6cvAXKAMuD0iXsirP2vCQ162PfB5mjNFRE/XYGZmOfA7Wc3MEuWANzNLlAPezCxRDvguJulCSf/Sg/1/pviWEJLmScptrrGk5ZIG5XV8M+s4B3wXi4jvRcT/9WAJn6Fw7x+zDpF0o6RLipavk/RVSfMl3S9pabbNOZJ+K2mRpAOybT8l6RlJz0n6paS9i44xU9JvJP1R0oR2arg6O+7vJd2YrZsn6WZJCyUtlnRMcX1F+y6WVNH1P5ntjwO+HZIqJL0k6fuSXpD0sKRdJE2Q9LvsBLxX0q7Z9o0vhvdL+m2z4yzKHo/IXiwLJM2VtE8b/W91UkvaKXuRlGfb7JTdsfME4NPATdn2B2SHOSt7If5B0vHZPv0l3ZG9iJ6TdGK2/ouSZkt6KOvjm9vws/pKVuNiSVe0tT77ebws6UfZz/eexp+h9bi7gbFFy2OBOuBI4ELgEOALwPsi4hjgNuDSbNsngQ9HxNEU7j91VdFxjgBOAj4C/KekfVvqXNJoCvet+lBEHAkUn4O7RsRRwMXA7Z15kjsCB3xpDgJujYgPAK8DY4DZEfHB7AR8CRhfvENEvAz0kzQsW3U2cLekvsAtwJkRMYLCSXpDO/03Oakj4h3gTuCcrP1k4PcRMZ/Cm8m+FhFHRcSSrL1P9kK8Arg2W3dJocw4HPgc8ENJ/bO2o7J6DwfOllT8juQWSRoBnE/hfkMfBiZIOrq19dluBwNTIuIQ4M3s+VkPi4jngL0k7SvpSOBvFO4r9buIWBURDcAS4OFsl0VARfZ4CDA3u5j5GvCBokPfHxFvRcQa4HEKNyRsycnAHRHx96yevxa1VWfrfgUMlLRH555t2hzwpVkWEQuzxwsonMyHSXoiO5HPoemJ3OgnFIKS7PvdFELtMOARSQuB/6DwomhLSyf17UDjWP8FwB1t7D+7We0Ax1H4JdH4y+hV4H1Z26MR8UZEbABepI2bGRU5DrgvItZHxLqsz+PbWA9QGxG/zh7fmW1rvcMs4Ey2nLcADUXt7xQtv8OWN03eAnw3u3D4EtC/aJ/mb7rpyJtwWjrGJppmWX8McMCXqvjE3kzhZJ4BfDk7ka+n5ZPqbmCspPdRuFr+IyDghewK+6iIODwiPtFO/1ud1BFRC9RJOonCldCDJdTfWHt7Wnq+eeiKF7zl424Ktxc5k0LYl2p3ttxz6rxmbadnQ4N7AiMp3M6kJY8A5xcNe76nqO3sbN1xwBsR8QawHBierR8ODMMAB3xnDABWZUMu57S0QTZEshn4Oluugl4ByiV9BEBSX0ktXf0Xa+mkhsLY553ArIjYnK1bm9XWnica685+AQ3NauuoJ4DPSNpV0m7AGdm61tYDDG38OQCfpzB+a71AdluRAcDKiFi1DbteB8yStICtbwX8PIWhmaeB/4qIP7fS90MUhhprsr9yv1rUvEHSc8D32DIsei/wHkkvAF8G/rAN9Satx28XvB37OvAMUJ99by1U7wZuIruqiIiNks4EviNpdwr/Bt8G2rpPT+NJ3ZfCcEyjn1EYmikenvkx8H1Jl1G4+mrNFGBqNsS0CfhiRDRIamOX1kXEs5JmAI3/sXxbNpZLS+uzWQ6vAJdIup3CUNDUDnVuucj+Om18PA+YV7Q8sqW2iLgfuL+VQz4fESVNIY6IG4EbW2i6MyKuaLbtW0B7fwXvkHwvml5O0jzgqxFR00JbJTA5Io7fasdeLgv4X0TEYT1cinUDSdcB6yLifzpxjHm08lqwlvkKfjulwoeYX0Qrw0NmvUlEXNd8naTDgZnNVjdExFaf/JYdY2TXV5Y2X8H3EpJuBT7abPXNEdHW7JhuI+kZYOdmq78QEYt6oh4za58D3swsUZ5FY2aWKAe8mVmiHPBmZolywJuZJer/AYaE5xv87uXIAAAAAElFTkSuQmCC\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"So that means `vmap` makes predictions too much faster because of vectorization of the input.\n", | |
"\n", | |
"The most interesting thing in Jax is how it makes easy and transparent to use GPU or vectorization. \n", | |
"\n", | |
"It is also possible paralelize it in multiple devices with pmap, but unfortenetelly Colab has only one GPU in free tier so we won't be able to apply pmap." | |
], | |
"metadata": { | |
"id": "mOLXuZtVLIle" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Verifying accuracy\n", | |
"\n", | |
"We got same accuracy as expected in previous work. This means Flax implementation does not reduced accuracy." | |
], | |
"metadata": { | |
"id": "G0vzKPrUcSEY" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"ground_truth = [490,361,171,822,297,482,13,704,599,164,649,11,73,286,554,6,648,399,749,545,13,204,318,693,399,304,102,207,480,780,644,275,14,954,249,790,501,547,809,606,297,927,424,156,60,983,256,207,281,456,413,498,561,750,182,267,118,893,597,840,836,107,647,471,945,451,214,790,291,837,707,193,397,568,401,705,200,202,31,949,361,98,709,483,563,695,122,497,914,476,102,199,104,221,138,257,188,436,229,52,377,200,597,544,131,935,602,342,508,748,617,901,184,513,204,999,782,922,869,598,596,701,970,452,631,817,57,398,885,10,288,305,34,112,273,120,959,657,823,356,32,657,753,372,364,349,106,513,551,366,228,917,464,655,737,58,62,219,832,922,192,303,946,996,737,333,729,824,517,89,724,2,829,172,266,123,434,149,39,171,44,662,734,782,270,153,283,58,217,581,170,306,689,753,613,454,885,392,316,956,527,817,47,223,516,100,441,659,803,297,95,343,576,367,576,330,392,609,961,143,818,764,377,192,494,820,150,204,460,769,620,800,153,683,365,450,207,616,985,994,315,104,595,541,620,52,389,45,756,303,328,901,991,290,982,821,914,496,984,976,711,450,66,564,584,549,844,540,3,361,680,718,44,861,578,496,588,413,27,669,294,961,10,334,935,819,518,235,246,528,825,381,449,289,684,682,468,489,601,66,658,673,76,886,750,162,450,592,885,379,955,447,360,920,341,451,994,400,215,229,543,44,67,490,809,40,42,397,383,720,668,738,143,923,843,820,485,790,529,37,848,626,186,761,593,301,239,206,775,113,912,576,890,296,168,425,270,719,988,504,621,462,428,917,284,833,573,769,231,412,649,703,564,346,176,827,559,999,91,794,220,717,97,374,973,647,696,341,429,428,88,539,171,368,606,608,970,199,151,967,40,241,243,458,21,734,224,329,102,696,109,329,574,390,418,555,547,639,879,413,310,752,735,98,106,782,814,841,107,599,267,267,969,284,818,123,723,888,392,109,31,938,744,9,995,44,44,344,312,172,659,285,997,572,709,162,87,566,847,714,740,99,743,675,783,421,346,542,541,435,464,956,570,72,682,394,581,147,349,40,169,317,733,331,545,137,11,764,789,714,830,441,966,767,100,6,502,495,547,31,102,53,315,834,592,301,141,507,848,625,18,481,766,839,542,671,62,74,661,175,351,692,464,191,700,766,747,401,982,205,114,695,68,670,957,168,382,906,196,915,473,799,864,263,22,654,420,437,506,776,358,162,96,739,349,238,929,580,234,257,710,887,729,370,101,827,780,590,733,387,238,900,882,974,562,407,428,193,711,548,645,370,794,718,74,795,904,126,373,179,694,163,424,115,247,539,194,819,132,85,252,786,184,155,323,900,963,781,6,30,289,145,198,355,938,329,157,645,955,793,542,809,247,808,622,722,978,215,950,999,230,331,704,779,130,562,545,392,640,508,757,101,94,129,372,999,980,429,585,85,386,596,970,777,216,783,169,148,269,228,147,227,252,244,5,544,219,139,612,686,282,128,355,742,183,923,552,376,610,691,352,572,471,892,188,773,905,487,656,422,324,627,683,72,73,59,992,667,455,251,828,596,656,574,893,967,541,154,8,338,320,666,389,287,687,94,277,670,46,267,907,964,48,689,27,16,318,813,15,865,913,567,10,748,424,342,664,695,397,707,750,813,919,418,958,249,386,26,492,797,435,293,693,214,878,1,830,115,307,697,821,705,940,279,857,910,278,621,808,571,845,635,948,424,447,802,313,681,35,911,652,953,508,283,155,486,905,96,813,540,459,952,920,427,882,55,665,325,131,151,823,158,511,307,831,482,496,872,700,642,885,686,705,333,535,816,598,615,627,873,616,564,865,146,431,675,552,37,249,553,749,210,882,926,551,651,395,161,714,381,595,340,593,352,802,511,47,797,687,755,896,725,277,899,988,895,557,819,852,402,572,554,734,691,163,928,49,244,331,96,61,693,388,765,145,505,602,424,706,393,745,158,119,32,401,740,375,397,156,494,975,39,548,541,874,706,577,201,766,770,237,750,932,543,563,805,116,466,32,659,383,205,236,780,698,399,533,455,554,571,445,145,854,626,471,624,511,793,412,459,286,677,451,608,545,986,188,459,425,217,481,512,610,500,137,211,297,200,59,873,682,417,254,886,597,623,348,458,593,564,677,396,394,135,640,245,893,381,560,213,544,631,874,96,140,270,744,468,752,774,164,836,34,867,930,609,781,919,329,657,565,161,3,450,932,452,50,264,536,301]\n", | |
"glow_results = [490,590,171,822,298,493,13,970,599,87,649,11,73,286,554,6,450,397,749,545,13,189,318,693,399,304,102,18,489,897,644,275,212,954,249,790,495,547,809,629,296,927,424,123,60,659,256,132,267,809,639,498,561,750,40,267,69,893,856,840,836,144,647,471,945,451,214,790,291,968,278,193,397,568,401,710,200,202,31,949,361,98,709,654,274,695,166,497,914,476,102,75,104,221,138,257,188,436,229,52,377,200,597,544,131,848,602,343,508,845,617,320,184,570,132,999,782,512,548,598,689,701,970,455,328,794,57,397,885,10,288,556,34,64,273,183,370,657,823,49,32,657,753,372,364,992,117,513,551,367,228,917,464,655,737,136,62,219,832,922,211,317,946,996,737,333,729,824,517,89,724,2,829,172,266,109,434,15,39,171,186,662,734,985,270,153,283,58,217,939,140,306,689,753,613,454,885,392,316,956,527,854,47,804,978,100,441,659,803,966,95,343,576,367,576,330,392,609,961,143,818,764,377,125,494,756,150,253,460,769,620,280,153,683,365,655,207,616,712,994,315,104,595,356,620,52,389,45,756,677,328,779,991,290,982,821,914,496,984,125,786,498,66,564,584,549,844,540,172,359,680,718,185,861,742,496,975,413,63,669,294,961,10,334,817,373,518,235,246,529,759,555,449,289,684,596,473,489,601,66,656,771,85,886,879,162,450,592,885,530,955,443,360,271,702,451,232,400,215,754,543,44,67,490,809,40,154,397,383,720,726,738,143,923,843,820,486,790,914,37,848,626,186,761,534,300,239,206,775,189,904,576,890,296,168,425,270,366,988,14,621,462,428,723,284,833,573,769,246,412,649,710,564,346,176,827,291,999,210,562,220,717,97,374,306,647,550,341,426,428,88,539,171,368,606,608,807,989,151,967,40,241,243,458,21,734,224,331,102,696,109,329,826,390,418,555,547,639,573,413,310,752,735,98,49,977,304,841,107,599,267,285,939,284,818,123,248,943,392,109,33,819,744,9,995,44,44,344,312,141,659,285,997,572,709,473,87,566,943,714,740,99,743,675,811,421,346,251,347,435,464,956,570,72,682,394,581,147,349,182,169,317,733,329,545,137,11,863,789,714,830,441,966,767,100,6,502,495,547,31,205,53,378,834,592,301,179,507,236,622,18,481,829,839,542,297,1,74,791,175,351,685,471,191,700,949,747,401,982,108,114,695,68,771,957,168,382,906,196,915,473,891,873,263,22,654,420,437,987,776,358,162,96,739,349,238,890,580,234,257,524,887,729,379,918,755,780,21,733,387,238,900,882,974,562,407,428,446,711,708,328,370,794,698,74,795,904,77,776,179,694,209,426,114,247,289,214,958,132,30,252,786,50,155,331,900,963,776,6,30,289,91,198,349,938,329,157,645,312,793,251,809,247,808,622,722,857,215,259,999,230,329,704,779,130,562,545,392,640,508,757,195,94,81,372,999,980,429,585,30,389,721,970,788,216,783,169,148,283,228,116,227,252,564,5,313,377,176,612,933,279,128,355,742,183,923,553,509,610,248,352,572,464,892,188,773,905,487,656,422,324,627,683,72,73,59,992,790,454,251,916,596,656,705,893,967,558,154,8,338,320,302,389,287,990,94,255,670,46,267,934,964,23,689,27,16,318,813,149,949,913,548,10,748,424,342,664,695,397,707,750,813,306,418,547,249,386,210,492,598,422,293,693,214,878,1,830,114,307,721,821,705,940,282,857,943,278,644,808,571,845,635,948,439,447,275,313,681,35,790,652,781,661,283,184,486,905,96,813,540,491,758,920,427,882,174,665,325,59,60,974,154,511,705,831,482,496,848,700,745,885,686,944,333,535,701,598,614,627,873,619,684,949,146,431,675,228,37,249,553,898,210,882,926,551,575,395,18,714,381,595,340,865,352,938,908,47,963,687,755,896,725,277,299,988,517,557,674,852,402,572,289,734,691,163,928,49,244,974,150,61,693,388,765,59,831,602,424,706,393,745,158,119,32,401,740,654,397,158,494,975,39,548,540,874,710,729,12,580,770,237,790,932,543,563,805,116,466,42,659,384,205,826,989,698,399,533,507,554,571,445,145,839,626,471,606,523,793,412,459,286,677,490,608,545,986,188,957,425,654,481,659,610,500,137,51,296,90,59,873,682,417,220,654,965,623,348,14,814,564,677,396,394,135,640,221,495,381,560,75,858,632,905,96,140,589,859,476,662,774,164,540,34,867,930,609,781,919,329,657,565,161,3,450,932,452,50,250,539,302,]\n", | |
"activations = predict_gpu(all_images_dev, params_dev, batch_stats_dev).block_until_ready()\n", | |
"predictions = get_top1(activations)\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "HCORIREWcReu" | |
}, | |
"execution_count": 21, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print('Accuracy ---> ', sum(1 for x,y in zip(ground_truth, predictions) if x == y) / float(len(ground_truth)))\n", | |
"print('Accuracy comparing to results from MO436 projects 2 and 3 --> ', sum(1 for x,y in zip(glow_results, predictions) if x == y) / float(len(ground_truth)))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "yAyTUypHdC-A", | |
"outputId": "298c1645-b078-49a0-9022-9cd566e392a9" | |
}, | |
"execution_count": 22, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Accuracy ---> 0.69\n", | |
"Accuracy comparing to results from MO436 projects 2 and 3 --> 1.0\n" | |
] | |
} | |
] | |
} | |
], | |
"metadata": { | |
"colab": { | |
"collapsed_sections": [], | |
"name": "Flax_Resnet18.ipynb", | |
"provenance": [] | |
}, | |
"gpuClass": "standard", | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment