Last active
November 12, 2020 16:48
-
-
Save chrisroat/7314ba14eeb28390349a9c9d97477b91 to your computer and use it in GitHub Desktop.
Dask-based histogram matching
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dask | |
import dask.array as da | |
import numpy as np | |
from dask.array import reductions | |
def _match_cumulative_cdf(source, template): | |
""" | |
Return modified source array so that the cumulative density function of | |
its values matches the cumulative density function of the template. | |
""" | |
# Only these types are allowed, since some intermediate memory storage | |
# scales as the number of unique values in the range of the dtype. | |
allowed_dtypes = (np.uint8, np.uint16, np.int8, np.int16) | |
if source.dtype not in allowed_dtypes: | |
raise ValueError('_match_cumulative_cdf does not support input data ' | |
'of type %s. Please consider converting your data ' | |
'to one of the following types: %s' % (source.dtype, allowed_dtypes)) | |
if template.dtype not in allowed_dtypes: | |
raise ValueError('_match_cumulative_cdf does not support input data ' | |
'of type %s. Please consider converting your data ' | |
'to one of the following types: %s' % (template.dtype, allowed_dtypes)) | |
src_counts = unique_counts(source) | |
tmpl_counts = unique_counts(template) | |
interp_a_values = dask.delayed(interp)(src_counts, source.size, tmpl_counts, template.size) | |
interp_a_values = da.from_delayed(interp_a_values, dtype=np.float, shape=(np.nan, )) | |
result = source.map_blocks(match, src_counts=src_counts, values=interp_a_values, dtype=np.float) | |
return result.reshape(source.shape) | |
def unique_counts(arr): | |
return reductions.reduction(arr.ravel(), unique_chunk, unique_agg, dtype=np.intp, concatenate=False) | |
def unique_chunk(chunk, axis=None, keepdims=None): | |
info = np.iinfo(chunk.dtype) | |
span = info.max - info.min + 1 | |
arr = np.zeros(span, dtype=np.intp) | |
values, counts = np.unique(chunk, return_counts=True) | |
arr[values - info.min] = counts | |
return arr | |
def unique_agg(chunk, axis=None, keepdims=None): | |
if isinstance(chunk, list): | |
return np.sum(chunk, axis=0) | |
return chunk | |
def interp(src_counts_span, source_size, tmpl_counts_span, template_size): | |
src_counts = src_counts_span[src_counts_span != 0] | |
src_quantiles = np.cumsum(src_counts) / source_size | |
tmpl_counts = tmpl_counts_span[tmpl_counts_span != 0] | |
tmpl_quantiles = np.cumsum(tmpl_counts) / template_size | |
tmpl_values = np.flatnonzero(tmpl_counts_span) | |
return np.interp(src_quantiles, tmpl_quantiles, tmpl_values) | |
def match(chunk, src_counts, values): | |
if chunk.size == 0: | |
return np.zeros(chunk.shape, np.float) | |
info = np.iinfo(chunk.dtype) | |
src_indices = np.flatnonzero(src_counts) + info.min | |
mapping = np.zeros(chunk.shape, np.intp) | |
for idx, value in enumerate(src_indices): | |
mapping[chunk == value] = idx | |
return values[mapping] + info.min |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment