Last active
August 3, 2020 16:54
-
-
Save sjperkins/3bb05e988f7861bf135ce5d55c7ae490 to your computer and use it in GitHub Desktop.
Implement extend and write with dask and pyrap.tables
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import contextlib | |
import gc | |
import os | |
import shutil | |
import tempfile | |
import dask | |
import dask.array as da | |
from dask.core import flatten | |
from dask.highlevelgraph import HighLevelGraph | |
import numpy as np | |
from numpy.testing import assert_array_equal | |
import pyrap.tables as pt | |
@contextlib.contextmanager | |
def temptablefile(): | |
try: | |
tmpdir = tempfile.mkdtemp() | |
yield os.path.join(tmpdir, "test.table") | |
finally: | |
shutil.rmtree(tmpdir) | |
def table_descriptor(column, casa_type): | |
# Column descriptor | |
desc = {'desc': {'_c_order': True, | |
'comment': '%s column' % column, | |
'dataManagerGroup': '', | |
'dataManagerType': '', | |
'keywords': {}, | |
#'ndim': 0, | |
'maxlen': 0, | |
'option': 0, | |
'valueType': "double"}, | |
'name': column} | |
return pt.maketabdesc([desc]) | |
def _add_rows(table, data, prev_start_end=None): | |
rows = data.shape[0] | |
start = 0 if prev_start_end is None else prev_start_end[-1] | |
end = start + rows | |
table.addrows(rows) | |
return (start, end) | |
def add_rows(data, table): | |
name = "add-rows-" + dask.base.tokenize(data) | |
layers = {} | |
prev_start_end = None | |
for k in flatten(data.__dask_keys__()): | |
key = (name,) + k[1:] | |
layers[key] = (_add_rows, table, k, prev_start_end) | |
prev_start_end = key | |
graph = HighLevelGraph.from_collections(name, layers, [data]) | |
return da.Array(graph, name, chunks=data.chunks, dtype=data.dtype) | |
def _write_data(data, start_end, table, column): | |
start, end = start_end | |
table.putcol(column, data, startrow=start, nrow=end-start) | |
return np.full_like(data, True) | |
def write_data(table, column, data): | |
assert data.ndim == 1 | |
extents = add_rows(data, table) | |
return da.blockwise(_write_data, "x", | |
data, "x", | |
extents, "x", | |
table, None, | |
column, None, | |
dtype=np.bool) | |
if __name__ == "__main__": | |
data = da.random.random((100,), chunks=10).astype(np.float64) | |
desc = table_descriptor("TEMP", "DOUBLE") | |
with temptablefile() as T, pt.table(T, desc) as table: | |
writes = write_data(table, "TEMP", data) | |
writes.visualize("graph.png", optimize_graph=False, rankdir="LR") | |
try: | |
writes.compute(scheduler='sync') | |
except Exception: | |
pass | |
else: | |
table_data = table.getcol("TEMP") | |
assert_array_equal(table_data, data) | |
finally: | |
table.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment