-
-
Save fpgaminer/95f8df092c1d0154df5d970a2f82b07f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# BigGAN-deep models\n", | |
"module_path = '/home/pi/biggan-deep-128_1'\n", | |
"# module_path = 'https://tfhub.dev/deepmind/biggan-deep-128/1' # 128x128 BigGAN-deep\n", | |
"# module_path = 'https://tfhub.dev/deepmind/biggan-deep-256/1' # 256x256 BigGAN-deep\n", | |
"# module_path = 'https://tfhub.dev/deepmind/biggan-deep-512/1' # 512x512 BigGAN-deep\n", | |
"\n", | |
"# BigGAN (original) models\n", | |
"# module_path = 'https://tfhub.dev/deepmind/biggan-128/2' # 128x128 BigGAN\n", | |
"# module_path = 'https://tfhub.dev/deepmind/biggan-256/2' # 256x256 BigGAN\n", | |
"# module_path = 'https://tfhub.dev/deepmind/biggan-512/2' # 512x512 BigGAN" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow.compat.v1 as tf\n", | |
"tf.disable_v2_behavior()\n", | |
"\n", | |
"import io\n", | |
"import IPython.display\n", | |
"import numpy as np\n", | |
"import PIL.Image\n", | |
"from scipy.stats import truncnorm\n", | |
"import tensorflow_hub as hub\n", | |
"import time\n", | |
"import os\n", | |
"import random" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"tf.reset_default_graph()\n", | |
"print('Loading BigGAN module from:', module_path)\n", | |
"module = hub.Module(module_path)\n", | |
"inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)\n", | |
" for k, v in module.get_input_info_dict().items()}\n", | |
"output = module(inputs)\n", | |
"\n", | |
"print()\n", | |
"print('Inputs:\\n', '\\n'.join(\n", | |
" ' {}: {}'.format(*kv) for kv in inputs.items()))\n", | |
"print()\n", | |
"print('Output:', output)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"input_z = inputs['z']\n", | |
"input_y = inputs['y']\n", | |
"input_trunc = inputs['truncation']\n", | |
"\n", | |
"dim_z = input_z.shape.as_list()[1]\n", | |
"vocab_size = input_y.shape.as_list()[1]\n", | |
"\n", | |
"def truncated_z_sample(batch_size, truncation=1., seed=None):\n", | |
" state = None if seed is None else np.random.RandomState(seed)\n", | |
" values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state)\n", | |
" return truncation * values\n", | |
"\n", | |
"def one_hot(index, vocab_size=vocab_size):\n", | |
" index = np.asarray(index)\n", | |
" if len(index.shape) == 0:\n", | |
" index = np.asarray([index])\n", | |
" assert len(index.shape) == 1\n", | |
" num = index.shape[0]\n", | |
" output = np.zeros((num, vocab_size), dtype=np.float32)\n", | |
" output[np.arange(num), index] = 1\n", | |
" return output\n", | |
"\n", | |
"def one_hot_if_needed(label, vocab_size=vocab_size):\n", | |
" label = np.asarray(label)\n", | |
" if len(label.shape) <= 1:\n", | |
" label = one_hot(label, vocab_size)\n", | |
" assert len(label.shape) == 2\n", | |
" return label\n", | |
"\n", | |
"def sample(sess, noise, label, truncation=1., batch_size=8,\n", | |
" vocab_size=vocab_size):\n", | |
" noise = np.asarray(noise)\n", | |
" label = np.asarray(label)\n", | |
" num = noise.shape[0]\n", | |
" if len(label.shape) == 0:\n", | |
" label = np.asarray([label] * num)\n", | |
" if label.shape[0] != num:\n", | |
" raise ValueError('Got # noise samples ({}) != # label samples ({})'\n", | |
" .format(noise.shape[0], label.shape[0]))\n", | |
" label = one_hot_if_needed(label, vocab_size)\n", | |
" ims = []\n", | |
" for batch_start in range(0, num, batch_size):\n", | |
" s = slice(batch_start, min(num, batch_start + batch_size))\n", | |
" feed_dict = {input_z: noise[s], input_y: label[s], input_trunc: truncation}\n", | |
" ims.append(sess.run(output, feed_dict=feed_dict))\n", | |
" ims = np.concatenate(ims, axis=0)\n", | |
" assert ims.shape[0] == num\n", | |
" ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255)\n", | |
" ims = np.uint8(ims)\n", | |
" return ims\n", | |
"\n", | |
"def interpolate(A, B, num_interps):\n", | |
" if A.shape != B.shape:\n", | |
" raise ValueError('A and B must have the same shape to interpolate.')\n", | |
" alphas = np.linspace(0, 1, num_interps)\n", | |
" return np.array([(1-a)*A + a*B for a in alphas])\n", | |
"\n", | |
"def imgrid(imarray, cols=5, pad=1):\n", | |
" if imarray.dtype != np.uint8:\n", | |
" raise ValueError('imgrid input imarray must be uint8')\n", | |
" pad = int(pad)\n", | |
" assert pad >= 0\n", | |
" cols = int(cols)\n", | |
" assert cols >= 1\n", | |
" N, H, W, C = imarray.shape\n", | |
" rows = N // cols + int(N % cols != 0)\n", | |
" batch_pad = rows * cols - N\n", | |
" assert batch_pad >= 0\n", | |
" post_pad = [batch_pad, pad, pad, 0]\n", | |
" pad_arg = [[0, p] for p in post_pad]\n", | |
" imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)\n", | |
" H += pad\n", | |
" W += pad\n", | |
" grid = (imarray\n", | |
" .reshape(rows, cols, H, W, C)\n", | |
" .transpose(0, 2, 1, 3, 4)\n", | |
" .reshape(rows*H, cols*W, C))\n", | |
" if pad:\n", | |
" grid = grid[:-pad, :-pad]\n", | |
" return grid\n", | |
"\n", | |
"def imshow(a, format='png', jpeg_fallback=True):\n", | |
" a = np.asarray(a, dtype=np.uint8)\n", | |
" data = io.BytesIO()\n", | |
" PIL.Image.fromarray(a).save(data, format)\n", | |
" im_data = data.getvalue()\n", | |
" try:\n", | |
" disp = IPython.display.display(IPython.display.Image(im_data))\n", | |
" except IOError:\n", | |
" if jpeg_fallback and format != 'jpeg':\n", | |
" print(('Warning: image was too large to display in format \"{}\"; '\n", | |
" 'trying jpeg instead.').format(format))\n", | |
" return imshow(a, format='jpeg')\n", | |
" else:\n", | |
" raise\n", | |
" return disp\n", | |
"\n", | |
"\n", | |
"def display_image_to_framebuffer(im):\n", | |
" a = np.array(PIL.Image.fromarray(im).resize((480,480)))\n", | |
" a = a[...,::-1]\n", | |
" a = np.pad(a, ((0,0),(160,160),(0,1)), constant_values=((0,0),(0,0),(255,255)))\n", | |
" a.tofile(\"/dev/fb0\")\n", | |
"\n", | |
" \n", | |
"def create_labels(num, max_classes=3):\n", | |
" label = np.zeros((num, vocab_size))\n", | |
" for i in range(len(label)):\n", | |
" for _ in range(np.random.randint(1, max_classes)):\n", | |
" j = np.random.randint(0, vocab_size-1)\n", | |
" label[i, j] = np.random.random()\n", | |
" label[i] /= label[i].sum()\n", | |
" return label\n", | |
"\n", | |
"\n", | |
"def create_mutation(vector, label):\n", | |
" new_vector = np.zeros((1, vector.shape[0]))\n", | |
" new_label = np.zeros((1, label.shape[0]))\n", | |
"\n", | |
" vector_mutation_rate = vector.std() * 4\n", | |
"\n", | |
" new_label[0][:] = label\n", | |
" dv = (np.random.rand(*vector.shape)-0.5) * vector_mutation_rate\n", | |
" new_vector[0] = vector + dv\n", | |
" new_vector[0] /= max(-new_vector[0].min(), new_vector[0].max())\n", | |
"\n", | |
" # Reduce class\n", | |
" if random.random() < 0.2:\n", | |
" opts = np.nonzero(new_label[0])[0]\n", | |
" if len(opts) != 1:\n", | |
" new_label[0][random.choice(opts)] *= 0.2 + random.random() * 0.6\n", | |
"\n", | |
" # Add class.\n", | |
" if random.random() < 0.3:\n", | |
" new_label[0][random.randint(0, label.shape[0]-1)] += random.random() * 0.5\n", | |
"\n", | |
" # Remove if less than two percent.\n", | |
" new_label[new_label < .02] = 0\n", | |
"\n", | |
" # Normalize.\n", | |
" new_label[0] /= new_label[0].sum()\n", | |
"\n", | |
" return new_vector, new_label\n", | |
"\n", | |
"\n", | |
"def interpolate_and_shape(A, B, num_interps):\n", | |
" interps = interpolate(A, B, num_interps)\n", | |
" return (interps.transpose(1, 0, *range(2, len(interps.shape)))\n", | |
" .reshape(num_samples * num_interps, *interps.shape[2:]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"initializer = tf.global_variables_initializer()\n", | |
"sess = tf.Session()\n", | |
"sess.run(initializer)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Main Loop" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# This version of the main loop picks a new random target, interpolates towards it, and then when it arrives\n", | |
"# it picks another random target and interpolates from there.\n", | |
"# This should explore the latent space fairly well.\n", | |
"truncation = 0.5\n", | |
"num_interps = 16\n", | |
"\n", | |
"alphas = np.linspace(0, 1, num_interps, endpoint=False)\n", | |
"\n", | |
"z_A = truncated_z_sample(1, truncation)\n", | |
"y_A = create_labels(1, random.randrange(3,9))\n", | |
"\n", | |
"while True:\n", | |
" # Pick a new target point in the latent space\n", | |
" z_B = truncated_z_sample(1, truncation)\n", | |
" y_B = create_labels(1, random.randrange(3,9))\n", | |
" \n", | |
" # Interpolate towards the new point\n", | |
" for a in alphas:\n", | |
" z = (1-a)*z_A + a*z_B\n", | |
" y = (1-a)*y_A + a*y_B\n", | |
" \n", | |
" ims = sample(sess, z, y, truncation=truncation)\n", | |
" display_image_to_framebuffer(ims[0])\n", | |
" \n", | |
" # Helps keep CPU cool by giving it a break\n", | |
" time.sleep(4)\n", | |
"\n", | |
" # We've reached the target, set it as the new start point and repeat\n", | |
" z_A = z_B\n", | |
" y_A = y_B" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"truncation = 0.5\n", | |
"\n", | |
"z = truncated_z_sample(1, truncation)\n", | |
"y = create_labels(1, 3)\n", | |
"\n", | |
"while True:\n", | |
" z, y = create_mutation(z[0], y[0])\n", | |
" \n", | |
" ims = sample(sess, z, y, truncation=truncation)\n", | |
" display_image_to_framebuffer(ims[0])\n", | |
" \n", | |
" time.sleep(4)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment