Skip to content

Instantly share code, notes, and snippets.

@kvchen
Created December 4, 2016 22:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kvchen/c1b1f021f1d8775565be81928f5c4e14 to your computer and use it in GitHub Desktop.
Save kvchen/c1b1f021f1d8775565be81928f5c4e14 to your computer and use it in GitHub Desktop.
@ray.remote
def lu_decomp_invert(lu_decomp):
"""Takes the inverse of each of the components in the P, L, U
decomposition. Needed as a helper function for the block-level LU decomp.
"""
return tuple(np.linalg.inv(x) for x in lu_decomp)
@ray.remote(num_return_vals=3)
def block_lu(a, block_size=100):
"""Returns the LU decomposition of a square matrix.
Parameters
----------
a : array_like
Returns
-------
p : array_like
l : array_like
u : array_like
"""
if a.shape[0] <= block_size or a.shape[1] <= block_size:
return ray.get(ra.linalg.lu.remote(a))
p, l, u = np.zeros(a.shape), np.zeros(a.shape), np.zeros(a.shape)
num_blocks = int(np.ceil(float(a.shape[0]) / block_size))
# Compute all single-node LU decompositions in parallel
block_decomps_remote = []
a_modified = a
for idx in range(num_blocks):
block = a_modified[:block_size, :block_size]
block_decomps_remote.append(ra.linalg.lu.remote(block))
a12 = a_modified[:block_size, block_size:]
a21 = a_modified[block_size:, :block_size]
# Compute the Schur complement
a22 = a_modified[block_size:, block_size:]
a_modified = a22 - np.dot(a21, np.dot(np.linalg.inv(block), a12))
block_decomps = ray.get(block_decomps_remote)
# Compute the inverses for each of the LU components in parallel
block_decomp_inverses = ray.get([lu_decomp_invert.remote(decomp)
for decomp in block_decomps_remote])
# Perform coalescing
for idx, (plu, plu_inverse) in enumerate(zip(block_decomps,
block_decomp_inverses)):
p11, l11, u11 = plu
p11_inverse, l11_inverse, u11_inverse = plu_inverse
block_low = block_size * idx
block_high = block_low + block_size
p[block_low:block_high, block_low:block_high] = p11
if idx < num_blocks - 1:
a12 = a[:block_size, block_size:]
a21 = a[block_size:, :block_size]
# TODO(kvchen): Change this to use distributed block
# multiplication. Not sure if it'll be any more efficient.
u12, l21 = ray.get([
ra.dot.remote(l11_inverse, ra.dot.remote(p11_inverse, a12)),
ra.dot.remote(a21, u11_inverse),
])
# Recurse on the lower-right quadrant
a = a[block_size:, block_size:] - ray.get(ra.dot.remote(l21, u12))
l[block_high:, block_low:block_high] = l21
u[block_low:block_high, block_high:] = u12
l[block_low:block_high, block_low:block_high] = l11
u[block_low:block_high, block_low:block_high] = u11
return p, l, u
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment