Skip to content

Instantly share code, notes, and snippets.

@itamarhaber
Created October 2, 2018 22:01
Show Gist options
  • Save itamarhaber/8bedc325139bf99dec67221042d4e203 to your computer and use it in GitHub Desktop.
Save itamarhaber/8bedc325139bf99dec67221042d4e203 to your computer and use it in GitHub Desktop.
Naively call RedisTF from Python with redis-py
  1. Clone, sh get_deps.sh, and make run https://github.com/lantiga/RedisTF
  2. pip install redis tensorflow
  3. cd models
  4. Run python tf-minimal.py to prepare a minimal graph
  5. The attached redistf-py.py demonstrates how to run the test in the README from Python:
$ python redistf-py.py
Setting the graph: OK
Setting tensor t1: OK
Setting tensor t2: OK
Running the thing: OK
Resulting values: ['4', '9']
  1. Success!
import redis
class RedisTF(object):
def __init__(self, conn):
self.__conn = conn
# Loads a graph
def SetGraph(self, graph, path):
with open(path, 'r') as f:
payload = f.read()
return self.__conn.execute_command('TF.GRAPH', graph, payload)
# Sets a tensor
def SetTensor(self, tensor, dtype, shape, values):
args = [tensor, dtype]
args += [str(x) for x in shape]
args.append('VALUES')
args += [str(x) for x in values]
return self.__conn.execute_command('TF.TENSOR', *args)
# Runs a graph with a list of input tensor-name tuples, storing the result
def Run(self, graph, inputs, output):
args = [graph, len(inputs)]
for i in inputs:
args.append(i[0])
args.append(i[1])
args += [str(x) for x in output]
return self.__conn.execute_command('TF.RUN', *args)
# Gets the value of a tensor
def Values(self, tensor):
return self.__conn.execute_command('TF.VALUES', tensor)
if __name__ == '__main__':
conn = redis.StrictRedis()
rtf = RedisTF(conn)
print 'Setting the graph: {}'.format(rtf.SetGraph('graph', 'graph.pb'))
print 'Setting tensor t1: {}'.format(rtf.SetTensor('t1', 'FLOAT', [1, 2], [2, 3]))
print 'Setting tensor t2: {}'.format(rtf.SetTensor('t2', 'FLOAT', [1, 2], [2, 3]))
print 'Running the thing: {}'.format(rtf.Run('graph', [('t1', 'a'), ('t2', 'b')], ('t3', 'c')))
print 'Resulting values: {}'.format(rtf.Values('t3'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment