Skip to content

Instantly share code, notes, and snippets.

@juancamilog
Created July 12, 2016 22:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save juancamilog/0450e044e0ad297891fc67f700a0ee56 to your computer and use it in GitHub Desktop.
Save juancamilog/0450e044e0ad297891fc67f700a0ee56 to your computer and use it in GitHub Desktop.
@register_stabilize
@local_optimizer([Dot, Dot22])
def inv_as_solve(node):
if not imported_scipy:
return False
if isinstance(node.op, (Dot, Dot22)):
l, r = node.inputs
if l.owner and l.owner.op == matrix_inverse:
if CHECK l.owner.inputs[0] is PSD HERE:
# get leftmost dot operand from r
lop = r
while lop.owner and isinstance(lop.owner.op, (Dot, Dot22)):
lop_l, lop_r = lop.owner.inputs
lop = lop_l
if l.owner.inputs[0] is lop:
# if they are the same, remove lop from the right hand graph and return it
pass
return [solve(l.owner.inputs[0], r)]
if r.owner and r.owner.op == matrix_inverse:
if CHECK r.owner.inputs[0] is PSD HERE:
# get rightmost dot operand from r
rop = l
while rop.owner and isinstance(rop.owner.op, (Dot, Dot22)):
rop_l, rop_r = rop.owner.inputs
rop = rop_r
if r.owner.inputs[0] is rop:
# if they are the same, remove rop from the left hand graph and return it
pass
if is_symmetric(r.owner.inputs[0]):
return [solve(r.owner.inputs[0], l.T).T]
else:
return [solve(r.owner.inputs[0].T, l.T).T]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment