Created
September 30, 2021 13:25
-
-
Save hengck23/0dd338bed8aad0e98b7477a6f17f9de7 to your computer and use it in GitHub Desktop.
learnable QT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from common import * | |
from scipy.signal import get_window | |
from scipy import signal | |
#https://gist.github.com/keunwoochoi/be18701219fb671f6c74b3d6e0740513 | |
def next_pow2(x): | |
return int(np.ceil(np.log2(x))) | |
########################################################################### | |
# https://opus.ostfalia.de/frontdoor/deliver/index/docId/1254/file/Lajmi_2021_CQT_FFT_Frequenzanalyse.pdf | |
# https://arxiv.org/pdf/1902.00631.pdf | |
# http://academics.wellesley.edu/Physics/brown/pubs/cq1stPaper.pdf | |
# https://en.wikipedia.org/wiki/Constant-Q_transform | |
def length_to_rect_window(l, kernel_length): | |
rect = np.ones(l) | |
if l % 2 == 1: # pad more zeros on RHS | |
left = int(np.ceil(kernel_length / 2.0 - l / 2.0)) - 1 | |
else: | |
left = int(np.ceil(kernel_length / 2.0 - l / 2.0)) | |
right = kernel_length-l-left | |
window = np.pad(rect,(left,right)) | |
return window | |
def length_to_hann_window(l, kernel_length): | |
hann = 0.5 * (1 - np.cos(2 * np.pi * np.arange(l) / l)) | |
if l % 2 == 1: # pad more zeros on RHS | |
left = int(np.ceil(kernel_length / 2.0 - l / 2.0)) - 1 | |
else: | |
left = int(np.ceil(kernel_length / 2.0 - l / 2.0)) | |
right = kernel_length-l-left | |
window = np.pad(hann,(left,right)) | |
return window | |
def length_to_gaussian_window(l, kernel_length): | |
sigma = l/4 #l/4 | |
x = np.linspace(-(kernel_length - 1) / 2., (kernel_length - 1) / 2., kernel_length) | |
gauss = np.exp(-0.5 * np.square(x) / np.square(sigma) ) | |
window = gauss/gauss.max() | |
return window | |
def torch_length_to_gaussian_window(l, kernel_length): | |
sigma = l/6 #l/4 | |
x = np.linspace(-(kernel_length - 1) / 2., (kernel_length - 1) / 2., kernel_length) | |
gauss = np.exp(-0.5 * np.square(x) / np.square(sigma) ) | |
window = gauss/gauss.max() | |
return window | |
''' | |
For gravitational-wave signals, binary black holes are most clear with lower Q values (Q = 5-20), | |
where binary neutron star mergers work better with higher Q values (Q = 80 - 120). | |
https://dcc.ligo.org/public/0035/G040521/000/G040521-00.pdf | |
''' | |
def create_cqt_kernel( | |
Q, | |
fs, | |
fmin, | |
fmax, | |
num_freq_per_octave, | |
kernel_length, | |
): | |
# calculate the number of freq bin | |
num_freq = int(np.ceil(num_freq_per_octave * np.log2(fmax / fmin))) | |
freq = fmin * 2.0 ** (np.r_[0:num_freq] / np.float(num_freq_per_octave)) | |
if 1: #check: | |
if np.max(freq) > fs/2: | |
raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \ | |
please reduce the num of bins'.format(np.max(freq))) | |
length = np.ceil(Q * fs / freq).astype(np.int32) | |
#length = length**1.2 | |
#kernel_length = int(max(length))+2 | |
cqt_kernel = np.zeros((num_freq, kernel_length), dtype=np.complex64) | |
cqt_window = np.zeros((num_freq, kernel_length), dtype=np.float32) | |
for k in range(num_freq): | |
f = freq[k] | |
l = length[k] | |
fft = np.exp(np.r_[-kernel_length//2:kernel_length//2]*1j*2*np.pi*f/fs)/l | |
win = length_to_hann_window(l, kernel_length) | |
#win = length_to_rect_window(l, kernel_length) | |
#win = length_to_gaussian_window(l, kernel_length) | |
cqt_kernel[k] = fft | |
cqt_window[k] = win | |
return cqt_kernel, cqt_window, length, freq | |
''' | |
https://staff.fnwi.uva.nl/r.vandenboomgaard/SP20162017/SystemsSignals/plottingsignals.html | |
t = np.r_[-l//2:l//2] #np.linspace(-0.02, 0.05, l) | |
plt.plot(t, sig.real ) | |
plt.plot(t, sig.imag ) | |
plt.show() | |
''' | |
# def complex_to_mag(real, imag): | |
# mag = torch.sqrt(real.pow(2) + imag.pow(2)) | |
# return mag | |
class MyQT(torch.nn.Module): | |
def __init__(self, | |
Q=5, | |
fs=22050, | |
fmin=20, | |
fmax=500, | |
num_freq_per_octave=12, | |
hop_length=16, | |
kernel_length=4096, | |
trainable=False, | |
): | |
super().__init__() | |
self.hop_length = hop_length | |
# creating kernels for CQT | |
cqt_kernel, cqt_window, length, freq = \ | |
create_cqt_kernel( | |
Q, | |
fs, | |
fmin, | |
fmax, | |
num_freq_per_octave, | |
kernel_length, | |
) | |
self.num_freq, self.kernel_length = cqt_kernel.shape | |
cqt_kernel_real = torch.tensor(cqt_kernel.real).unsqueeze(1) | |
cqt_kernel_imag = torch.tensor(cqt_kernel.imag).unsqueeze(1) | |
cqt_window = torch.tensor(cqt_window).unsqueeze(1) | |
freq = torch.tensor(freq) | |
length = torch.tensor(length) | |
#---- | |
self.register_buffer('freq', freq) | |
self.register_buffer('length', length) | |
self.register_buffer('cqt_window', cqt_window) | |
if trainable: | |
self.cqt_kernel_real = nn.Parameter(cqt_kernel_real) | |
self.cqt_kernel_imag = nn.Parameter(cqt_kernel_imag) | |
else: | |
self.register_buffer('cqt_kernel_real', cqt_kernel_real) | |
self.register_buffer('cqt_kernel_imag', cqt_kernel_imag) | |
#--- | |
cqt_image_w = (4096)//self.hop_length+1 | |
cqt_image_h = self.num_freq | |
self.text = ''\ | |
+ 'Q=%f\n'%(Q)\ | |
+ 'fs=%d\n'%(fs)\ | |
+ 'fmin=%f\n'%(fmin)\ | |
+ 'fmax=%f\n'%(fmax)\ | |
+ 'num_freq_per_octave=%d\n'%(num_freq_per_octave)\ | |
+ 'hop_length=%d\n'%(hop_length)\ | |
+ 'trainable=%s\n'%str(trainable)\ | |
+ 'cqt_kernel.shape=%s\n'%str(cqt_kernel.shape)\ | |
+ 'cqt_image.shape=%s\n'%str((cqt_image_h,cqt_image_w))\ | |
+ '' | |
def forward(self, x): | |
if 1: #self.pad_mode == 'constant': | |
x = F.pad(x,pad=(self.kernel_length//2,self.kernel_length//2),mode='constant') | |
# cqt | |
cqt_kernel_real = self.cqt_kernel_real * self.cqt_window | |
cqt_kernel_imag = self.cqt_kernel_imag * self.cqt_window | |
cqt_kernel_real = F.normalize(cqt_kernel_real,2,dim=-1) | |
cqt_kernel_imag = F.normalize(cqt_kernel_imag,2,dim=-1) | |
cqt_real = F.conv1d(x, cqt_kernel_real, stride = self.hop_length, ) | |
cqt_imag = -F.conv1d(x, cqt_kernel_imag, stride = self.hop_length, ) | |
mag = torch.sqrt(cqt_real.pow(2) + cqt_imag.pow(2)) | |
return mag | |
########################################################################## | |
def cqt_to_overlay(cqt, size=(384,384)): | |
image = cqt | |
#image = image ** 0.5 | |
#image = (image-image.min())/(image.max()-image.min()) | |
image = np.clip(image, 0,1) | |
image = (image*255).astype(np.uint8) | |
overlay = np.stack([image[0],image[1],image[2]],-1) | |
overlay = cv2.flip(overlay,0) | |
if size is not None: | |
overlay = cv2.resize(overlay,dsize=size) | |
return overlay | |
def cqt_to_overlay1(cqt, size=(384,384)): | |
image = cqt | |
#image = image ** 0.5 | |
#image = (image-image.min())/(image.max()-image.min()) | |
image = np.clip(image, 0,1) | |
image = (image*255).astype(np.uint8) | |
overlay = np.hstack([image[0],image[1],image[2]]) | |
overlay = cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR) | |
overlay = cv2.flip(overlay,0) | |
if size is not None: | |
overlay = cv2.resize(overlay,dsize=(3*size[0],size[1])) | |
return overlay | |
def cqt_to_overlay2(cqt, size=(384,384)): | |
image = cqt | |
#image = image ** 0.5 | |
#image = (image-image.min())/(image.max()-image.min()) | |
image = np.clip(image, 0,1) | |
image = (image*255).astype(np.uint8) | |
overlay = np.vstack([image[0],image[1],image[2]]) | |
overlay = cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR) | |
overlay = cv2.flip(overlay,0) | |
if size is not None: | |
overlay = cv2.resize(overlay,dsize=(size[0],3*size[1])) | |
return overlay | |
########################################################################## | |
def run_check_cqt1(): | |
Q = 5 #float(filter_scale) / (2 ** (1 / num_freq_per_octave) - 1) | |
fs = 2048 | |
fmin = 20 | |
fmax = 500 | |
num_freq_per_octave=27 | |
kernel_length = 4096 | |
print(float(1.00) / (2 ** (1 / num_freq_per_octave) - 1))#float(filter_scale) / (2 ** (1 / num_freq_per_octave) - 1) | |
print(float(0.50) / (2 ** (1 / num_freq_per_octave) - 1)) | |
print(float(0.25) / (2 ** (1 / num_freq_per_octave) - 1)) | |
print('') | |
''' | |
16.817153745105756 | |
8.408576872552878 | |
4.204288436276439 | |
''' | |
cqt_kernel, cqt_window, length, freq = create_cqt_kernel( | |
Q, | |
fs, | |
fmin, | |
fmax, | |
num_freq_per_octave=num_freq_per_octave, | |
kernel_length=kernel_length, | |
) | |
num_freq, kernel_length = cqt_kernel.shape | |
#--- | |
print('Q', Q) | |
print('kernel_length', kernel_length) | |
print('') | |
print('freq', freq.shape) | |
print(freq) | |
print('') | |
print('length', length.shape) | |
print(length) | |
print('') | |
if 0: | |
for k in range(num_freq): | |
plt.clf() | |
plt.plot(cqt_kernel.real[k]) | |
#plt.waitforbuttonpress() | |
plt.show() | |
if 1: | |
for k in range(num_freq): | |
plt.plot(cqt_kernel.real[k]) | |
plt.waitforbuttonpress() | |
plt.show() | |
if 0: | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
plt.plot(np.zeros(len(freq)),freq,'.') | |
plt.plot(length,freq,'.') | |
for f,l in zip(freq,length): | |
ax.annotate(str(f), xy=(0, f)) | |
plt.show() | |
exit(0) | |
def run_check_cqt2(): | |
Q = 16 | |
kernel_length=4096 | |
if 1: #128 | |
fs = 2048 | |
fmin = 20 | |
fmax = 500 | |
num_freq_per_octave=27 | |
hop_length = 32 | |
cqt = MyQT( | |
Q=Q, | |
fs=fs, | |
fmin=fmin, | |
fmax=fmax, | |
num_freq_per_octave=num_freq_per_octave, | |
hop_length=hop_length, | |
kernel_length=kernel_length, | |
) | |
print(cqt.text) | |
bHP, aHP = signal.butter(4, (25, 500), btype='bandpass', fs=2048) | |
tukey = signal.tukey(4096, 0.2) | |
wave_norm=[[3.0e20],[3.0e20],[3.0e20]] | |
list = [ | |
'05dac2b96f',#0 | |
'013f088058', | |
#-- | |
'07e650dd6f',#1 | |
'01ad950925', | |
'0827de7926', | |
] | |
for id in list: | |
file = '/root/share1/kaggle/2021/g2net/data' + '/train/%s/%s/%s/%s.npy'%(id[0],id[1],id[2],id) | |
wave = np.load(file) | |
wave = wave * wave_norm | |
wave = signal.filtfilt(bHP, aHP, wave) | |
wave = wave * tukey | |
#--- | |
wave = wave.astype(np.float32) | |
image= cqt(torch.from_numpy(wave).unsqueeze(1)) | |
image = image.data.cpu().numpy() | |
print('image.shape', image.shape) | |
print('image.max', image.max(-1).max(-1)) | |
#--- | |
if 0: | |
#compare opencv and pytorch resize | |
for m in image: | |
cv_big = cv2.resize(m, dsize=(256, 256), interpolation=cv2.INTER_LINEAR) | |
pt_big = F.interpolate( | |
torch.from_numpy(m).unsqueeze(0).unsqueeze(0), | |
size =(256, 256), mode='bilinear', align_corners=False | |
) | |
pt_big = pt_big.data.cpu().numpy()[0,0] | |
diff = np.abs(cv_big - pt_big) | |
print(diff.max(),diff.mean(),diff.std(),) | |
print(cv_big.max(),cv_big.mean(),cv_big.std(),) | |
print(pt_big.max(),pt_big.mean(),pt_big.std(),) | |
print('') | |
zz=0 | |
pass | |
#--- | |
image = image**0.5 | |
overlay_shape = (384,384) #None #(224,224) | |
overlay = cqt_to_overlay(image,overlay_shape) | |
overlay1 = cqt_to_overlay1(image,overlay_shape) | |
#overlay = np.hstack([overlay,overlay1]) | |
draw_shadow_text(overlay, '%s ' % (id, ), (2, 20), 0.6, (255, 255, 255), 2) | |
image_show('overlay', overlay, resize=1) | |
image_show('overlay1', overlay1, resize=1) | |
cv2.waitKey(0) | |
# main ################################################################# | |
if __name__ == '__main__': | |
run_check_cqt2() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment