Skip to content

Instantly share code, notes, and snippets.

@sjperkins
Last active August 3, 2020 16:54
Show Gist options
  • Save sjperkins/3bb05e988f7861bf135ce5d55c7ae490 to your computer and use it in GitHub Desktop.
Save sjperkins/3bb05e988f7861bf135ce5d55c7ae490 to your computer and use it in GitHub Desktop.
Implement extend and write with dask and pyrap.tables
import argparse
import contextlib
import gc
import os
import shutil
import tempfile
import dask
import dask.array as da
from dask.core import flatten
from dask.highlevelgraph import HighLevelGraph
import numpy as np
from numpy.testing import assert_array_equal
import pyrap.tables as pt
@contextlib.contextmanager
def temptablefile():
try:
tmpdir = tempfile.mkdtemp()
yield os.path.join(tmpdir, "test.table")
finally:
shutil.rmtree(tmpdir)
def table_descriptor(column, casa_type):
# Column descriptor
desc = {'desc': {'_c_order': True,
'comment': '%s column' % column,
'dataManagerGroup': '',
'dataManagerType': '',
'keywords': {},
#'ndim': 0,
'maxlen': 0,
'option': 0,
'valueType': "double"},
'name': column}
return pt.maketabdesc([desc])
def _add_rows(table, data, prev_start_end=None):
rows = data.shape[0]
start = 0 if prev_start_end is None else prev_start_end[-1]
end = start + rows
table.addrows(rows)
return (start, end)
def add_rows(data, table):
name = "add-rows-" + dask.base.tokenize(data)
layers = {}
prev_start_end = None
for k in flatten(data.__dask_keys__()):
key = (name,) + k[1:]
layers[key] = (_add_rows, table, k, prev_start_end)
prev_start_end = key
graph = HighLevelGraph.from_collections(name, layers, [data])
return da.Array(graph, name, chunks=data.chunks, dtype=data.dtype)
def _write_data(data, start_end, table, column):
start, end = start_end
table.putcol(column, data, startrow=start, nrow=end-start)
return np.full_like(data, True)
def write_data(table, column, data):
assert data.ndim == 1
extents = add_rows(data, table)
return da.blockwise(_write_data, "x",
data, "x",
extents, "x",
table, None,
column, None,
dtype=np.bool)
if __name__ == "__main__":
data = da.random.random((100,), chunks=10).astype(np.float64)
desc = table_descriptor("TEMP", "DOUBLE")
with temptablefile() as T, pt.table(T, desc) as table:
writes = write_data(table, "TEMP", data)
writes.visualize("graph.png", optimize_graph=False, rankdir="LR")
try:
writes.compute(scheduler='sync')
except Exception:
pass
else:
table_data = table.getcol("TEMP")
assert_array_equal(table_data, data)
finally:
table.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment