Skip to content

Instantly share code, notes, and snippets.

@JonathanRaiman
Last active November 27, 2018 02:25
Show Gist options
  • Save JonathanRaiman/18251daab19c2f69c64d016db8d22b0c to your computer and use it in GitHub Desktop.
Save JonathanRaiman/18251daab19c2f69c64d016db8d22b0c to your computer and use it in GitHub Desktop.
Dali graph transformation Plan
"""
Micro-dali JIT Plan:
- contains gemm, operator fusion, elementwise/reduction ops.
- supports tensordot
- supports 'jit'
- supports conversion from gemm + im2col to conv2d (NHWC)
- supports 'optimization' passes
- supports 'implementation' registries for specialization
(e.g. int vs float)
TODO:
- ?
STRECH GOALS:
- allow multiple outputs
- efficient patterns for optimization registry
- allow forced assignments (that should not be replaced by optimization)
- reorder matmuls?
- paced JIT running (do reductions, then do elementwise, etc...)
- async running with locks/mutexes?
"""
import numpy as np
import time
import tensorflow as tf
## REGISTRIES for implementations or graph transformations
IMPLEMENTATIONS = {}
OPTIMIZATIONS = []
def tname(x):
return type(x).__name__
def register_implementation(opname, implname):
"""
Mark that elements built from `opname` (e.g. MatMul)
can be implemented using `implname`. implname is
a callback that receives the current Array and
returns the appropriate implementation class.
"""
IMPLEMENTATIONS[opname.__name__] = implname
def register_optimization(condition, transformation):
"""
Update the computation graph if the local Array
matches the condition given by calling transformation
on that node (note: optimizations are run bottom-up).
"""
OPTIMIZATIONS.append((condition, transformation))
class Expression(object):
def ndim(self):
return len(self.shape())
def __str__(self):
return tname(self) + "(" + ", ".join([str(arg) for arg in self.arguments()]) + ")"
def __eq__(self, other):
if type(self) != type(other):
return False
self_args = self.arguments()
other_args = other.arguments()
if len(self_args) != len(other_args):
return False
return all(arg == oarg for arg, oarg in zip(self_args, other_args))
def supports_operator(self, operator):
return operator == "="
def shape_to_trivial_strides(shape):
res = [0 for _ in shape]
residual_shape = 1
for i in reversed(range(0, len(shape))):
res[i] = residual_shape
residual_shape *= shape[i]
return res
class Buffer(Expression):
def __init__(self, data):
self._data = data
def __eq__(self, other):
if not isinstance(other, Buffer) or self._data.shape != other._data.shape:
return False
return np.alltrue(np.equal(self._data, other._data))
def shape(self):
return self._data.shape
def dtype(self):
return self._data.dtype
def value(self):
return self._data
def arguments(self):
return []
def data(self):
return self._data
def reshape(self, shape):
return Buffer(self._data.reshape(shape))
def dimshuffle(self, axes):
return Buffer(self._data.transpose(axes))
def contiguous_memory(self):
return self._data.flags.c_contiguous
def strides(self):
item_size = self._data.itemsize
return [stride // item_size for stride in self._data.strides]
def is_transpose(self):
ndim = self.ndim()
if ndim <= 1:
return True
if self.contiguous_memory():
return False
reversed_shape = list(reversed(self.shape()))
reversed_strides = shape_to_trivial_strides(reversed_shape)
strides = self.strides()
for i in range(0, ndim):
if reversed_strides[i] != strides[ndim - 1 - i]:
return False
return True
def supports_operator(self, operator):
return True
class Array(object):
def __init__(self, internal):
self._expression = internal
self._simplified = False
def __eq__(self, other):
if not isinstance(other, Array):
return False
return self._expression == other._expression
def canonical(self):
node = self
# assignment pass
node = all_assignments_or_buffers(node)
# simplification pass (jit, merge, etc...)
return simplify_destination(node)
def eval(self):
if not isinstance(self._expression, Buffer):
node = self.canonical()
computable = convert_to_ops(node)
# run (DAG evaluation)
for step in computable:
step.run()
self._expression = node._expression._left._expression
def value(self):
self.eval()
return self._expression.value()
def __str__(self):
return str(self._expression)
@property
def dtype(self):
return self._expression.dtype()
@property
def shape(self):
return self._expression.shape()
@property
def ndim(self):
return self._expression.ndim()
@property
def T(self):
return transpose(self)
def buffer(x):
return Array(Buffer(x.copy()))
class Node(Expression):
pass
class Assignment(Expression):
def __init__(self, left, operator, right):
self._left = left
self._operator = operator
self._right = right
def arguments(self):
return [self._left, self._right]
def shape(self):
return self._left._expression.shape()
def dtype(self):
return self._left.dtype
def data(self):
return self._left._expression.data()
def __str__(self):
return (tname(self) + "(" + self._operator + ", " +
str(self._left) + ", " + str(self._right) + ")")
class ControlFlow(Expression):
"""Signify that a buffer (left) requires all the operations
in 'conditions' to complete before progressing"""
def __init__(self, left, conditions):
self._left = left
self._conditions = conditions
def arguments(self):
return [self._left] + self._conditions
def shape(self):
return self._left._expression.shape()
def dtype(self):
return self._left.dtype
def data(self):
return self._left._expression.data()
def type_promotion(left_dtype, right_dtype):
np_left_dtype = np.dtype(left_dtype)
np_right_dtype = np.dtype(right_dtype)
if np_left_dtype.kind == "i" and np_right_dtype.kind == "i":
if np_left_dtype.itemsize > np_right_dtype.itemsize:
return left_dtype
return right_dtype
if np_left_dtype.kind == "f":
if np_right_dtype.kind == "f" and np_right_dtype.itemsize > np_left_dtype.itemsize:
return right_dtype
return left_dtype
raise ValueError("unknown dtype promotion.")
def autoreduce_assign(left, right):
"""
Special autoreduction useful for gradient propagation.
Implements reduce all non matching dimensions to be 1 and
add to the left. Can be JIT-ed and combined with other
element-wise operations (if benefitial)
"""
if not isinstance(right._expression, Buffer):
right = to_assignment(right)
reduction_axes = []
for axis, (left_dim, right_dim) in enumerate(zip(left.shape, right.shape)):
if left_dim == 1 and right_dim > 1:
reduction_axes.append(axis)
return assign(left, "+=", reduce_sum(right, reduction_axes, keep_dims=True))
def assign(left, operator, right):
"""Assign right data to left destination with operator.
If the right side is not yet evaluated, and the operator
is not simple equality, then a temporary destination for
saving the right side is added.
"""
if operator == "=":
return Array(Assignment(left, operator, right))
elif operator == "<<=":
res = autoreduce_assign(left, right)
return res
else:
if not isinstance(right._expression, Buffer):
right = to_assignment(right)
# a temp is added so that non overwriting operators
# can be run independently from the right side's evaluation.
return Array(Assignment(left, operator, right))
def to_assignment(node):
return assign(buffer(np.zeros(node.shape, node.dtype)),
"=", Array(node._expression))
def right_args(node):
return node._expression._right._expression.arguments()
def buffer_arg(node):
if isinstance(node._expression, Buffer):
return node
if isinstance(node._expression, (Assignment, ControlFlow)):
return node._expression._left
return None
class Allocate(object):
"""Dummy allocation step. Could involve data movement to/from GPU."""
def __init__(self, node):
self._node = node
def run(self):
pass
### GRAPH TRANSFORMATION PASSES ###
def buffer_buffer_op(node):
print(node)
raise ValueError("hell no")
identity_node = Array(Identity(node._expression._right))
el = Array(Assignment(node._expression._left,
node._expression._operator,
Array(JITRunner(identity_node, [node._expression._right]))))
return el
def convert_to_ops(root):
steps = []
elements = [root]
while len(elements) > 0:
element = elements.pop()
if isinstance(element._expression, Buffer):
steps.append(Allocate(element))
elif isinstance(element._expression, Assignment):
#if isinstance(element._expression._right._expression, Buffer) or
if isinstance(element._expression._right._expression, Assignment):
# TODO: clean this up
element = buffer_buffer_op(element)
name_of_object = tname(element._expression._right._expression)
if name_of_object in IMPLEMENTATIONS:
steps.append(IMPLEMENTATIONS[name_of_object](element._expression._right)(
element._expression._right,
element._expression._operator,
element._expression._left))
# ensure destination is allocated:
elements.append(element._expression._left)
elements.extend(right_args(element))
else:
raise ValueError("no way to implement %r" % (name_of_object,))
elif isinstance(element._expression, ControlFlow):
# add all the dependencies of this node as a step to complete:
elements.extend(element._expression.arguments())
else:
raise ValueError("can only convert Assignments and Buffers to ops (got %s)." % (str(element),))
return list(reversed(steps))
def can_copyless_reshape(node, shape):
"""
Returns True if the shape/strides of the node
are compatible with the new shape without requiring
a copy.
"""
if node._expression.contiguous_memory():
return True
ndim = node.ndim
shape_ = node.shape
if len(shape) > ndim:
# check if the lowest dimensions will be identical
matching_lowest = True
for i in range(0, ndim):
if shape[len(shape) - i - 1] != shape_[ndim - i - 1]:
matching_lowest = False
break
is_ones_elsewhere = True
for i in range(0, len(shape) - ndim):
if shape[i] != 1:
is_ones_elsewhere = False
break
if matching_lowest and is_ones_elsewhere:
return True
return False
def can_reshape_inplace(node):
if not (isinstance(node._expression, Assignment) and
node._expression._operator == "=" and
isinstance(node._expression._right._expression, Reshape)):
return False
reshape_node = node._expression._right
return can_copyless_reshape(buffer_arg(reshape_node._expression._node),
reshape_node._expression._shape)
def to_reshape_inplace(node):
"""
Replace an assignment of buffer = reshape(buffer) by
an inplace reshape(buffer) + control flow for maintaining
antecedents of the buffer.
"""
reshape_node = node._expression._right
buffer_node = buffer_arg(reshape_node._expression._node)
shape = reshape_node._expression._shape
return Array(ControlFlow(Array(buffer_node._expression.reshape(shape)),
[reshape_node._expression._node]))
def is_jit_assignment(node):
return (isinstance(node._expression, Assignment) and
isinstance(node._expression._right._expression, JITNode) and
not isinstance(node._expression._right._expression, JITRunner))
GEMM_OPERATORS = ("+=", "-=")
def is_chained_assignment(node):
""" assign(Left, operator, assign(Temp, '=', Right)) =>
assign(Left, operator, Right)"""
return (isinstance(node._expression, Assignment) and
node._expression._operator == "=" and
isinstance(node._expression._right._expression, Assignment) and
node._expression._right._expression._operator == "=")
def is_chained_or_gemm_assignment(node):
return (isinstance(node._expression, Assignment) and
isinstance(node._expression._right._expression, Assignment) and
node._expression._right._expression._operator == "=" and
(node._expression._operator == "=" or
(node._expression._operator in GEMM_OPERATORS and
isinstance(node._expression._right._expression._right._expression, MatMul))))
def jit_root(node):
if isinstance(node._expression, JITRunner):
return node._expression._root
return node
def replace_assign_with_inplace(node):
rightside = jit_root(node._expression._right)
operator = node._expression._operator
if operator == "=":
return rightside, None
elif operator == "+=":
return add(node._expression._left, rightside), node._expression._left
elif operator == "-=":
return substract(node._expression._left, rightside), node._expression._left
elif operator == "*=":
return eltmul(node._expression._left, rightside), node._expression._left
elif operator == "/=":
return eltdiv(node._expression._left, rightside), node._expression._left
else:
raise ValueError("cannot replace assign inplace with operator %r" % (operator,))
def assign_merge(root):
original_root_buffer = root._expression._left
original_root_operator = root._expression._operator
return Array(Assignment(original_root_buffer,
original_root_operator,
root._expression._right._expression._right))
def jit_merge(root):
leaves = []
root_buffer = root._expression._left
root_operator = root._expression._operator
for arg in right_args(root):
if (isinstance(arg._expression, Assignment) and
isinstance(arg._expression._right._expression, JITRunner)):
# grab leaves from existing jit-runner recursively:
leaves.extend(arg._expression._right._expression._leaves)
# if the node is an assignment to a buffer, ensure that
# the assignment op gets included within this op
# (e.g. by spoofing the assignment and replacing it with
# the equivalent JIT op)
replaced, left_leaf = replace_assign_with_inplace(arg)
# if the assignment involves using the left-side (e.g.
# left += right -> left + right), then keep the left node
# as a dependency leaf:
if left_leaf is not None:
leaves.append(left_leaf)
# now that the jitrunners and assignments are gone, connect
# up the new operation in the graph:
arg._expression = replaced._expression
# elif isinstance(arg._expression, Assignment):
# new_arg = Array(arg._expression)
# arg._expression = arg._expression._left._expression
# leaves.append(new_arge)
else:
# this node is either an assignment, or a buffer,
# and is needed as an input here:
leaves.append(arg)
new_root = root._expression._right
return Array(Assignment(
# keep the original target buffer:
root_buffer, root_operator,
# use the merged operation instead
Array(JITRunner(new_root, leaves))))
register_optimization(can_reshape_inplace, to_reshape_inplace)
register_optimization(is_jit_assignment, jit_merge)
register_optimization(is_chained_or_gemm_assignment, assign_merge)
def conv2d_merge(root):
original_root_buffer = root._expression._left
gemm_node = root._expression._conditions[0]._expression._right
im2col_node = gemm_node._expression._left._expression._conditions[0]._expression._right
x = im2col_node._expression._input
strides = im2col_node._expression._strides
padding = im2col_node._expression._padding
data_format = im2col_node._expression._data_format
filter_size = im2col_node._expression._filter_size
w = reshape(gemm_node._expression._right, (filter_size[0], filter_size[1], x.shape[3],
original_root_buffer.shape[3]))
return Array(Assignment(original_root_buffer,
"=",
Array(Conv2D(x, w, strides, padding, data_format))))
def is_im2col_gemm(node):
"""6 micro second check"""
# TODO: add check if conv2D is supported for this data type.
is_cflow_assign = (isinstance(node._expression, ControlFlow) and
len(node._expression._conditions) == 1 and
isinstance(node._expression._conditions[0]._expression, Assignment) and
node._expression._conditions[0]._expression._operator == "=")
if not is_cflow_assign:
return False
gemm_node = node._expression._conditions[0]._expression._right
is_assign_gemm = (isinstance(gemm_node._expression, MatMul) and
isinstance(gemm_node._expression._left._expression, ControlFlow) and
len(gemm_node._expression._left._expression._conditions) == 1 and
isinstance(gemm_node._expression._left._expression._conditions[0]._expression, Assignment))
if not is_assign_gemm:
return False
im2col_node = gemm_node._expression._left._expression._conditions[0]._expression._right
is_im2col = isinstance(im2col_node._expression, Im2col)
return is_im2col
register_optimization(is_im2col_gemm, conv2d_merge)
def simplify_destination(root):
# leaf node:
if isinstance(root._expression, Buffer):
return root
# recurse on children:
children = ([root._expression._right]
if isinstance(root._expression, Assignment)
else root._expression.arguments())
# recurse on arguments of node:
for arg in children:
arg._expression = simplify_destination(arg)._expression
for condition, transformation in OPTIMIZATIONS:
if condition(root):
root = transformation(root)
return root
def all_assignments_or_buffers(root):
"""
Transform graph so that it only uses
Buffers or assignments of buffers.
(e.g. give everyone a destination)
"""
if isinstance(root._expression, Buffer):
return root
if not isinstance(root._expression, Assignment):
root = to_assignment(root)
if (isinstance(root._expression._right._expression, Assignment) and
root._expression._right._expression._operator == "=" and
root._expression._right._expression._right._expression.supports_operator(root._expression._operator)):
root._expression._right._expression = root._expression._right._expression._right._expression
for arg in right_args(root):
arg._expression = all_assignments_or_buffers(arg)._expression
return root
### OPS (REGISTRY) ###
class Computation(object):
"""Abstract Computation"""
def __init__(self, op, operator, target):
self._op = op
self._operator = operator
self._target = target
class JITNode(Node):
def supports_operator(self, operator):
return True
class JITRunner(JITNode):
"""Merged jit nodes into one."""
def __init__(self, root, leaves):
if isinstance(root._expression, JITRunner):
raise ValueError("JITRunner should not contain a JITRunner.")
self._root = root
self._leaves = leaves
def arguments(self):
return self._leaves
def shape(self):
return self._root._expression.shape()
def dtype(self):
return self._root._expression.dtype()
def __str__(self):
# pretty print jit merged op into JIT[kernel](inputs)
return "JIT[" + str(self._root).replace("Buffer()", "X") + "](" + ", ".join([str(arg) for arg in self.arguments()]) + ")"
def jit_execute(root):
"""Recursive function for simulating JIT execution."""
if isinstance(root._expression, (Buffer, Assignment, ControlFlow)):
return root._expression.data()
elif isinstance(root._expression, Add):
return jit_execute(root._expression._left) + jit_execute(root._expression._right)
elif isinstance(root._expression, Subtract):
return jit_execute(root._expression._left) - jit_execute(root._expression._right)
elif isinstance(root._expression, EltMul):
return jit_execute(root._expression._left) * jit_execute(root._expression._right)
elif isinstance(root._expression, EltDiv):
return jit_execute(root._expression._left) / jit_execute(root._expression._right)
elif isinstance(root._expression, Tanh):
return np.tanh(jit_execute(root._expression._node))
elif isinstance(root._expression, Identity):
return jit_execute(root._expression._node)
elif isinstance(root._expression, ReduceSum):
return np_reduce_sum(jit_execute(root._expression._node),
root._expression._axis,
root._expression._keep_dims)
else:
raise ValueError("no jit execution for %r (%r)" % (str(root), type(root._expression)))
def assign_with_operator(left, operator, right):
if operator == "=":
left[:] = right
elif operator == "+=":
left[:] += right
elif operator == "-=":
left[:] -= right
elif operator == "*=":
left[:] *= right
elif operator == "/=":
left[:] /= right
else:
raise ValueError("unknown operator behavior %r" % (operator,))
class JITRunnerImpl(Computation):
def run(self):
# simulate generating a kernel
# and doing the actual work in one call
root = self._op._expression._root
assign_with_operator(self._target._expression.data(),
self._operator,
jit_execute(root))
register_implementation(JITRunner, lambda x: JITRunnerImpl)
def buffer_impl(op, operator, target):
right_identity = Array(Identity(op))
right_runner = Array(JITRunner(right_identity, [op]))
return JITRunnerImpl(right_runner, operator, target)
register_implementation(Buffer, lambda x: buffer_impl)
class BinaryElementWise(JITNode):
def __init__(self, left, right):
self._left = left
self._right = right
def arguments(self):
return [self._left, self._right]
def shape(self):
return self._left.shape
def dtype(self):
return type_promotion(self._left._expression.dtype(),
self._right._expression.dtype())
class UnitaryElementWise(JITNode):
def __init__(self, node):
self._node = node
def arguments(self):
return [self._node]
def shape(self):
return self._node.shape
def dtype(self):
return self._node.dtype
class Add(BinaryElementWise):
pass
class Subtract(BinaryElementWise):
pass
class EltMul(BinaryElementWise):
pass
class EltDiv(BinaryElementWise):
pass
def add(a, b):
return Array(Add(a, b))
def substract(a, b):
return Array(Subtract(a, b))
def eltmul(a, b):
return Array(EltMul(a, b))
def eltdiv(a, b):
return Array(EltDiv(a, b))
class Tanh(UnitaryElementWise):
"""tanh"""
def dtype(self):
return np.float64
class ReduceSum(UnitaryElementWise):
def __init__(self, node, axis, keep_dims):
self._node = node
self._axis = axis
self._keep_dims = keep_dims
def shape(self):
shape = list(self._node.shape)
for ax in self._axis:
shape[ax] = 1 if self._keep_dims else 0
if self._keep_dims:
return shape
return [dim for dim in shape if dim > 0]
def reduce_sum(node, axis, keep_dims=False):
# TODO: distinguish contiguous vs. non contiguous reduction
return Array(ReduceSum(node, axis, keep_dims))
def tanh(x):
return Array(Tanh(x))
class MatMul(Node):
def __init__(self, left, right):
self._left = left
self._right = right
def arguments(self):
return [self._left, self._right]
def shape(self):
return (self._left._expression.shape()[0],
self._right._expression.shape()[1])
def dtype(self):
return type_promotion(self._left.dtype, self._right.dtype)
def calc_pad(pad, in_siz, out_siz, stride, ksize):
"""Calculate padding width.
Args:
pad: padding method, "SAME", "VALID", or manually speicified.
ksize: kernel size [I, J].
Returns:
pad_: Actual padding width.
"""
if pad == 'SAME':
return int((out_siz - 1) * stride + ksize - in_siz)
elif pad == 'VALID':
return 0
else:
return pad
def calc_size(h, kh, pad, sh):
"""Calculate output image size on one dimension.
Args:
h: input image size.
kh: kernel size.
pad: padding strategy.
sh: stride.
Returns:
s: output size.
"""
if pad == 'VALID':
return int(np.ceil((h - kh + 1) / sh))
elif pad == 'SAME':
return int(np.ceil(h / sh))
else:
return int(np.ceil((h - kh + pad + 1) / sh))
def extract_sliding_windows(x, ksize, padding, strides, floor_first=True, out=None):
"""Converts a tensor to sliding windows.
Args:
x: [N, H, W, C]
k: [KH, KW]
pad: [PH, PW]
strides: [NBATCH, SH, SW, NCHANNELS]
Returns:
y: [N, (H-KH+PH+1)/SH, (W-KW+PW+1)/SW, KH * KW, C]
"""
n = x.shape[0]
h = x.shape[1]
w = x.shape[2]
c = x.shape[3]
kh = ksize[0]
kw = ksize[1]
sh = strides[1]
sw = strides[2]
h2 = calc_size(h, kh, padding, sh)
w2 = calc_size(w, kw, padding, sw)
ph = calc_pad(padding, h, h2, sh, kh)
pw = calc_pad(padding, w, w2, sw, kw)
ph0 = int(np.floor(ph / 2))
ph1 = int(np.ceil(ph / 2))
pw0 = int(np.floor(pw / 2))
pw1 = int(np.ceil(pw / 2))
if floor_first:
pph = (ph0, ph1)
ppw = (pw0, pw1)
else:
pph = (ph1, ph0)
ppw = (pw1, pw0)
x = np.pad(
x, ((0, 0), pph, ppw, (0, 0)),
mode='constant',
constant_values=(0.0, ))
if out is None:
out = np.zeros([n, h2, w2, kh, kw, c], dtype=x.dtype)
for ii in range(h2):
for jj in range(w2):
xx = ii * sh
yy = jj * sw
out[:, ii, jj, :, :, :] = x[:, xx:xx + kh, yy:yy + kw, :]
return out
class Im2col(Node):
def __init__(self, input, filter_size, strides, padding, data_format):
self._input = input
self._filter_size = filter_size
self._strides = strides
self._padding = padding
self._data_format = data_format
def arguments(self):
return [self._input]
def shape(self):
ksize = self._filter_size
x = self._input
n = x.shape[0]
h = x.shape[1]
w = x.shape[2]
c = x.shape[3]
kh = ksize[0]
kw = ksize[1]
sh = self._strides[1]
sw = self._strides[2]
h2 = calc_size(h, kh, self._padding, sh)
w2 = calc_size(w, kw, self._padding, sw)
ph = calc_pad(self._padding, h, h2, sh, kh)
pw = calc_pad(self._padding, w, w2, sw, kw)
ph0 = int(np.floor(ph / 2))
ph1 = int(np.ceil(ph / 2))
pw0 = int(np.floor(pw / 2))
pw1 = int(np.ceil(pw / 2))
return (n, h2, w2, kh, kw, c)
def dtype(self):
return self._input.dtype
class Im2colImpl(Computation):
def run(self):
op = self._op._expression
extract_sliding_windows(op._input._expression.data(), op._filter_size,
padding=op._padding,
strides=op._strides,
out=self._target._expression.data())
register_implementation(Im2col, lambda x: Im2colImpl)
class Conv2D(Node):
def __init__(self, input, filter, strides, padding, data_format):
self._input = input
self._filter = filter
self._strides = strides
self._padding = padding
self._data_format = data_format
def arguments(self):
return [self._input, self._filter]
def shape(self):
return (self._input.shape[0],
calc_size(self._input.shape[1], self._filter.shape[0], self._padding, self._strides[1]),
calc_size(self._input.shape[2], self._filter.shape[1], self._padding, self._strides[2]),
self._filter.shape[3])
def dtype(self):
return type_promotion(self._input.dtype, self._filter.dtype)
class Conv2DImpl(Computation):
def run(self):
w = self._op._expression._filter._expression.data()
x = self._op._expression._input._expression.data()
ksize = w.shape[:2]
x = extract_sliding_windows(x, ksize,
padding=self._op._expression._padding,
strides=self._op._expression._strides)
ws = w.shape
w = w.reshape([ws[0] * ws[1] * ws[2], ws[3]])
xs = x.shape
x = x.reshape([xs[0] * xs[1] * xs[2], xs[3] * xs[4] * xs[5]])
out = self._target._expression.data()
out = out.reshape([x.shape[0], w.shape[1]])
gemm(x, w, out, alpha=1.0, beta=0.0)
register_implementation(Conv2D, lambda x: Conv2DImpl)
def conv2d(input, filter, strides, padding, data_format="NHWC"):
assert(data_format == "NHWC")
# test for dimensions here...
return Array(Conv2D(input, filter, strides, padding, data_format))
def im2col(input, kernel_size, strides, padding, data_format="NHWC"):
assert(data_format == "NHWC")
return Array(Im2col(input, kernel_size, strides, padding, data_format))
def im2col_conv2d(input, filter, strides, padding, data_format="NHWC"):
ksize = filter.shape[:2]
patches = im2col(input, ksize, strides, padding, data_format)
patches_2d = reshape(patches, (patches.shape[0] * patches.shape[1] * patches.shape[2],
patches.shape[3] * patches.shape[4] * patches.shape[5]))
filter_2d = reshape(filter, (filter.shape[0] * filter.shape[1] * filter.shape[2],
filter.shape[3]))
output_2d = dot(patches_2d, filter_2d)
return reshape(output_2d, (input.shape[0], patches.shape[1],
patches.shape[2], filter.shape[3]))
def gemm(a, b, c, alpha, beta):
"""Reference implementation for actual BLAS gemm
Note: does nothing special with transposes etc..."""
if beta == 0.0:
np.matmul(a, b, c)
c *= alpha
else:
c[:] = c * beta + np.matmul(a, b) * alpha
class MatMulImpl(Computation):
def _get_alpha(self):
return -1.0 if self._operator == "-=" else 1.0
def _get_beta(self):
return 0.0 if self._operator == "=" else 1.0
def run(self):
gemm(self._op._expression._left._expression.data(),
self._op._expression._right._expression.data(),
self._target._expression.data(),
alpha=self._get_alpha(),
beta=self._get_beta())
class IMatMulImpl(MatMulImpl):
def run(self):
gemm(self._op._expression._left._expression.data(),
self._op._expression._right._expression.data(),
self._target._expression.data(),
alpha=int(self._get_alpha()),
beta=int(self._get_beta()))
def choose_matmul(x):
if x.dtype == np.float32 or x.dtype == np.float64:
return MatMulImpl
elif x.dtype == np.int32 or x.dtype == np.int64:
return IMatMulImpl
else:
raise ValueError("no implementation found.")
register_implementation(MatMul, choose_matmul)
class Dimshuffle(Node):
def __init__(self, node, axes):
self._node = node
self._axes = axes
def arguments(self):
return [self._node]
def shape(self):
original_shape = self._node._expression.shape()
return tuple([original_shape[i] for i in self._axes])
def dtype(self):
return self._node._expression.dtype()
class DimshuffleImpl(Computation):
def run(self):
print("dimshuffle...")
assign_with_operator(self._target._expression.data(),
self._operator,
self._op._expression._node._expression.data().transpose(self._op._expression._axes))
register_implementation(Dimshuffle, lambda x: DimshuffleImpl)
def dimshuffle(node, axes):
if isinstance(node._expression, Buffer):
return Array(node._expression.dimshuffle(axes))
for i, ax in enumerate(axes):
if i != ax:
return Array(Dimshuffle(node, axes))
return node
def transpose(node, axes=None):
if axes is None:
axes = list(reversed(range(node._expression.ndim())))
return dimshuffle(node, axes)
class Reshape(Node):
def __init__(self, node, shape):
self._node = node
self._shape = shape
def arguments(self):
return [self._node]
def shape(self):
return self._shape
def dtype(self):
return self._node._expression.dtype()
class ReshapeImpl(Computation):
def run(self):
print("reshape...")
assign_with_operator(self._target._expression.data(),
self._operator,
self._op._expression._node._expression.data().reshape(self._op._expression._shape))
register_implementation(Reshape, lambda x: ReshapeImpl)
def reshape(node, shape):
if tuple(node._expression.shape()) == tuple(shape):
return node
if isinstance(node._expression, Buffer) and can_copyless_reshape(node, shape):
return Array(node._expression.reshape(shape))
return Array(Reshape(node, shape))
class Identity(JITNode):
def __init__(self, node):
self._node = node
def arguments(self):
return [self._node]
def shape(self):
return self._node._expression.shape()
def dtype(self):
return self._node._expression.dtype()
def ascontiguousarray(array):
buffer_node = buffer_arg(array)
if buffer_node is None:
return ascontiguousarray(to_assignment(array))
elif buffer_node._expression.contiguous_memory():
return array
else:
return Array(Identity(array))
def identity(array):
return ascontiguousarray(array)
### TENSORDOT ###
def check_tensordot_reduce_axes(operand_shape,
name,
reduce_axes,
batched):
# Do not reduce over more dimensions than operand_shape.size().
if len(reduce_axes) > len(operand_shape):
raise ValueError(("length of argument {name}_reduce_axes "
"should be less than the dimensions of {name}"
" ({name}.ndim()={operand_shape}"
", {name}_reduce_axes.size()={size}).").format(
name=name,
operand_shape=operand_shape,
size=len(reduce_axes)))
# all reduction axes must be less than operand_shape.size()
max_reduce_dim = max(reduce_axes)
if not (len(reduce_axes) == 0 or max_reduce_dim < len(operand_shape)):
raise ValueError(("{name}_reduce_axes contains reduction dimensions "
" that are greater than or equal to "
"{name}.ndim() ("
"{name}.ndim()={size}"
", and found max({name}_reduce_axes)"
"={max_reduce_dim}).").format(
name=name,
size=len(operand_shape),
max_reduce_dim=max_reduce_dim))
if batched and 0 in reduce_axes:
raise ValueError(("axes to sum over must not contain the batch axis "
"({name}_reduce_axes={reduce_axes}).").format(
name=name, reduce_axes=reduce_axes))
def tensordot_nonreduced_axes(ndim, reduce_axes, batched):
"""Returns all the axes that are not being reduced."""
other_axes = []
for x in range(0, ndim):
# when batched, 0 is always kept
# as leading dim, and thus will not
# be dimshuffled
if batched and x == 0:
continue
if x not in reduce_axes:
other_axes.append(x)
return other_axes
def matrix_multiply_with_reshape(a, b, output_shape, output_shape_2d):
if a._expression.ndim() != 2:
raise ValueError("a must have ndim=2")
if b._expression.ndim() != 2:
raise ValueError("b must have ndim=2")
left = output_shape_2d[0]
middle = max(a._expression.shape()[1], b._expression.shape()[0])
right = output_shape_2d[1]
# if the broadcasting fails let ReshapedMatrixMultiplyFunction
# throw an error.
new_a = ascontiguousarray(reshape(a, (left, middle)))
new_b = ascontiguousarray(reshape(b, (middle, right)))
return reshape(Array(MatMul(new_a, new_b)), output_shape)
def tensordot_as_dot(a, b, a_reduce_axes=None, b_reduce_axes=None,
batched=False, axis=None):
# This code follows the logic from theano's tensordot as dot
# [source https://github.com/Theano/Theano/blob/master/theano/tensor/basic.py#L5628]
# Theano code was also originally based elsewhere on
# Tijmen Tieleman's gnumpy:
# [source http://www.cs.toronto.edu/~tijmen/gnumpy.html]
# if 'axes' is a single number of axes to multiply and sum over
# (trailing axes of a, leading axes of b), we can just reshape
# and use dot.
# validate that the axis used for summing
# is not out of bounds for the arguments a and b
if axis is not None:
if axis < 0:
raise ValueError(("axis must be a non-negative "
"integer (got {axis}).").format(axis=axis))
for i in range(0, 2):
operand = a if i == 0 else b
operand_name = "a" if i == 0 else "b"
if axis > operand._expression.ndim():
raise ValueError(("axis can not be larger than the dimension of "
"{name} ({name}.ndim()={ndim}, axis={axis}).").format(
axis=axis, name=operand_name, ndim=operand._expression.ndim()))
if axis == operand._expression.ndim() and batched:
raise ValueError(("axis to sum over must not include the batch axis "
"of {name} ({name}.ndim()={ndim}, axis={axis}).").format(
name=name, axis=axis, ndim=operand._expression.ndim()))
batch_axes = 1 if batched else 0
a_shape, b_shape = [1, 1], [1, 1]
a_old_shape = a._expression.shape()
b_old_shape = b._expression.shape()
# compute total size of summed axes
for i in range(0, axis):
a_shape[1] *= a_old_shape[len(a_old_shape) - (i + 1)]
b_shape[0] *= b_old_shape[batch_axes + i]
# compute total size of other axes
for i in range(0, a._expression.ndim() - axis - batch_axes):
a_shape[0] *= a_old_shape[batch_axes + i]
for i in range(0, b._expression.ndim() - axis - batch_axes):
b_shape[1] *= b_old_shape[len(b_old_shape) -(i + 1)]
if batched:
a_shape.insert(0, a_old_shape[0])
b_shape.insert(0, b_old_shape[0])
output_shape = a_old_shape[:len(a_old_shape) - axis] + b_old_shape[batch_axes + axis:]
return matrix_multiply_with_reshape(
reshape(a, a_shape),
reshape(b, b_shape),
output_shape,
(a_shape[0], b_shape[1]))
else:
if a_reduce_axes is None or b_reduce_axes is None:
raise ValueError("a_reduce_axes and b_reduce_axes must "
"not be None if axis is None.")
if len(a_reduce_axes) != len(b_reduce_axes):
raise ValueError(("must have as many reduction axes for a than b "
"(got a_reduce_axes=%r and "
"b_reduce_axes=%r).") % (a_reduce_axes,
b_reduce_axes))
check_tensordot_reduce_axes(a._expression.shape(), "a", a_reduce_axes, batched)
check_tensordot_reduce_axes(b._expression.shape(), "b", b_reduce_axes, batched)
a_new_axes = tensordot_nonreduced_axes(
a._expression.ndim(), a_reduce_axes, batched)
b_new_axes = tensordot_nonreduced_axes(
b._expression.ndim(), b_reduce_axes, batched)
# for A: add reduction axis at the end of shape
a_new_axes.extend(a_reduce_axes)
# for B: add reduction axis at the beginning of shape
b_new_axes = b_reduce_axes + b_new_axes
if batched:
a_new_axes.insert(0, 0)
b_new_axes.insert(0, 0)
return tensordot_as_dot(dimshuffle(a, a_new_axes),
dimshuffle(b, b_new_axes),
axis=len(a_reduce_axes),
batched=batched)
def ascontiguousarray_or_simple_transpose(node):
"""Gemms support transposed matrix multiplies, but strided
memory generally is unsupported."""
buff = buffer_arg(node)
if buff is not None and (buff._expression.contiguous_memory() or buff._expression.is_transpose()):
return node
return ascontiguousarray(node)
def dot(a, b):
a_ndim, b_ndim = a._expression.ndim(), b._expression.ndim()
if a_ndim == 2 and b_ndim == 2:
a = ascontiguousarray_or_simple_transpose(a)
b = ascontiguousarray_or_simple_transpose(b)
return Array(MatMul(a, b))
elif a_ndim > 2 or b_ndim > 2:
return tensordot_as_dot(a, b,
a_reduce_axes=[a_ndim-1],
b_reduce_axes=[b_ndim-2])
else:
raise ValueError("dot not implemented for a.ndim = %r, b.ndim = %r" % (
a_ndim, b_ndim))
def expect_result(op, expected):
return np.testing.assert_allclose(op.value(), expected)
def np_reduce_sum(array, axis, keep_dims):
for ax in axis:
array = np.expand_dims(array.sum(ax), ax)
if not keep_dims:
array = np.squeeze(array, axis)
return array
def tf_conv2d(input, filter, strides, padding, session=None):
if session is None:
session = tf.InteractiveSession()
return session.run(tf.nn.conv2d(input, filter, strides=strides, padding=padding))
def main():
## arrays:
m3x3 = np.ones((3, 3))
z3x3 = np.zeros((3, 3))
z4x3 = np.zeros((4, 3))
array = np.arange(12).reshape((3, 4)).astype(np.float32)
array_strided = np.zeros((3, 2, 4))[:, 0, :]
array_strided[:] = array
# Additions:
expect_result(add(add(buffer(m3x3), buffer(m3x3)),
add(buffer(m3x3), buffer(m3x3))),
m3x3 + m3x3 + m3x3 + m3x3)
expect_result(add(buffer(m3x3), add(buffer(m3x3), buffer(m3x3))),
m3x3 + m3x3 + m3x3)
expect_result(add(buffer(m3x3), buffer(m3x3)), m3x3 + m3x3)
# Additions & Tanh
expect_result(add(tanh(buffer(m3x3)), tanh(buffer(m3x3))),
np.tanh(m3x3) + np.tanh(m3x3))
expect_result(add(tanh(buffer(array)), tanh(buffer(array))),
np.tanh(array) + np.tanh(array))
expect_result(tanh(buffer(array)), np.tanh(array))
# GEMM
# mix of matrix multiply + elementwise
expect_result(tanh(dot(buffer(array), buffer(array).T)),
np.tanh(np.dot(array, array.T)))
# GEMM sum:
expect_result(add(dot(buffer(m3x3), buffer(m3x3)),
dot(buffer(m3x3), buffer(m3x3))),
np.dot(m3x3, m3x3) + np.dot(m3x3, m3x3))
# CONV op:
batch_size = 128
height = 10
width = 10
channels = 3
out_channels = 10
strides = (1, 1, 1, 1)
padding = "SAME"
x = buffer(np.ones((batch_size, height, width, channels)).astype(np.float32))
w = buffer(np.arange(height * width * channels * out_channels).astype(np.float32).reshape((height, width, channels, out_channels)))
expect_result(conv2d(x, w, strides=strides, padding=padding),
tf_conv2d(x._expression.data(), w._expression.data(),
strides=strides, padding=padding))
expect_result(im2col_conv2d(x, w, strides=strides, padding=padding),
tf_conv2d(x._expression.data(), w._expression.data(),
strides=strides, padding=padding))
## Tensor dots:
expect_result(dot(buffer(np.arange(18).reshape((2, 3, 3))),
buffer(np.arange(9).reshape((3, 3)))),
np.dot(np.arange(18).reshape((2, 3, 3)),
np.arange(9).reshape((3, 3))))
## Transposes:
# the transpose is compatible with gemm, no copy:
expect_result(dot(buffer(array), buffer(array).T),
np.dot(array, array.T))
# the strided nature forces a copy before a gemm:
expect_result(dot(buffer(array), buffer(array_strided).T),
np.dot(array, array_strided.T))
## Autoreduction:
autoreduce_assign_dest = np.zeros((3, 2, 1, 2, 1))
autoreduce_assign_source = np.ones((3, 2, 10, 2, 10))
expect_result(assign(buffer(autoreduce_assign_dest), "<<=",
buffer(autoreduce_assign_source)),
np_reduce_sum(autoreduce_assign_source, (2, 4), True))
expect_result(reduce_sum(buffer(autoreduce_assign_source), (2, 4), False),
np_reduce_sum(autoreduce_assign_source, (2, 4), False))
# create a storage location with data:
a = buffer(m3x3)
# now subtract from that location
c = assign(a, "-=", buffer(m3x3))
# look at the data before evaluation:
expect_result(a, m3x3)
# calling the assignment changes a's value:
c.eval()
expect_result(a, m3x3 - m3x3)
## Canonicalization:
buff = buffer(array)
ops = [(identity(identity(identity(buff))), buff),
(identity(identity(identity(transpose(buff)))),
assign(buffer(z4x3), "=", Array(JITRunner(identity(buff.T), [buff.T])))),
(assign(buffer(m3x3), "*=", dot(buffer(m3x3), buffer(m3x3))),
Array(Assignment(buffer(m3x3), "*=", Array(Assignment(buffer(z3x3), "=", dot(buffer(m3x3), buffer(m3x3))))))),
(assign(buffer(m3x3), "+=", dot(buffer(m3x3), buffer(m3x3))),
Array(Assignment(buffer(m3x3), "+=", dot(buffer(m3x3), buffer(m3x3)))))]
for op, proposed in ops:
assert(op.canonical() == proposed), (str(op.canonical()), str(proposed))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment