Created
July 12, 2016 22:34
-
-
Save juancamilog/0450e044e0ad297891fc67f700a0ee56 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@register_stabilize | |
@local_optimizer([Dot, Dot22]) | |
def inv_as_solve(node): | |
if not imported_scipy: | |
return False | |
if isinstance(node.op, (Dot, Dot22)): | |
l, r = node.inputs | |
if l.owner and l.owner.op == matrix_inverse: | |
if CHECK l.owner.inputs[0] is PSD HERE: | |
# get leftmost dot operand from r | |
lop = r | |
while lop.owner and isinstance(lop.owner.op, (Dot, Dot22)): | |
lop_l, lop_r = lop.owner.inputs | |
lop = lop_l | |
if l.owner.inputs[0] is lop: | |
# if they are the same, remove lop from the right hand graph and return it | |
pass | |
return [solve(l.owner.inputs[0], r)] | |
if r.owner and r.owner.op == matrix_inverse: | |
if CHECK r.owner.inputs[0] is PSD HERE: | |
# get rightmost dot operand from r | |
rop = l | |
while rop.owner and isinstance(rop.owner.op, (Dot, Dot22)): | |
rop_l, rop_r = rop.owner.inputs | |
rop = rop_r | |
if r.owner.inputs[0] is rop: | |
# if they are the same, remove rop from the left hand graph and return it | |
pass | |
if is_symmetric(r.owner.inputs[0]): | |
return [solve(r.owner.inputs[0], l.T).T] | |
else: | |
return [solve(r.owner.inputs[0].T, l.T).T] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment