Skip to content

Instantly share code, notes, and snippets.

@danking
Last active July 17, 2020 20:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danking/5972bf5dba62acab7ef9c2c38af9893b to your computer and use it in GitHub Desktop.
Save danking/5972bf5dba62acab7ef9c2c38af9893b to your computer and use it in GitHub Desktop.
import hail as hl
hl.nd.array([1, 2, 3, 4]).reshape((2, 2)).show()
# FIXME: use ndarray sum / fma
def block_product(left, right):
product = left @ right
n_rows, n_cols = product.shape
return hl.struct(
shape=product.shape,
block=hl.range(hl.int(n_rows * n_cols)).map(
lambda absolute: product[absolute % n_rows, absolute // n_rows]))
def block_aggregate(prod):
shape = prod.shape
block = prod.block
return hl.nd.from_column_major(
hl.agg.array_sum(block),
hl.agg.take(shape, 1)[0])
x = hl.nd.array([1, 2, 3, 4]).reshape((2, 2))
y = hl.nd.array([1, 0, 0, 1]).reshape((2,2))
x.collect()
y.collect()
block_product(x, y).collect()
t = hl.utils.range_table(3)
t = t.annotate(block = x)
t.collect()
t = t.annotate(product = block_product(t.block, y))
t.product.collect()
t.aggregate(hl.agg.array_sum(t.product.block))
thing = t.aggregate(
hl.struct(
the_sum = hl.agg.array_sum(t.product.block),
the_shape = hl.agg.take(t.product.shape, 1)[0]
)
)
thing
hl.nd.from_column_major(thing.the_sum, thing.the_shape).collect()
thing = t.aggregate(
hl.nd.from_column_major(
hl.agg.array_sum(t.product.block),
hl.agg.take(t.product.shape, 1)[0]
)
)
thing
def to_column_major(ndarray):
n_rows, n_cols = ndarray.shape
return hl.range(hl.int(n_rows * n_cols)).map(
lambda absolute: ndarray[absolute % n_rows, absolute // n_rows])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment