Skip to content

Instantly share code, notes, and snippets.

@fxtentacle
Created April 16, 2020 09:21
Show Gist options
  • Save fxtentacle/0c19dc1ce013f4f98ff57b1261d4b644 to your computer and use it in GitHub Desktop.
Save fxtentacle/0c19dc1ce013f4f98ff57b1261d4b644 to your computer and use it in GitHub Desktop.
TensorFlow problem
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "tf-hang.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "YhVmNy60MIQ-",
"colab_type": "code",
"colab": {}
},
"source": [
"#!/usr/bin/env python\n",
"# coding: utf-8\n",
"\n",
"from __future__ import absolute_import, division, print_function, unicode_literals\n",
"import random\n",
"import tensorflow as tf\n",
"\n",
"\n",
"@tf.function(input_signature=[tf.TensorSpec(shape=(2,), dtype=tf.int32)])\n",
"def advance_random_seed(seed):\n",
" s1, s2 = tf.unstack(seed, num=2)\n",
" halfmaxint = tf.constant(2 ** 30 - 1, tf.int32)\n",
" return tf.stack([s1, tf.truncatemod(s2+1, halfmaxint)])\n",
"\n",
"\n",
"def make_example(seed):\n",
" seed = advance_random_seed(seed)\n",
" target_exp = tf.random.stateless_uniform(shape=(), seed=seed, minval=0, maxval=6, dtype=tf.int32)\n",
" scale_factor = 2 ** target_exp\n",
" use_device = 'GPU:0'\n",
" if scale_factor > 16: use_device = 'CPU:0'\n",
" #tf.print(scale_factor, use_device)\n",
" with tf.device(use_device):\n",
" seed = advance_random_seed(seed)\n",
" input_data = tf.random.stateless_uniform(shape=(1,32*scale_factor,32*scale_factor,3), seed=seed, minval=0.0, maxval=1.0, dtype=tf.float32)\n",
" patches = tf.image.extract_patches(input_data, sizes=[1, scale_factor, scale_factor, 1], strides=[1, scale_factor, scale_factor, 1], rates=[1, 1, 1, 1], padding='VALID')\n",
" patches = tf.reshape(patches, (1,32,32,scale_factor,scale_factor,3))\n",
" lines = tf.reduce_mean(patches, axis=-2)\n",
" return lines\n",
"\n",
"\n",
"def run():\n",
" seeds = []\n",
" for i in range(32):\n",
" seeds.append( [random.getrandbits(30), random.getrandbits(30)] )\n",
" seeds = tf.constant(seeds, tf.int32)\n",
" x = tf.data.Dataset.from_tensor_slices(seeds)\n",
" x = x.map(lambda s: tf.py_function(make_example,[s],tf.float32), num_parallel_calls=4)\n",
" for d in iter(x):\n",
" print('.', end='')\n",
" print('done')\n",
"\n",
"run()"
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment