Skip to content

Instantly share code, notes, and snippets.

@chizhang529
Last active August 17, 2018 03:51
Show Gist options
  • Save chizhang529/a8191a4d2061aa76106e4b54fe369201 to your computer and use it in GitHub Desktop.
Save chizhang529/a8191a4d2061aa76106e4b54fe369201 to your computer and use it in GitHub Desktop.
Useful Snippets for Tensorflow

Check GPU Works with Tensorflow

You may want to verify GPU is working with Tensorflow in runtime. Here are some useful python code snippets:

## show all devices available ##
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

To make sure GPU works in computation:

import tensorflow as tf
with tf.device('/gpu:0'):
    a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
    b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')
    c = tf.matmul(a, b)

with tf.Session() as sess:
    print (sess.run(c))

If you have a gpu and can use it, you will see the result [[ 22. 28.] [ 49. 64.]] with device information like

Found device 0 with properties:
      name: GeForce GTX 1080 Ti
      major: 6 minor: 1 
      memoryClockRate(GHz): 1.582 
      pciBusID: 0000:02:00.0 
      totalMemory: 10.92GiB 
      freeMemory: 10.76GiB

Otherwise, you will see an error with a long stacktrace. In the end you will have something like this:

Cannot assign a device to node 'MatMul':
Could not satisfy explicit device specification '/device:GPU:0' because no devices matching that specification are registered in this process

Another easy way to avoid all verbose logs above is

import tensorflow as tf
if tf.test.gpu_device_name():
    print("Default GPU Device: {}".format(tf.test.gpu_device_name()))
else:
    print("Please install Tensorflow GPU or buy a graphics card. :-)")

Shut Up Tensorflow Warnings

If you feel annoyed with warnings about that code wasn't compiled to SSV, AVX or FMA Instructions, then you can mute those by setting an environment variable to filter INFO and WARNING messages:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
## do your work below ##

As a general approach for filtering out uninteresting TensorFlow messages, it may be useful to run Tensorflow scripts as

run_tf.sh my_script.py

where run_tf.sh does

#!/bin/sh
# Run python script, filtering out TensorFlow logging
# https://github.com/tensorflow/tensorflow/issues/566#issuecomment-259170351
python $* 3>&1 1>&2 2>&3 3>&- | grep -v ":\ I\ " | grep -v "WARNING:tensorflow" | grep -v ^pciBusID | grep -v ^major: | grep -v ^name: | grep -v ^Total\ memory: | grep -v ^Free\ memory:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment