Created
December 4, 2016 02:28
-
-
Save kvchen/08c856b1fce0bdd737045ef709447ebb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@ray.remote(num_return_vals=3) | |
def block_lu(a, block_size=100): | |
"""Returns the LU decomposition of a square matrix. | |
Parameters | |
---------- | |
a : array_like | |
Returns | |
------- | |
p : array_like | |
l : array_like | |
u : array_like | |
""" | |
if a.shape[0] <= block_size or a.shape[1] <= block_size: | |
return ray.get(ra.linalg.lu.remote(a)) | |
# Find the LU decomposition of the top-left quadrant on a single node | |
a11 = a[:block_size, :block_size] | |
p11, l11, u11 = ray.get(ra.linalg.lu.remote(a11)) | |
# Compute the inverses for each of the components of the top-left | |
p11_inverse, l11_inverse, u11_inverse = ray.get([ | |
ra.linalg.inv.remote(l11), | |
ra.linalg.inv.remote(u11), | |
ra.linalg.inv.remote(p11), | |
]) | |
a12 = a[:block_size, block_size:] | |
a21 = a[block_size:, :block_size] | |
# TODO(kvchen): Change this to use distributed block multiplication | |
u12, l21 = ray.get([ | |
ra.dot.remote(l11_inverse, ra.dot.remote(p11_inverse, a12)), | |
ra.dot.remote(a21, u11_inverse), | |
]) | |
a22 = a[block_size:, block_size:] | |
# Recurse on the lower-right quadrant | |
s = a22 - ray.get(ra.dot.remote(l21, u12)) | |
ps, ls, us = ray.get(block_lu.remote(s)) | |
# ps, ls, us = | |
print(6) | |
# A12Block = bpm.extractBlockMatrices(topRange, botRange, SBlock) | |
# A21Block = bpm.extractBlockMatrices(botRange, (0,0), SBlock) | |
# | |
# #Block Multiply is handled in Spark already | |
# U12 = np.dot(L11i, np.dot(P1i, bpm.toLocalMatrix(A12Block))) | |
# UBlock.extend(bpm.shiftIndicesOfBlockMatrix(topRangeAbs, botRangeAbs, | |
# bpm.createBlockMatrix(U12,NrowsPerBlock,NcolsPerBlock))) | |
# L21 = np.dot(bpm.toLocalMatrix(A21Block), U11i) | |
# LBlock.extend(bpm.shiftIndicesOfBlockMatrix(botRangeAbs, topRangeAbs, | |
# bpm.createBlockMatrix(L21,NrowsPerBlock,NcolsPerBlock))) | |
# | |
# #super kloodgy way of populating empty regions with zeros... | |
# # not sure how spark will work for this...hopefully better. | |
# UBlock.extend(bpm.shiftIndicesOfBlockMatrix(botRangeAbs, topRangeAbs, | |
# bpm.createBlockMatrix(L21*0.0,NrowsPerBlock,NcolsPerBlock))) | |
# LBlock.extend(bpm.shiftIndicesOfBlockMatrix(topRangeAbs, botRangeAbs, | |
# bpm.createBlockMatrix(U12*0.0,NrowsPerBlock,NcolsPerBlock))) | |
# PBlock.extend(bpm.shiftIndicesOfBlockMatrix(botRangeAbs, topRangeAbs, | |
# bpm.createBlockMatrix(L21*0.0,NrowsPerBlock,NcolsPerBlock))) | |
# PBlock.extend(bpm.shiftIndicesOfBlockMatrix(topRangeAbs, botRangeAbs, | |
# bpm.createBlockMatrix(U12*0.0,NrowsPerBlock,NcolsPerBlock))) | |
# | |
# #there are a couple of ways to do this...not sure which one is best | |
# A22Block = bpm.extractBlockMatrices(botRange, botRange, SBlock) | |
# #notice here that S is overwritten every iteration | |
# S = bpm.toLocalMatrix(A22Block)-np.dot(L21,U12) | |
# SBlock = bpm.createBlockMatrix( S, | |
# NrowsPerBlock,NcolsPerBlock) | |
# print "generating S for " + str(BlockList[j+1:]) | |
return p11, l11, u11 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment