Skip to content

Instantly share code, notes, and snippets.

@nirum
Last active April 17, 2019 19:18
Show Gist options
  • Save nirum/6f7f345bc6b46224f36cc40c73a8147c to your computer and use it in GitHub Desktop.
Save nirum/6f7f345bc6b46224f36cc40c73a8147c to your computer and use it in GitHub Desktop.
lax_shape_bug
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX Quickstart.ipynb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/nirum/6f7f345bc6b46224f36cc40c73a8147c/jax-quickstart.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"metadata": {
"id": "logZcM_HEnve",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"##### Copyright 2018 Google LLC.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");"
]
},
{
"metadata": {
"id": "QwN47xiBEsKz",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
"https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License."
]
},
{
"metadata": {
"colab_type": "text",
"id": "xtWX4x9DCF5_"
},
"cell_type": "markdown",
"source": [
"# JAX Quickstart\n",
"Dougal Maclaurin, Peter Hawkins, Matthew Johnson, Roy Frostig, Alex Wiltschko, Chris Leary\n",
"\n",
"![](https://raw.githubusercontent.com/google/jax/master/images/jax_logo_250px.png)\n",
"\n",
"#### [JAX](https://github.com/google/jax) is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.\n",
"\n",
"With its updated version of [Autograd](https://github.com/hips/autograd), JAX\n",
"can automatically differentiate native Python and NumPy code. It can\n",
"differentiate through a large subset of Python’s features, including loops, ifs,\n",
"recursion, and closures, and it can even take derivatives of derivatives of\n",
"derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily\n",
"to any order.\n",
"\n",
"What’s new is that JAX uses\n",
"[XLA](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/overview.md)\n",
"to compile and run your NumPy code on accelerators, like GPUs and TPUs.\n",
"Compilation happens under the hood by default, with library calls getting\n",
"just-in-time compiled and executed. But JAX even lets you just-in-time compile\n",
"your own Python functions into XLA-optimized kernels using a one-function API.\n",
"Compilation and automatic differentiation can be composed arbitrarily, so you\n",
"can express sophisticated algorithms and get maximal performance without having\n",
"to leave Python.\n"
]
},
{
"metadata": {
"colab_type": "code",
"id": "PaW85yP_BrCF",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 101
},
"outputId": "44dcfc96-680d-49d4-907f-0778d218d64b"
},
"cell_type": "code",
"source": [
"#!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.12-cp36-none-linux_x86_64.whl\n",
"!pip install --upgrade -q jax"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[K 100% |████████████████████████████████| 37.4MB 851kB/s \n",
"\u001b[K 100% |████████████████████████████████| 163kB 10.8MB/s \n",
"\u001b[K 100% |████████████████████████████████| 61kB 27.2MB/s \n",
"\u001b[?25h Building wheel for jax (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25h Building wheel for opt-einsum (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25h"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "i0ugCIZVZYAP",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "69bfa987-9b62-4808-bf9c-a5c96d400da4"
},
"cell_type": "code",
"source": [
"import jax\n",
"jax.__version__"
],
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'0.1.25'"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"metadata": {
"colab_type": "code",
"id": "SY8mDvEvCGqk",
"colab": {}
},
"cell_type": "code",
"source": [
"from __future__ import print_function, division\n",
"import jax.numpy as np\n",
"from jax import grad, jit, vmap\n",
"from jax import random\n",
"from jax.experimental import stax"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "n5dVaHWLXstM",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Building a single conv layer with None in the input_shape fails"
]
},
{
"metadata": {
"id": "O6E8bs1XXRGc",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"input_shape = (None, 32, 32, 3)\n",
"\n",
"init_fn, predict_fn = stax.Conv(16, (3, 3), padding='SAME')"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "8q065lnAXfs4",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"rng_key = random.PRNGKey(0)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "2rixAURJXd3e",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 321
},
"outputId": "de239da1-5e78-4118-e70b-c99e1d7d98d8"
},
"cell_type": "code",
"source": [
"output_shape, params = init_fn(rng_key, input_shape)\n",
"print(output_shape)"
],
"execution_count": 16,
"outputs": [
{
"output_type": "error",
"ename": "AttributeError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-16-ffbfb75734ac>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0moutput_shape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minit_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng_key\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/experimental/stax.py\u001b[0m in \u001b[0;36minit_fun\u001b[0;34m(rng, input_shape)\u001b[0m\n\u001b[1;32m 111\u001b[0m next(filter_shape_iter) for c in rhs_spec]\n\u001b[1;32m 112\u001b[0m output_shape = lax.conv_general_shape_tuple(\n\u001b[0;32m--> 113\u001b[0;31m input_shape, kernel_shape, strides, padding, dimension_numbers)\n\u001b[0m\u001b[1;32m 114\u001b[0m \u001b[0mbias_shape\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mout_chan\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mc\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'C'\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mout_spec\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0mbias_shape\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitertools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropwhile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/lax.py\u001b[0m in \u001b[0;36mconv_general_shape_tuple\u001b[0;34m(lhs_shape, rhs_shape, window_strides, padding, dimension_numbers)\u001b[0m\n\u001b[1;32m 4300\u001b[0m \u001b[0mlhs_trans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlhs_shape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlhs_perm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4301\u001b[0m \u001b[0mrhs_trans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrhs_shape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrhs_perm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4302\u001b[0;31m \u001b[0mout_trans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconv_shape_tuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlhs_trans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrhs_trans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwindow_strides\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpadding\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4303\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_trans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margsort\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_perm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4304\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/lax.py\u001b[0m in \u001b[0;36mconv_shape_tuple\u001b[0;34m(lhs_shape, rhs_shape, strides, pads)\u001b[0m\n\u001b[1;32m 4282\u001b[0m \u001b[0;34m\"\"\"Compute the shape tuple of a conv given input shapes in canonical order.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4283\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpads\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4284\u001b[0;31m \u001b[0mpads\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpadtype_to_pads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlhs_shape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrhs_shape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstrides\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpads\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4285\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpads\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlhs_shape\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4286\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"Wrong number of explicit pads for convolution: expected {}, got {}.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jax/lax.py\u001b[0m in \u001b[0;36mpadtype_to_pads\u001b[0;34m(in_shape, window_shape, window_strides, padding)\u001b[0m\n\u001b[1;32m 4227\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4228\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mpadding\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mPaddingType\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSAME\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4229\u001b[0;31m \u001b[0mout_shape\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mceil\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrue_divide\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0min_shape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwindow_strides\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4230\u001b[0m pad_sizes = [_max((out_size - 1) * stride + window_shape - in_size, 0)\n\u001b[1;32m 4231\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mout_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstride\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwindow_shape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_size\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: 'float' object has no attribute 'ceil'"
]
}
]
},
{
"metadata": {
"id": "FFS-bL3XXwVd",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## The same layer with an integer batch_size works"
]
},
{
"metadata": {
"id": "NuLwbYxrXyey",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "820a1bcc-e54a-426f-bd49-9c4b863985e0"
},
"cell_type": "code",
"source": [
"input_shape = (128, 32, 32, 3)\n",
"init_fn, predict_fn = stax.Conv(16, (3, 3), padding='SAME')\n",
"output_shape, params = init_fn(rng_key, input_shape)\n",
"\n",
"print(output_shape)"
],
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"text": [
"(128, 32, 32, 16)\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "oogEXOrTX8QZ",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment