Skip to content

Instantly share code, notes, and snippets.

@chrisroat
Last active November 12, 2020 16:48
Show Gist options
  • Save chrisroat/7314ba14eeb28390349a9c9d97477b91 to your computer and use it in GitHub Desktop.
Save chrisroat/7314ba14eeb28390349a9c9d97477b91 to your computer and use it in GitHub Desktop.
Dask-based histogram matching
import dask
import dask.array as da
import numpy as np
from dask.array import reductions
def _match_cumulative_cdf(source, template):
"""
Return modified source array so that the cumulative density function of
its values matches the cumulative density function of the template.
"""
# Only these types are allowed, since some intermediate memory storage
# scales as the number of unique values in the range of the dtype.
allowed_dtypes = (np.uint8, np.uint16, np.int8, np.int16)
if source.dtype not in allowed_dtypes:
raise ValueError('_match_cumulative_cdf does not support input data '
'of type %s. Please consider converting your data '
'to one of the following types: %s' % (source.dtype, allowed_dtypes))
if template.dtype not in allowed_dtypes:
raise ValueError('_match_cumulative_cdf does not support input data '
'of type %s. Please consider converting your data '
'to one of the following types: %s' % (template.dtype, allowed_dtypes))
src_counts = unique_counts(source)
tmpl_counts = unique_counts(template)
interp_a_values = dask.delayed(interp)(src_counts, source.size, tmpl_counts, template.size)
interp_a_values = da.from_delayed(interp_a_values, dtype=np.float, shape=(np.nan, ))
result = source.map_blocks(match, src_counts=src_counts, values=interp_a_values, dtype=np.float)
return result.reshape(source.shape)
def unique_counts(arr):
return reductions.reduction(arr.ravel(), unique_chunk, unique_agg, dtype=np.intp, concatenate=False)
def unique_chunk(chunk, axis=None, keepdims=None):
info = np.iinfo(chunk.dtype)
span = info.max - info.min + 1
arr = np.zeros(span, dtype=np.intp)
values, counts = np.unique(chunk, return_counts=True)
arr[values - info.min] = counts
return arr
def unique_agg(chunk, axis=None, keepdims=None):
if isinstance(chunk, list):
return np.sum(chunk, axis=0)
return chunk
def interp(src_counts_span, source_size, tmpl_counts_span, template_size):
src_counts = src_counts_span[src_counts_span != 0]
src_quantiles = np.cumsum(src_counts) / source_size
tmpl_counts = tmpl_counts_span[tmpl_counts_span != 0]
tmpl_quantiles = np.cumsum(tmpl_counts) / template_size
tmpl_values = np.flatnonzero(tmpl_counts_span)
return np.interp(src_quantiles, tmpl_quantiles, tmpl_values)
def match(chunk, src_counts, values):
if chunk.size == 0:
return np.zeros(chunk.shape, np.float)
info = np.iinfo(chunk.dtype)
src_indices = np.flatnonzero(src_counts) + info.min
mapping = np.zeros(chunk.shape, np.intp)
for idx, value in enumerate(src_indices):
mapping[chunk == value] = idx
return values[mapping] + info.min
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment