Skip to content

Instantly share code, notes, and snippets.

@fpgaminer
Created June 8, 2020 21:23
Show Gist options
  • Save fpgaminer/95f8df092c1d0154df5d970a2f82b07f to your computer and use it in GitHub Desktop.
Save fpgaminer/95f8df092c1d0154df5d970a2f82b07f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# BigGAN-deep models\n",
"module_path = '/home/pi/biggan-deep-128_1'\n",
"# module_path = 'https://tfhub.dev/deepmind/biggan-deep-128/1' # 128x128 BigGAN-deep\n",
"# module_path = 'https://tfhub.dev/deepmind/biggan-deep-256/1' # 256x256 BigGAN-deep\n",
"# module_path = 'https://tfhub.dev/deepmind/biggan-deep-512/1' # 512x512 BigGAN-deep\n",
"\n",
"# BigGAN (original) models\n",
"# module_path = 'https://tfhub.dev/deepmind/biggan-128/2' # 128x128 BigGAN\n",
"# module_path = 'https://tfhub.dev/deepmind/biggan-256/2' # 256x256 BigGAN\n",
"# module_path = 'https://tfhub.dev/deepmind/biggan-512/2' # 512x512 BigGAN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow.compat.v1 as tf\n",
"tf.disable_v2_behavior()\n",
"\n",
"import io\n",
"import IPython.display\n",
"import numpy as np\n",
"import PIL.Image\n",
"from scipy.stats import truncnorm\n",
"import tensorflow_hub as hub\n",
"import time\n",
"import os\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"print('Loading BigGAN module from:', module_path)\n",
"module = hub.Module(module_path)\n",
"inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)\n",
" for k, v in module.get_input_info_dict().items()}\n",
"output = module(inputs)\n",
"\n",
"print()\n",
"print('Inputs:\\n', '\\n'.join(\n",
" ' {}: {}'.format(*kv) for kv in inputs.items()))\n",
"print()\n",
"print('Output:', output)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_z = inputs['z']\n",
"input_y = inputs['y']\n",
"input_trunc = inputs['truncation']\n",
"\n",
"dim_z = input_z.shape.as_list()[1]\n",
"vocab_size = input_y.shape.as_list()[1]\n",
"\n",
"def truncated_z_sample(batch_size, truncation=1., seed=None):\n",
" state = None if seed is None else np.random.RandomState(seed)\n",
" values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state)\n",
" return truncation * values\n",
"\n",
"def one_hot(index, vocab_size=vocab_size):\n",
" index = np.asarray(index)\n",
" if len(index.shape) == 0:\n",
" index = np.asarray([index])\n",
" assert len(index.shape) == 1\n",
" num = index.shape[0]\n",
" output = np.zeros((num, vocab_size), dtype=np.float32)\n",
" output[np.arange(num), index] = 1\n",
" return output\n",
"\n",
"def one_hot_if_needed(label, vocab_size=vocab_size):\n",
" label = np.asarray(label)\n",
" if len(label.shape) <= 1:\n",
" label = one_hot(label, vocab_size)\n",
" assert len(label.shape) == 2\n",
" return label\n",
"\n",
"def sample(sess, noise, label, truncation=1., batch_size=8,\n",
" vocab_size=vocab_size):\n",
" noise = np.asarray(noise)\n",
" label = np.asarray(label)\n",
" num = noise.shape[0]\n",
" if len(label.shape) == 0:\n",
" label = np.asarray([label] * num)\n",
" if label.shape[0] != num:\n",
" raise ValueError('Got # noise samples ({}) != # label samples ({})'\n",
" .format(noise.shape[0], label.shape[0]))\n",
" label = one_hot_if_needed(label, vocab_size)\n",
" ims = []\n",
" for batch_start in range(0, num, batch_size):\n",
" s = slice(batch_start, min(num, batch_start + batch_size))\n",
" feed_dict = {input_z: noise[s], input_y: label[s], input_trunc: truncation}\n",
" ims.append(sess.run(output, feed_dict=feed_dict))\n",
" ims = np.concatenate(ims, axis=0)\n",
" assert ims.shape[0] == num\n",
" ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255)\n",
" ims = np.uint8(ims)\n",
" return ims\n",
"\n",
"def interpolate(A, B, num_interps):\n",
" if A.shape != B.shape:\n",
" raise ValueError('A and B must have the same shape to interpolate.')\n",
" alphas = np.linspace(0, 1, num_interps)\n",
" return np.array([(1-a)*A + a*B for a in alphas])\n",
"\n",
"def imgrid(imarray, cols=5, pad=1):\n",
" if imarray.dtype != np.uint8:\n",
" raise ValueError('imgrid input imarray must be uint8')\n",
" pad = int(pad)\n",
" assert pad >= 0\n",
" cols = int(cols)\n",
" assert cols >= 1\n",
" N, H, W, C = imarray.shape\n",
" rows = N // cols + int(N % cols != 0)\n",
" batch_pad = rows * cols - N\n",
" assert batch_pad >= 0\n",
" post_pad = [batch_pad, pad, pad, 0]\n",
" pad_arg = [[0, p] for p in post_pad]\n",
" imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)\n",
" H += pad\n",
" W += pad\n",
" grid = (imarray\n",
" .reshape(rows, cols, H, W, C)\n",
" .transpose(0, 2, 1, 3, 4)\n",
" .reshape(rows*H, cols*W, C))\n",
" if pad:\n",
" grid = grid[:-pad, :-pad]\n",
" return grid\n",
"\n",
"def imshow(a, format='png', jpeg_fallback=True):\n",
" a = np.asarray(a, dtype=np.uint8)\n",
" data = io.BytesIO()\n",
" PIL.Image.fromarray(a).save(data, format)\n",
" im_data = data.getvalue()\n",
" try:\n",
" disp = IPython.display.display(IPython.display.Image(im_data))\n",
" except IOError:\n",
" if jpeg_fallback and format != 'jpeg':\n",
" print(('Warning: image was too large to display in format \"{}\"; '\n",
" 'trying jpeg instead.').format(format))\n",
" return imshow(a, format='jpeg')\n",
" else:\n",
" raise\n",
" return disp\n",
"\n",
"\n",
"def display_image_to_framebuffer(im):\n",
" a = np.array(PIL.Image.fromarray(im).resize((480,480)))\n",
" a = a[...,::-1]\n",
" a = np.pad(a, ((0,0),(160,160),(0,1)), constant_values=((0,0),(0,0),(255,255)))\n",
" a.tofile(\"/dev/fb0\")\n",
"\n",
" \n",
"def create_labels(num, max_classes=3):\n",
" label = np.zeros((num, vocab_size))\n",
" for i in range(len(label)):\n",
" for _ in range(np.random.randint(1, max_classes)):\n",
" j = np.random.randint(0, vocab_size-1)\n",
" label[i, j] = np.random.random()\n",
" label[i] /= label[i].sum()\n",
" return label\n",
"\n",
"\n",
"def create_mutation(vector, label):\n",
" new_vector = np.zeros((1, vector.shape[0]))\n",
" new_label = np.zeros((1, label.shape[0]))\n",
"\n",
" vector_mutation_rate = vector.std() * 4\n",
"\n",
" new_label[0][:] = label\n",
" dv = (np.random.rand(*vector.shape)-0.5) * vector_mutation_rate\n",
" new_vector[0] = vector + dv\n",
" new_vector[0] /= max(-new_vector[0].min(), new_vector[0].max())\n",
"\n",
" # Reduce class\n",
" if random.random() < 0.2:\n",
" opts = np.nonzero(new_label[0])[0]\n",
" if len(opts) != 1:\n",
" new_label[0][random.choice(opts)] *= 0.2 + random.random() * 0.6\n",
"\n",
" # Add class.\n",
" if random.random() < 0.3:\n",
" new_label[0][random.randint(0, label.shape[0]-1)] += random.random() * 0.5\n",
"\n",
" # Remove if less than two percent.\n",
" new_label[new_label < .02] = 0\n",
"\n",
" # Normalize.\n",
" new_label[0] /= new_label[0].sum()\n",
"\n",
" return new_vector, new_label\n",
"\n",
"\n",
"def interpolate_and_shape(A, B, num_interps):\n",
" interps = interpolate(A, B, num_interps)\n",
" return (interps.transpose(1, 0, *range(2, len(interps.shape)))\n",
" .reshape(num_samples * num_interps, *interps.shape[2:]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"initializer = tf.global_variables_initializer()\n",
"sess = tf.Session()\n",
"sess.run(initializer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Main Loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This version of the main loop picks a new random target, interpolates towards it, and then when it arrives\n",
"# it picks another random target and interpolates from there.\n",
"# This should explore the latent space fairly well.\n",
"truncation = 0.5\n",
"num_interps = 16\n",
"\n",
"alphas = np.linspace(0, 1, num_interps, endpoint=False)\n",
"\n",
"z_A = truncated_z_sample(1, truncation)\n",
"y_A = create_labels(1, random.randrange(3,9))\n",
"\n",
"while True:\n",
" # Pick a new target point in the latent space\n",
" z_B = truncated_z_sample(1, truncation)\n",
" y_B = create_labels(1, random.randrange(3,9))\n",
" \n",
" # Interpolate towards the new point\n",
" for a in alphas:\n",
" z = (1-a)*z_A + a*z_B\n",
" y = (1-a)*y_A + a*y_B\n",
" \n",
" ims = sample(sess, z, y, truncation=truncation)\n",
" display_image_to_framebuffer(ims[0])\n",
" \n",
" # Helps keep CPU cool by giving it a break\n",
" time.sleep(4)\n",
"\n",
" # We've reached the target, set it as the new start point and repeat\n",
" z_A = z_B\n",
" y_A = y_B"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"truncation = 0.5\n",
"\n",
"z = truncated_z_sample(1, truncation)\n",
"y = create_labels(1, 3)\n",
"\n",
"while True:\n",
" z, y = create_mutation(z[0], y[0])\n",
" \n",
" ims = sample(sess, z, y, truncation=truncation)\n",
" display_image_to_framebuffer(ims[0])\n",
" \n",
" time.sleep(4)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment