Skip to content

Instantly share code, notes, and snippets.

@hengck23
Created September 30, 2021 13:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hengck23/0dd338bed8aad0e98b7477a6f17f9de7 to your computer and use it in GitHub Desktop.
Save hengck23/0dd338bed8aad0e98b7477a6f17f9de7 to your computer and use it in GitHub Desktop.
learnable QT
from common import *
from scipy.signal import get_window
from scipy import signal
#https://gist.github.com/keunwoochoi/be18701219fb671f6c74b3d6e0740513
def next_pow2(x):
return int(np.ceil(np.log2(x)))
###########################################################################
# https://opus.ostfalia.de/frontdoor/deliver/index/docId/1254/file/Lajmi_2021_CQT_FFT_Frequenzanalyse.pdf
# https://arxiv.org/pdf/1902.00631.pdf
# http://academics.wellesley.edu/Physics/brown/pubs/cq1stPaper.pdf
# https://en.wikipedia.org/wiki/Constant-Q_transform
def length_to_rect_window(l, kernel_length):
rect = np.ones(l)
if l % 2 == 1: # pad more zeros on RHS
left = int(np.ceil(kernel_length / 2.0 - l / 2.0)) - 1
else:
left = int(np.ceil(kernel_length / 2.0 - l / 2.0))
right = kernel_length-l-left
window = np.pad(rect,(left,right))
return window
def length_to_hann_window(l, kernel_length):
hann = 0.5 * (1 - np.cos(2 * np.pi * np.arange(l) / l))
if l % 2 == 1: # pad more zeros on RHS
left = int(np.ceil(kernel_length / 2.0 - l / 2.0)) - 1
else:
left = int(np.ceil(kernel_length / 2.0 - l / 2.0))
right = kernel_length-l-left
window = np.pad(hann,(left,right))
return window
def length_to_gaussian_window(l, kernel_length):
sigma = l/4 #l/4
x = np.linspace(-(kernel_length - 1) / 2., (kernel_length - 1) / 2., kernel_length)
gauss = np.exp(-0.5 * np.square(x) / np.square(sigma) )
window = gauss/gauss.max()
return window
def torch_length_to_gaussian_window(l, kernel_length):
sigma = l/6 #l/4
x = np.linspace(-(kernel_length - 1) / 2., (kernel_length - 1) / 2., kernel_length)
gauss = np.exp(-0.5 * np.square(x) / np.square(sigma) )
window = gauss/gauss.max()
return window
'''
For gravitational-wave signals, binary black holes are most clear with lower Q values (Q = 5-20),
where binary neutron star mergers work better with higher Q values (Q = 80 - 120).
https://dcc.ligo.org/public/0035/G040521/000/G040521-00.pdf
'''
def create_cqt_kernel(
Q,
fs,
fmin,
fmax,
num_freq_per_octave,
kernel_length,
):
# calculate the number of freq bin
num_freq = int(np.ceil(num_freq_per_octave * np.log2(fmax / fmin)))
freq = fmin * 2.0 ** (np.r_[0:num_freq] / np.float(num_freq_per_octave))
if 1: #check:
if np.max(freq) > fs/2:
raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \
please reduce the num of bins'.format(np.max(freq)))
length = np.ceil(Q * fs / freq).astype(np.int32)
#length = length**1.2
#kernel_length = int(max(length))+2
cqt_kernel = np.zeros((num_freq, kernel_length), dtype=np.complex64)
cqt_window = np.zeros((num_freq, kernel_length), dtype=np.float32)
for k in range(num_freq):
f = freq[k]
l = length[k]
fft = np.exp(np.r_[-kernel_length//2:kernel_length//2]*1j*2*np.pi*f/fs)/l
win = length_to_hann_window(l, kernel_length)
#win = length_to_rect_window(l, kernel_length)
#win = length_to_gaussian_window(l, kernel_length)
cqt_kernel[k] = fft
cqt_window[k] = win
return cqt_kernel, cqt_window, length, freq
'''
https://staff.fnwi.uva.nl/r.vandenboomgaard/SP20162017/SystemsSignals/plottingsignals.html
t = np.r_[-l//2:l//2] #np.linspace(-0.02, 0.05, l)
plt.plot(t, sig.real )
plt.plot(t, sig.imag )
plt.show()
'''
# def complex_to_mag(real, imag):
# mag = torch.sqrt(real.pow(2) + imag.pow(2))
# return mag
class MyQT(torch.nn.Module):
def __init__(self,
Q=5,
fs=22050,
fmin=20,
fmax=500,
num_freq_per_octave=12,
hop_length=16,
kernel_length=4096,
trainable=False,
):
super().__init__()
self.hop_length = hop_length
# creating kernels for CQT
cqt_kernel, cqt_window, length, freq = \
create_cqt_kernel(
Q,
fs,
fmin,
fmax,
num_freq_per_octave,
kernel_length,
)
self.num_freq, self.kernel_length = cqt_kernel.shape
cqt_kernel_real = torch.tensor(cqt_kernel.real).unsqueeze(1)
cqt_kernel_imag = torch.tensor(cqt_kernel.imag).unsqueeze(1)
cqt_window = torch.tensor(cqt_window).unsqueeze(1)
freq = torch.tensor(freq)
length = torch.tensor(length)
#----
self.register_buffer('freq', freq)
self.register_buffer('length', length)
self.register_buffer('cqt_window', cqt_window)
if trainable:
self.cqt_kernel_real = nn.Parameter(cqt_kernel_real)
self.cqt_kernel_imag = nn.Parameter(cqt_kernel_imag)
else:
self.register_buffer('cqt_kernel_real', cqt_kernel_real)
self.register_buffer('cqt_kernel_imag', cqt_kernel_imag)
#---
cqt_image_w = (4096)//self.hop_length+1
cqt_image_h = self.num_freq
self.text = ''\
+ 'Q=%f\n'%(Q)\
+ 'fs=%d\n'%(fs)\
+ 'fmin=%f\n'%(fmin)\
+ 'fmax=%f\n'%(fmax)\
+ 'num_freq_per_octave=%d\n'%(num_freq_per_octave)\
+ 'hop_length=%d\n'%(hop_length)\
+ 'trainable=%s\n'%str(trainable)\
+ 'cqt_kernel.shape=%s\n'%str(cqt_kernel.shape)\
+ 'cqt_image.shape=%s\n'%str((cqt_image_h,cqt_image_w))\
+ ''
def forward(self, x):
if 1: #self.pad_mode == 'constant':
x = F.pad(x,pad=(self.kernel_length//2,self.kernel_length//2),mode='constant')
# cqt
cqt_kernel_real = self.cqt_kernel_real * self.cqt_window
cqt_kernel_imag = self.cqt_kernel_imag * self.cqt_window
cqt_kernel_real = F.normalize(cqt_kernel_real,2,dim=-1)
cqt_kernel_imag = F.normalize(cqt_kernel_imag,2,dim=-1)
cqt_real = F.conv1d(x, cqt_kernel_real, stride = self.hop_length, )
cqt_imag = -F.conv1d(x, cqt_kernel_imag, stride = self.hop_length, )
mag = torch.sqrt(cqt_real.pow(2) + cqt_imag.pow(2))
return mag
##########################################################################
def cqt_to_overlay(cqt, size=(384,384)):
image = cqt
#image = image ** 0.5
#image = (image-image.min())/(image.max()-image.min())
image = np.clip(image, 0,1)
image = (image*255).astype(np.uint8)
overlay = np.stack([image[0],image[1],image[2]],-1)
overlay = cv2.flip(overlay,0)
if size is not None:
overlay = cv2.resize(overlay,dsize=size)
return overlay
def cqt_to_overlay1(cqt, size=(384,384)):
image = cqt
#image = image ** 0.5
#image = (image-image.min())/(image.max()-image.min())
image = np.clip(image, 0,1)
image = (image*255).astype(np.uint8)
overlay = np.hstack([image[0],image[1],image[2]])
overlay = cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)
overlay = cv2.flip(overlay,0)
if size is not None:
overlay = cv2.resize(overlay,dsize=(3*size[0],size[1]))
return overlay
def cqt_to_overlay2(cqt, size=(384,384)):
image = cqt
#image = image ** 0.5
#image = (image-image.min())/(image.max()-image.min())
image = np.clip(image, 0,1)
image = (image*255).astype(np.uint8)
overlay = np.vstack([image[0],image[1],image[2]])
overlay = cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)
overlay = cv2.flip(overlay,0)
if size is not None:
overlay = cv2.resize(overlay,dsize=(size[0],3*size[1]))
return overlay
##########################################################################
def run_check_cqt1():
Q = 5 #float(filter_scale) / (2 ** (1 / num_freq_per_octave) - 1)
fs = 2048
fmin = 20
fmax = 500
num_freq_per_octave=27
kernel_length = 4096
print(float(1.00) / (2 ** (1 / num_freq_per_octave) - 1))#float(filter_scale) / (2 ** (1 / num_freq_per_octave) - 1)
print(float(0.50) / (2 ** (1 / num_freq_per_octave) - 1))
print(float(0.25) / (2 ** (1 / num_freq_per_octave) - 1))
print('')
'''
16.817153745105756
8.408576872552878
4.204288436276439
'''
cqt_kernel, cqt_window, length, freq = create_cqt_kernel(
Q,
fs,
fmin,
fmax,
num_freq_per_octave=num_freq_per_octave,
kernel_length=kernel_length,
)
num_freq, kernel_length = cqt_kernel.shape
#---
print('Q', Q)
print('kernel_length', kernel_length)
print('')
print('freq', freq.shape)
print(freq)
print('')
print('length', length.shape)
print(length)
print('')
if 0:
for k in range(num_freq):
plt.clf()
plt.plot(cqt_kernel.real[k])
#plt.waitforbuttonpress()
plt.show()
if 1:
for k in range(num_freq):
plt.plot(cqt_kernel.real[k])
plt.waitforbuttonpress()
plt.show()
if 0:
fig = plt.figure()
ax = fig.add_subplot(111)
plt.plot(np.zeros(len(freq)),freq,'.')
plt.plot(length,freq,'.')
for f,l in zip(freq,length):
ax.annotate(str(f), xy=(0, f))
plt.show()
exit(0)
def run_check_cqt2():
Q = 16
kernel_length=4096
if 1: #128
fs = 2048
fmin = 20
fmax = 500
num_freq_per_octave=27
hop_length = 32
cqt = MyQT(
Q=Q,
fs=fs,
fmin=fmin,
fmax=fmax,
num_freq_per_octave=num_freq_per_octave,
hop_length=hop_length,
kernel_length=kernel_length,
)
print(cqt.text)
bHP, aHP = signal.butter(4, (25, 500), btype='bandpass', fs=2048)
tukey = signal.tukey(4096, 0.2)
wave_norm=[[3.0e20],[3.0e20],[3.0e20]]
list = [
'05dac2b96f',#0
'013f088058',
#--
'07e650dd6f',#1
'01ad950925',
'0827de7926',
]
for id in list:
file = '/root/share1/kaggle/2021/g2net/data' + '/train/%s/%s/%s/%s.npy'%(id[0],id[1],id[2],id)
wave = np.load(file)
wave = wave * wave_norm
wave = signal.filtfilt(bHP, aHP, wave)
wave = wave * tukey
#---
wave = wave.astype(np.float32)
image= cqt(torch.from_numpy(wave).unsqueeze(1))
image = image.data.cpu().numpy()
print('image.shape', image.shape)
print('image.max', image.max(-1).max(-1))
#---
if 0:
#compare opencv and pytorch resize
for m in image:
cv_big = cv2.resize(m, dsize=(256, 256), interpolation=cv2.INTER_LINEAR)
pt_big = F.interpolate(
torch.from_numpy(m).unsqueeze(0).unsqueeze(0),
size =(256, 256), mode='bilinear', align_corners=False
)
pt_big = pt_big.data.cpu().numpy()[0,0]
diff = np.abs(cv_big - pt_big)
print(diff.max(),diff.mean(),diff.std(),)
print(cv_big.max(),cv_big.mean(),cv_big.std(),)
print(pt_big.max(),pt_big.mean(),pt_big.std(),)
print('')
zz=0
pass
#---
image = image**0.5
overlay_shape = (384,384) #None #(224,224)
overlay = cqt_to_overlay(image,overlay_shape)
overlay1 = cqt_to_overlay1(image,overlay_shape)
#overlay = np.hstack([overlay,overlay1])
draw_shadow_text(overlay, '%s ' % (id, ), (2, 20), 0.6, (255, 255, 255), 2)
image_show('overlay', overlay, resize=1)
image_show('overlay1', overlay1, resize=1)
cv2.waitKey(0)
# main #################################################################
if __name__ == '__main__':
run_check_cqt2()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment