Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save michaelchughes/f39ef71c352dc5f784ec45fd42251cac to your computer and use it in GitHub Desktop.
Save michaelchughes/f39ef71c352dc5f784ec45fd42251cac to your computer and use it in GitHub Desktop.
Create a function that will monotonically transform the intensity values of images from a "target" distribution to match a desired "source" distribution
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
from statsmodels.distributions.empirical_distribution import ECDF
def create_transform_func_to_match_source(target_x_ND, src_x_MD, n_quantiles=1000):
'''
Returns
-------
transform : func
Maps a given target arr of any shape into a new arr of same shape
Notes
-----
1) Map each target value to its empirical quantile (value in 0-1)
2) Then map that value to a source x
'''
target_dims = target_x_ND.shape
src_dims = src_x_MD.shape
assert target_dims[1:] == src_dims[1:]
target_ecdf = ECDF(target_x_ND.flatten())
src_qs_Q = np.linspace(0, 1, n_quantiles, endpoint=True)
srcx_B = src_x_MD.reshape((np.prod(src_dims),))
x_quantiles_Q = np.nanpercentile(srcx_B, src_qs_Q * 100)
x_quantiles_Q = np.sort(x_quantiles_Q) # in increasing order
def transform(targetx_ND):
dims = targetx_ND.shape
qs_A = target_ecdf(targetx_ND.reshape((np.prod(dims),)))
ids_A = np.searchsorted(src_qs_Q, qs_A)
return x_quantiles_Q[ids_A].reshape(dims)
return transform
def make_new_fig_with_subplots():
_, axes = plt.subplot_mosaic(
'''
ABCD
EFGH
IJKL
MNOP
XXXX
XXXX
'''
)
hist_key = 'X'
im_keys = [k for k in sorted(axes.keys()) if k != hist_key]
return axes, im_keys, hist_key
if __name__ == '__main__':
N = 100
D = 32
src_dist = scipy.stats.norm(0.3, 0.05)
target_dist = scipy.stats.norm(0.5, 0.1)
# Create many square images from each distribution
# Each one is dark border, then lighter, then lighter still
src_x_NDD = src_dist.rvs(size=(N, D, D), random_state=42)
target_x_NDD = target_dist.rvs(size=(N, D, D), random_state=43)
m = D//4
M = 3*D//4
src_x_NDD[:, m:M, m:M] += 0.2
target_x_NDD[:, m:M, m:M] += 0.2
# Add black border
B = 3
for arr in [src_x_NDD, target_x_NDD]:
arr[:, :B, :] *= 0.03
arr[:, -B:, :] *= 0.03
arr[:, :, :B] *= 0.03
arr[:, :, -B:] *= 0.03
target_x_NDD[0][np.diag_indices(D)] = 0
target_x_NDD[5][np.diag_indices(D)] = 0
target_x_NDD[10][np.diag_indices(D)] = 0
transform = create_transform_func_to_match_source(target_x_NDD, src_x_NDD)
txfm_x_NDD = transform(target_x_NDD)
for (arr, suptitle_str) in [
(src_x_NDD, 'SOURCE images'),
(target_x_NDD, 'TARGET images'),
(txfm_x_NDD, 'TRANSFORMED_TO_SRC images')]:
axes, im_keys, hist_key = make_new_fig_with_subplots()
for ii, key in enumerate(im_keys):
cur_ax = axes[key]
cur_ax.imshow(arr[ii], vmin=0, vmax=1, cmap='gray')
cur_ax.set_xticks([])
cur_ax.set_yticks([])
ax = axes[hist_key]
ax.hist(
arr.flatten(),
density=True,
bins=np.linspace(0, 1, 128))
ax.set_xlabel('pixel value')
ax.set_ylabel('density')
ax.set_ylim([0, 5])
plt.suptitle(suptitle_str)
plt.savefig(suptitle_str.replace(" ", "_") + '.png', bbox_inches='tight', pad_inches=0)
@michaelchughes
Copy link
Author

The first 16 SOURCE images, with histogram of pixel values from all 100 images

SOURCE_images

The first 16 TARGET images, with histogram of pixel values

TARGET_images

The same 16 target images, after being TRANFORMED to match the source

TRANSFORMED_TO_SRC_images

@michaelchughes
Copy link
Author

Mapping from target values to source values

mapping_func

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment