Skip to content

Instantly share code, notes, and snippets.

@hotpxl
Last active March 2, 2017 14:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save hotpxl/6144870f62bf1acc40a493d3c931f648 to your computer and use it in GitHub Desktop.
Save hotpxl/6144870f62bf1acc40a493d3c931f648 to your computer and use it in GitHub Desktop.
MinPy next step prototype

MinPy Next Step Prototype

Run it with Python 3. Not tested with Python 2 and most probably will not run for now.

This just shows how JIT and gradient code generation works together. Execution and optimization are just placeholders. They will be hooked to NNVM.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import functools
import contextlib
import random
random.seed()
class Array():
def __init__(self, name):
self._name = name
def rename(self, name):
self._name = name
def __add__(self, other):
if jit_enabled:
jit_sequence.append(('add', self, other))
print('add delayed')
else:
print('add eager')
if grad_enabled:
grad_sequence.append(('add', self, other))
return Array('({} + {})'.format(self, other))
def __mul__(self, other):
if jit_enabled:
jit_sequence.append(('mul', self, other))
print('mul delayed')
else:
print('mul eager')
if grad_enabled:
grad_sequence.append(('mul', self, other))
return Array('({} * {})'.format(self, other))
def __repr__(self):
return 'Array {}'.format(self._name)
def eval(self):
if jit_enabled:
# guard instruction
flush_jit_sequence()
def __getitem__(self, key):
self.eval()
return random.randint(0, 1)
jit_enabled = False
jit_sequence = []
grad_enabled = False
grad_sequence = []
jit_cache = {}
def flush_jit_sequence():
k = tuple(map(lambda i: (i[0], i[1]._name, i[2]._name), jit_sequence))
if k in jit_cache:
execute(jit_cache[k])
else:
# Run asynchronously
seq = optimize(jit_sequence)
jit_cache[k] = seq
# Run in main thread
execute(jit_sequence)
jit_sequence.clear()
def flush_grad_sequence():
g = get_grad(grad_sequence)
jit_sequence.extend(g)
grad_sequence.clear()
def reset_jit_cache():
jit_cache.clear()
# Part of NNVM.
def execute(seq):
print('executing seq {}'.format(seq))
def optimize(seq):
return 'optimized {}'.format(seq)
def get_grad(seq):
return list(map(lambda i: (i[0] + '_grad', i[1], i[2]), reversed(seq)))
@contextlib.contextmanager
def jit():
global jit_enabled
jit_enabled = True
yield
flush_jit_sequence()
jit_enabled = False
@contextlib.contextmanager
def grad():
global grad_enabled
grad_enabled = True
yield
flush_grad_sequence()
grad_enabled = False
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import array
def main():
a = array.Array('a')
b = array.Array('b')
print('Plain')
a + b
a * b
array.reset_jit_cache()
print()
print('Run twice')
for i in range(2):
with array.jit():
a + b
a * b
array.reset_jit_cache()
print()
print('Run with branch')
for i in range(4):
with array.jit():
a + b
if i % 2 == 0:
a * b
else:
b * a
array.reset_jit_cache()
print()
print('Run with data dependency')
for i in range(4):
with array.jit():
c = a + b
if c[0] == 1:
print('if')
a * b
else:
print('else')
b * a
array.reset_jit_cache()
print()
print('Run with grad')
with array.jit():
with array.grad():
c = a + b
d = c * a
e = b * c
array.reset_jit_cache()
print()
print('Run with grad and data dependency')
for i in range(4):
print('Iteration {}'.format(i))
with array.jit():
with array.grad():
c = a + b
d = c * a
e = b * c
c.rename('c')
d.rename('d')
e.rename('e')
if e[0] == 0:
print('if')
f = e + e
else:
print('else')
f = e * e
g = f + c
h = g * g
f.rename('f')
g.rename('g')
h.rename('h')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment