Skip to content

Instantly share code, notes, and snippets.

@OverLordGoldDragon
Created October 22, 2023 11:29
Show Gist options
  • Save OverLordGoldDragon/3026aea557c03ce00b2eb5060ab63c33 to your computer and use it in GitHub Desktop.
Save OverLordGoldDragon/3026aea557c03ce00b2eb5060ab63c33 to your computer and use it in GitHub Desktop.
Colormap testing
# -*- coding: utf-8 -*-
"""Skip to "Configure" section.
Install the zip (URL below), unzip, put in path.
pip install cmap
or
conda install -c conda-forge cmap
"""
# https://github.com/OverLordGoldDragon/ssqueezepy/files/13063053/jtfs_fbank.zip
#%% Imports ------------------------------------------------------------------
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import cmap
#%% Helpers ------------------------------------------------------------------
def imshow(Psi, ax, part):
if part != 'complex':
aPsi = abs(Psi)
mx = aPsi.max()
cmap = CMAPS[part]
if part == 'abs':
args = dict(X=aPsi, vmin=0, vmax=mx, cmap=cmap)
else:
args = dict(X=getattr(Psi, part), vmin=-mx, vmax=mx, cmap=cmap)
else:
args = COMPLEX_FN(Psi)
ax.imshow(**args, aspect='auto')
ax.set_xticks([]); ax.set_yticks([])
ax.set_xticklabels([]); ax.set_yticklabels([])
for spine in ax.spines:
ax.spines[spine].set_visible(False)
def make_cmap(cname=None, circshift=False):
if cname is None:
cname = 'colorcet:CET_C1'
cmap1 = cmap.Colormap(cname)
grid = np.linspace(0, 1, 1024)
if circshift:
grid = np.roll(grid, len(grid)//2)
cm = cmap1(grid)
cm = mpl.colors.ListedColormap(cm)
return cm
def make_cname(name, cetnum=None):
cetnum = (str(cetnum) if cetnum is not None else '1')
if name == 'cet':
cname = 'colorcet:CET_C' + cetnum
elif name == 'cetb':
cname = 'colorcet:CET_CBC' + cetnum
elif name == 'cett':
cname = 'colorcet:CET_CBTC' + cetnum
else:
1/0
return cname
#%% Configure ----------------------------------------------------------------
PART = ('real', 'imag', 'abs', 'complex')[-1]
CMAPS = {'real': 'bwr',
'imag': 'bwr',
'abs': 'turbo'}
# ignore analytic troubles
DROP_ILL_BEHAVED = 1
# background color
FACECOLOR = (None, '#d0d83a'
)[0]
def COMPLEX_FN(Psi):
"""Define PART='complex' handling here, passed to `ax.imshow`"""
# configure
name = ('cet', 'cetb', 'cett')[0]
cetnum = 3
# misc params
is_cet3 = bool(cetnum == 3 and name == 'cet')
# cmap
cname = make_cname(name, cetnum)
cm = make_cmap(cname, circshift=is_cet3) # circshift makes zero phase white
# phase
X = np.angle(Psi)
if is_cet3:
X = -X # make positive red
# amplitude
alpha = np.abs(Psi)**(2/3)
alpha /= alpha.max()
mx = np.abs(X).max()
if not np.allclose(mx, 0): # not `phi_t * phi_f`
X /= mx
return dict(X=X, alpha=alpha, cmap=cm, vmin=-1, vmax=1,
interpolation='none')
#%% Load data ----------------------------------------------------------------
fbank = np.load('jtfs_fbank.npz')
psis_up, psis_dn, psis_t, phi_t, phi_f = [
fbank[nm] for nm in ('psis_up', 'psis_dn', 'psis_t', 'phi_t', 'phi_f')]
if DROP_ILL_BEHAVED:
psis_up, psis_dn = psis_up[:-1], psis_dn[:-1]
psis_t = psis_t[:-1]
wx, wy = 120, 28
else:
wx, wy = 155, 35
Nh, Mh = 8192//2, 128//2
slc_x = slice(Nh-wx, Nh+wx+1)
slc_y = slice(Mh-wy, Mh+wy+1)
#%% Visualize ----------------------------------------------------------------
n_rows, n_cols = 2*len(psis_up) + 1, len(psis_t) + 1
fig, axes = plt.subplots(n_rows, n_cols, figsize=(11, 20.2), layout='constrained',
facecolor=FACECOLOR)
# spinned
for s_idx in (0, 1):
for n2, pt in enumerate(psis_t[::-1]):
psi_frs = (psis_dn if s_idx == 0 else
psis_up) # flip & down order for reasons
if s_idx == 1:
psi_frs = psi_frs[::-1]
for n1_fr, pf in enumerate(psi_frs):
col_idx = 1 + n2
if s_idx == 0:
row_idx = n1_fr
else:
row_idx = n_rows//2 + 1 + n1_fr
ax = axes[row_idx, col_idx]
Psi = pf[slc_y][:, None] * pt[slc_x][None, :]
# Psi = Psi[slc_y, slc_x]
imshow(Psi, ax, PART)
# phi_f
for n2, pt in enumerate(psis_t[::-1]):
col_idx = 1 + n2
row_idx = n_rows//2
ax = axes[row_idx, col_idx]
Psi = phi_f[slc_y][:, None] * pt[slc_x][None, :]
imshow(Psi, ax, PART)
# phi_t
for s_idx in (0, 1):
psi_frs = (psis_dn if s_idx == 0 else
psis_up)
if s_idx == 1:
psi_frs = psi_frs[::-1]
for n1_fr, pf in enumerate(psi_frs):
col_idx = 0
if s_idx == 0:
row_idx = n1_fr
else:
row_idx = n_rows//2 + 1 + n1_fr
ax = axes[row_idx, col_idx]
Psi = pf[slc_y][:, None] * phi_t[slc_x][None, :]
imshow(Psi, ax, PART)
# phi_t * phi_f
ax = axes[n_rows//2, 0]
Psi = phi_f[slc_y][:, None] * phi_t[slc_x][None, :]
imshow(Psi, ax, PART)
# postprocessing
if FACECOLOR is not None:
for ax in axes.flat:
ax.set_facecolor(FACECOLOR)
# finalize
plt.show()
@OverLordGoldDragon
Copy link
Author

(Q&A: https://stackoverflow.com/q/77331564/10133797)

Goals:

  1. see stages of phase clearly
  2. see "spin", direction of complex rotation
  3. see amplitude, not conflated with phase
  4. (optional) not have more than 4 distinct colors, too complicated for high frequency visuals

C6: the blue, green, and purple are too alike (fails 1), and it's tough to interpret without familiarity and memorization (6 colors, fails 4).

"Real superimposed with imag" - it took me some staring to realize there's any blue and red in there... fails (1) drastically.

The most satisfactory I found is C3, with circshift (to make zero-phase white) + negation (to make red positive, consistent with bwr):

But, it fails (3), as the white encodes both amplitude decay and phase. It's also... really blurry? C4 doesn't have that issue, but fails (2) as it's a two-phase. Another issue is, the black is too dominant relative to red and blue, I wish it could be toned down a little.

To fix (3), I figured to make some non-white background. Here's what I just came up with:

Kinda ugly, but it works... maybe there's something better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment