Skip to content

Instantly share code, notes, and snippets.

@OverLordGoldDragon
Created July 6, 2022 21:52
Show Gist options
  • Save OverLordGoldDragon/52cd86d9b7e5124f258fa328bd1a0896 to your computer and use it in GitHub Desktop.
Save OverLordGoldDragon/52cd86d9b7e5124f258fa328bd1a0896 to your computer and use it in GitHub Desktop.
SSQ/CWT inversion tool, work in progress
import numpy as np
import matplotlib.pyplot as plt
from ssqueezepy import ssq_cwt, issq_cwt, TestSignals, icwt
from ssqueezepy.visuals import plot, imshow
from PIL import Image
N = 7000
Npad = 8192
x = TestSignals().echirp(N=N, fmin=64, fmax=512)[0]
x = np.pad(x, (Npad - N)//2)
Tx, Wx, *_ = ssq_cwt(x)
Wx = np.abs(Wx)
#%%
kw = dict(abs=1, interpolation='none')
y_mult = 2.
x_mult = 1/8
DPI = 120
assert (Wx.shape[0] * y_mult).is_integer()
assert (Wx.shape[1] * x_mult).is_integer()
base_x = Wx.shape[1] / DPI
base_y = Wx.shape[0] / DPI
fig, ax = plt.subplots(1, 1, figsize=(base_x*x_mult, base_y*y_mult), dpi=DPI)
ax.imshow(Wx, cmap='bone', aspect='auto', interpolation='none')
plt.axis('off')
ax.set_position((0, 0, 1, 1))
plt.savefig('img.png')
plt.show()
#%%
ref = 255
img = Image.open('img_mod.png')
Wxm = np.asarray(img)[:, :, :3]
plt.imshow(Wxm, interpolation='none')
#%%
Wxr = Wxm[..., 0]
y_idxs, x_idxs = np.where(Wxr == ref)
# sort ascending x
idxs_sorted = np.argsort(x_idxs)
y_idxs, x_idxs = y_idxs[idxs_sorted], x_idxs[idxs_sorted]
# for each x_idx, average all y_idxs
y_idxs_mean, x_idxs_unique = [], [-1]
for x_i in x_idxs:
if x_i == x_idxs_unique[-1]:
# since it's sorted, same as `x_i in x_idxs_unique`
continue
_idxs = np.where(x_idxs == x_i)
# take midpoint
y_idxs_mean.append(y_idxs[_idxs].mean())
x_idxs_unique.append(x_i)
x_idxs_unique.pop(0) # remove the -1 which was there to cut a boolean for speed
y_idxs_mean, x_idxs_unique = np.array(y_idxs_mean), np.array(x_idxs_unique)
# scale values to scalogram's height
y_idxs_mean = (y_idxs_mean * (1 / y_mult)).astype(int)
# in case the selection is a segment, pad to match signal's length
# pad left
if x_idxs_unique[0] != 0:
pad_len = x_idxs_unique[0]
y_idxs_mean = np.pad(y_idxs_mean, [pad_len, 0], constant_values=-1)
# pad right
if x_idxs_unique[-1] != len(x) - 1:
pad_len = (Wxr.shape[-1] - 1) - x_idxs_unique[-1]
y_idxs_mean = np.pad(y_idxs_mean, [0, pad_len], constant_values=-1)
plt.plot(y_idxs_mean); plt.show()
# sanity checks: boundaries, total length
assert y_idxs_mean[x_idxs_unique[0]] != -1 # first selected index isn't -1
assert y_idxs_mean[x_idxs_unique[0] - 1] == -1 # first index to left of it is
assert y_idxs_mean[x_idxs_unique[-1]] != -1 # last selected index isn't -1
assert y_idxs_mean[x_idxs_unique[-1] + 1] == -1 # first index to right of it is
plt.plot(y_idxs_mean); plt.show()
# scale length to scalogram's width
y_idxs_mean = np.repeat(y_idxs_mean, Wx.shape[-1] // len(y_idxs_mean))
# more sanity checks
assert y_idxs_mean.min() >= -1 # no negatives (except -1)
assert y_idxs_mean.max() <= Wx.shape[0] # doesn't exceed scalogram height
assert len(y_idxs_mean) == Wx.shape[-1] # matches scalogram width
plt.plot(y_idxs_mean); plt.show()
#%%
cc = y_idxs_mean[:, None]
cw = np.ones(len(y_idxs_mean)) * 5
cw = cw[:, None]
out = issq_cwt(Wx, cc=cc, cw=cw).squeeze()
plot(out[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment