Skip to content

Instantly share code, notes, and snippets.

@yoshitaka-xvi
Created December 25, 2017 07:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yoshitaka-xvi/0f821b07d581f3e9c722789859e95092 to your computer and use it in GitHub Desktop.
Save yoshitaka-xvi/0f821b07d581f3e9c722789859e95092 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import argparse
import sys
from collections import deque
import pyaudio
# pylint: disable=unused-import
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
# pylint: enable=unused-import
FLAGS = None
def load_graph(filename):
"""Unpersists graph from file as default graph."""
with tf.gfile.FastGFile(filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
def load_labels(filename):
"""Read in labels, one label per line."""
return [line.rstrip() for line in tf.gfile.GFile(filename)]
def run_graph(labels, input_layer_name, output_layer_name,
num_top_predictions):
"""Runs the audio data through the graph and prints predictions."""
seconds = 1
rate = 16000
ring_buffer = deque((0).to_bytes(2, 'little')*seconds*rate, maxlen=seconds * rate * 2)
audio = pyaudio.PyAudio()
stream = audio.open(
format=pyaudio.paInt16,
channels=1,
rate=rate,
input=True,
frames_per_buffer=1024)
wav_header = b'RIFF' + (4+4+16+seconds*rate*2).to_bytes(4, 'little') + b'WAVE' + b'fmt ' + (16).to_bytes(4, 'little') + (1).to_bytes(2, 'little') + (1).to_bytes(2, 'little') + (rate).to_bytes(4, 'little') + (2*rate).to_bytes(4, 'little') + (2).to_bytes(2, 'little') + (16).to_bytes(2, 'little') + b'data' + (seconds*rate*2).to_bytes(4, 'little')
num = 0
with tf.Session() as sess:
# Feed the audio data as input to the graph.
# predictions will contain a two-dimensional array, where one
# dimension represents the input image count, and the other has
# predictions per class
softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
while True:
read_buffer = stream.read(1024)
ring_buffer.extend(read_buffer)
predictions, = sess.run(softmax_tensor, {input_layer_name: wav_header + bytes(ring_buffer)})
# Sort to show labels in order of confidence
top_k = predictions.argsort()[-num_top_predictions:][::-1]
for node_id in top_k:
human_string = labels[node_id]
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
return 0
def label_wav(labels, graph, input_name, output_name, how_many_labels):
"""Loads the model and labels, and runs the inference to print predictions."""
if not labels or not tf.gfile.Exists(labels):
tf.logging.fatal('Labels file does not exist %s', labels)
if not graph or not tf.gfile.Exists(graph):
tf.logging.fatal('Graph file does not exist %s', graph)
labels_list = load_labels(labels)
# load graph, which is stored in the default session
load_graph(graph)
run_graph(labels_list, input_name, output_name, how_many_labels)
def main(_):
"""Entry point for script, converts flags to arguments."""
label_wav(FLAGS.labels, FLAGS.graph, FLAGS.input_name,
FLAGS.output_name, FLAGS.how_many_labels)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--graph',
type=str,
default='D:/tmp/speech_commands_train/frozen_low_latency_conv.pb', help='Model to use for identification.')
parser.add_argument(
'--labels',
type=str,
default='D:/tmp/speech_commands_train/low_latency_conv_labels.txt',
help='Path to file containing labels.')
parser.add_argument(
'--input_name',
type=str,
default='wav_data:0',
help='Name of WAVE data input node in model.')
parser.add_argument(
'--output_name',
type=str,
default='labels_softmax:0',
help='Name of node outputting a prediction in the model.')
parser.add_argument(
'--how_many_labels',
type=int,
default=1,
help='Number of results to show.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment