Skip to content

Instantly share code, notes, and snippets.

@aaprasad
Created November 4, 2022 20:47
Show Gist options
  • Save aaprasad/8791f09bd6888acad4dcf4a9bf40f3f2 to your computer and use it in GitHub Desktop.
Save aaprasad/8791f09bd6888acad4dcf4a9bf40f3f2 to your computer and use it in GitHub Desktop.
ABR Thresholding
'''
Python implementation of ABR thresholding algorithm from Suthakar and Liberman, Hearing Research, (2019).
'''
import numpy as np
import scipy as sp
import pandas as pd
import seaborn as sns
from scipy import signal
from scipy.optimize import curve_fit
from sklearn.metrics import mean_squared_error,r2_score
from matplotlib import pyplot as plt
def cross_correlate(signal1, signal2):
'''
Calculate the cross_correlation between two abr_signals
signal1: first abr_signal. np array of shape (1,d)
signal2: second abr_signal. np array of shape (1,d)
returns tuple of ndarrays with shape (1,d) containing signal correlation array and signal lag array
'''
return signal.correlate(signal1,signal2), signal.correlation_lags(len(signal1), len(signal2))
def pairwise_correlation(abr_signals):
'''
Calculates the pairwise correlation between a set of abr signals at each decibel level.
abr_signals: an n x d array containing abr signals with n levels. Assumes array is already sorted in increasing order.
returns: a tuple of nxd arrays containing the pairwise correlations and pairwise lags
'''
pairwise_corrs = []
pairwise_lags = []
for i in range(0,len(abr_signals)-1):
abr_corr,lags = cross_correlate(abr_signals[i,:],abr_signals[i+1,:]) #calculate cross_correlation and lag
abr_corr /= np.max(abr_corr) #normalize cross correlation
pairwise_corrs.append(abr_corr)
pairwise_lags.append(lags)
pairwise_corrs = np.array(pairwise_corrs)
pairwise_lags = np.array(pairwise_lags)
return pairwise_corrs,pairwise_lags
def sigmoid(x,a,b,c,d):
'''
Calculates the sigmoid function: $y=a + (b-a)/(1+10^{d(c-x)})$. In our case it is a function of correlations wrt decibel levels
x: the x values for the sigmoid. For our use it will be each decibel level
a,b,c,d the parameters of the function which will be fit in `abr_curve_fit`
returns y as a function of x, in our case y is the correlation.
'''
return (a + ((b-a)/(1+np.power(10,d(c-x)))))
def inverse_sigmoid(x,a,b,c,d):
'''
Calculates the inverse of the `sigmoid` function
x: the x values for the sigmoid. For our use it will be each decibel level
a,b,c,d the parameters of the function which will be fit in `abr_curve_fit`
returns x as a function of y, in our case x is the decibel level.
'''
return c - (np.log10(((b-a)/(x-a))-1)/d)
def power_law(x,a,b,c):
'''
Calculates the power law: ax^b + c
x: the x values for the sigmoid. For our use it will be each decibel level
a,b,c,d the parameters of the function which will be fit in `abr_curve_fit`
returns y as a function of x, in our case y is the correlation.
'''
return a*np.power(x,b) + c
def inverse_power_law(x,a,b,c):
'''
Calculates the inverse of the `power` function
x: the x values for the sigmoid. For our use it will be each decibel level
a,b,c,d the parameters of the function which will be fit in `abr_curve_fit`
returns x as a function of y, in our case x is the decibel level.
'''
return ((x-c)/a) ** 1/b
def rmse(x,y,func,params):
'''
Calculates the squareroot mean-squared error metric for the fitted curve.
x: x values, in our case decibel levels
y: actual y values, in our case correlations
func: the funciton for which we want to calculate the RMSE.
params: parameters to be passed through `func` after fitting to the data
returns the squareroot, meansquared error
'''
y_pred = func(x,*params)
mse = mean_squared_error(y,y_pred)
return np.sqrt(mse)
def r_squared(x,y,func,params):
'''
Calculates R^2 value or the coefficient of determination for a given fitted function
x = actual x values
y = actual y values
func = fitted function for which we'd like to calculate the R^2 value for
params: the parameters to be passed through `func` after fitting to the data
returns the r^2 value for the fitted function
'''
y_pred = func(x,*params)
return r2_score(y,y_pred)
def adjusted_r_squared(x,y,func,params):
'''
Calculates the adjusted r^2 value 1-[(1-r^2)(n-1)/(n-k-1)]
x = actual x values
y = actual y values
func = fitted function for which we'd like to calculate the R^2 value for
params: the parameters to be passed through `func` after fitting to the data
returns adjusted r^2 value
'''
r_squared = r_squared(x,y,func,params)
return (1-((1-r_squared)*(len(x)-1)/(len(x)-1-1)))
def abr_curve_fit(levels,abr_corrs,abr_lags):
'''
Fits curve to both sigmoid and power functions
Levels: each decibel level in ABR test
abr_corrs: n x d array of cross correlations between each consecutve abr signal
abr_lags: n x d array of lags between time for each consecutive abr signal
returns parameters for sigmoid curve fit and power_law curve fit along w their corr
'''
zero_lag_corr = abr_corrs[:,np.where(abr_lags==0)] #get correlations at 0 lag.
return curve_fit(sigmoid,levels,zero_lag_corr)[0],curve_fit(power_law,levels,zero_lag_corrs)[0] #fit to curves.
def decision_tree(levels,abr_signals,criterion=0.35):
'''
Goes thru decision tree to get decibel threshold for ABR analysis. See figure 4 of paper for more details.
Levels: decibel levels from abr experiment
abr_signals: ABR signals at each level
criterion: user selected correlation criteria.
Returns abr signal threshold.
'''
noisy=False #flag for whether or not pipeline found good threshold
abr_corrs,abr_lags = pairwise_correlation(abr_signals) # calculate correlations and lags
sigmoid_params, power_params = abr_curve_fit(levels,abr_corrs) #fit sigmoid and power curves
a,b,c,d = sigmoid_params
#Begin decision tree.
print('Checking if a < criterion < b and 0.005 < d < 0.999')
if (criterion < a and criterion < b) and (0.005 < d and d < 0.999):
print('True. Going down left branch...')
sigmoid_rmse = rmse(levels,abr_corrs,sigmoid,sigmoid_params)
power_rmse = rmse(levels,abr_corrs,power_law,power_params)
print('Comparing Sigmoid RMSE to Power RMSE.')
if sigmoid_rmse < power_rmse:
print(f'Sigmoid RMSE = {sigmoid_rmse} < Power RMSE = {power_rmse}. Returning Sigmoid Threshold. Noisy={noisy}')
threshold = inverse_sigmoid(levels,criterion,*sigmoid_params)
return threshold,noisy
else:
print(f'Sigmoid RMSE = {sigmoid_rmse} > Power RMSE = {power_rmse}. Checking if powerfit adj R^2 > 0.7.')
power_r = adjusted_r_squared(levels,abr_corrs,power_law,*power_params)
if power_r > 0.7:
print(f'True. Returning Power Threshold. Noisy={noisy}')
threshold = inverse_power(levels,criterion,*power_params)
return threshold,noisy
else:
print(f'False. Checking if Max Corr > Criterion')
if max(abr_corrs) > criterion:
print(f'True. Returning Power Threshold. Noisy = {noisy}.')
noisy=True
threshold = inverse_power(levels,criterion,*power_params)
threshold, noisy
else:
print('False. No threshold found.')
noisy = True
threshold = np.nan
return threshold, noisy
else:
print('False. Checking if powerfit adj R^2 > 0.7.')
power_r = adjusted_r_squared(levels,abr_corrs,power_law,*power_params)
if power_r > 0.7:
print(f'True. Returning Power Threshold. Noisy={noisy}')
threshold = inverse_power(levels,criterion,*power_params)
return threshold,noisy
else:
print(f'False. Checking if Max Corr > Criterion')
if max(abr_corrs) > criterion:
print(f'True. Returning Power Threshold. Noise = {noisy}.')
noisy=True
threshold = inverse_power(levels,criterion,*power_params)
threshold, noisy
else:
print('False. No threshold found.')
noisy = True
threshold = np.nan
return threshold, noisy
raise ValueErrorException('Decision tree reached end with no conclusion. Please check code logic.') #Stress test.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment