Skip to content

Instantly share code, notes, and snippets.

@kvchen
Created December 4, 2016 02:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kvchen/08c856b1fce0bdd737045ef709447ebb to your computer and use it in GitHub Desktop.
Save kvchen/08c856b1fce0bdd737045ef709447ebb to your computer and use it in GitHub Desktop.
@ray.remote(num_return_vals=3)
def block_lu(a, block_size=100):
"""Returns the LU decomposition of a square matrix.
Parameters
----------
a : array_like
Returns
-------
p : array_like
l : array_like
u : array_like
"""
if a.shape[0] <= block_size or a.shape[1] <= block_size:
return ray.get(ra.linalg.lu.remote(a))
# Find the LU decomposition of the top-left quadrant on a single node
a11 = a[:block_size, :block_size]
p11, l11, u11 = ray.get(ra.linalg.lu.remote(a11))
# Compute the inverses for each of the components of the top-left
p11_inverse, l11_inverse, u11_inverse = ray.get([
ra.linalg.inv.remote(l11),
ra.linalg.inv.remote(u11),
ra.linalg.inv.remote(p11),
])
a12 = a[:block_size, block_size:]
a21 = a[block_size:, :block_size]
# TODO(kvchen): Change this to use distributed block multiplication
u12, l21 = ray.get([
ra.dot.remote(l11_inverse, ra.dot.remote(p11_inverse, a12)),
ra.dot.remote(a21, u11_inverse),
])
a22 = a[block_size:, block_size:]
# Recurse on the lower-right quadrant
s = a22 - ray.get(ra.dot.remote(l21, u12))
ps, ls, us = ray.get(block_lu.remote(s))
# ps, ls, us =
print(6)
# A12Block = bpm.extractBlockMatrices(topRange, botRange, SBlock)
# A21Block = bpm.extractBlockMatrices(botRange, (0,0), SBlock)
#
# #Block Multiply is handled in Spark already
# U12 = np.dot(L11i, np.dot(P1i, bpm.toLocalMatrix(A12Block)))
# UBlock.extend(bpm.shiftIndicesOfBlockMatrix(topRangeAbs, botRangeAbs,
# bpm.createBlockMatrix(U12,NrowsPerBlock,NcolsPerBlock)))
# L21 = np.dot(bpm.toLocalMatrix(A21Block), U11i)
# LBlock.extend(bpm.shiftIndicesOfBlockMatrix(botRangeAbs, topRangeAbs,
# bpm.createBlockMatrix(L21,NrowsPerBlock,NcolsPerBlock)))
#
# #super kloodgy way of populating empty regions with zeros...
# # not sure how spark will work for this...hopefully better.
# UBlock.extend(bpm.shiftIndicesOfBlockMatrix(botRangeAbs, topRangeAbs,
# bpm.createBlockMatrix(L21*0.0,NrowsPerBlock,NcolsPerBlock)))
# LBlock.extend(bpm.shiftIndicesOfBlockMatrix(topRangeAbs, botRangeAbs,
# bpm.createBlockMatrix(U12*0.0,NrowsPerBlock,NcolsPerBlock)))
# PBlock.extend(bpm.shiftIndicesOfBlockMatrix(botRangeAbs, topRangeAbs,
# bpm.createBlockMatrix(L21*0.0,NrowsPerBlock,NcolsPerBlock)))
# PBlock.extend(bpm.shiftIndicesOfBlockMatrix(topRangeAbs, botRangeAbs,
# bpm.createBlockMatrix(U12*0.0,NrowsPerBlock,NcolsPerBlock)))
#
# #there are a couple of ways to do this...not sure which one is best
# A22Block = bpm.extractBlockMatrices(botRange, botRange, SBlock)
# #notice here that S is overwritten every iteration
# S = bpm.toLocalMatrix(A22Block)-np.dot(L21,U12)
# SBlock = bpm.createBlockMatrix( S,
# NrowsPerBlock,NcolsPerBlock)
# print "generating S for " + str(BlockList[j+1:])
return p11, l11, u11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment