Created
November 4, 2022 20:47
-
-
Save aaprasad/8791f09bd6888acad4dcf4a9bf40f3f2 to your computer and use it in GitHub Desktop.
ABR Thresholding
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Python implementation of ABR thresholding algorithm from Suthakar and Liberman, Hearing Research, (2019). | |
''' | |
import numpy as np | |
import scipy as sp | |
import pandas as pd | |
import seaborn as sns | |
from scipy import signal | |
from scipy.optimize import curve_fit | |
from sklearn.metrics import mean_squared_error,r2_score | |
from matplotlib import pyplot as plt | |
def cross_correlate(signal1, signal2): | |
''' | |
Calculate the cross_correlation between two abr_signals | |
signal1: first abr_signal. np array of shape (1,d) | |
signal2: second abr_signal. np array of shape (1,d) | |
returns tuple of ndarrays with shape (1,d) containing signal correlation array and signal lag array | |
''' | |
return signal.correlate(signal1,signal2), signal.correlation_lags(len(signal1), len(signal2)) | |
def pairwise_correlation(abr_signals): | |
''' | |
Calculates the pairwise correlation between a set of abr signals at each decibel level. | |
abr_signals: an n x d array containing abr signals with n levels. Assumes array is already sorted in increasing order. | |
returns: a tuple of nxd arrays containing the pairwise correlations and pairwise lags | |
''' | |
pairwise_corrs = [] | |
pairwise_lags = [] | |
for i in range(0,len(abr_signals)-1): | |
abr_corr,lags = cross_correlate(abr_signals[i,:],abr_signals[i+1,:]) #calculate cross_correlation and lag | |
abr_corr /= np.max(abr_corr) #normalize cross correlation | |
pairwise_corrs.append(abr_corr) | |
pairwise_lags.append(lags) | |
pairwise_corrs = np.array(pairwise_corrs) | |
pairwise_lags = np.array(pairwise_lags) | |
return pairwise_corrs,pairwise_lags | |
def sigmoid(x,a,b,c,d): | |
''' | |
Calculates the sigmoid function: $y=a + (b-a)/(1+10^{d(c-x)})$. In our case it is a function of correlations wrt decibel levels | |
x: the x values for the sigmoid. For our use it will be each decibel level | |
a,b,c,d the parameters of the function which will be fit in `abr_curve_fit` | |
returns y as a function of x, in our case y is the correlation. | |
''' | |
return (a + ((b-a)/(1+np.power(10,d(c-x))))) | |
def inverse_sigmoid(x,a,b,c,d): | |
''' | |
Calculates the inverse of the `sigmoid` function | |
x: the x values for the sigmoid. For our use it will be each decibel level | |
a,b,c,d the parameters of the function which will be fit in `abr_curve_fit` | |
returns x as a function of y, in our case x is the decibel level. | |
''' | |
return c - (np.log10(((b-a)/(x-a))-1)/d) | |
def power_law(x,a,b,c): | |
''' | |
Calculates the power law: ax^b + c | |
x: the x values for the sigmoid. For our use it will be each decibel level | |
a,b,c,d the parameters of the function which will be fit in `abr_curve_fit` | |
returns y as a function of x, in our case y is the correlation. | |
''' | |
return a*np.power(x,b) + c | |
def inverse_power_law(x,a,b,c): | |
''' | |
Calculates the inverse of the `power` function | |
x: the x values for the sigmoid. For our use it will be each decibel level | |
a,b,c,d the parameters of the function which will be fit in `abr_curve_fit` | |
returns x as a function of y, in our case x is the decibel level. | |
''' | |
return ((x-c)/a) ** 1/b | |
def rmse(x,y,func,params): | |
''' | |
Calculates the squareroot mean-squared error metric for the fitted curve. | |
x: x values, in our case decibel levels | |
y: actual y values, in our case correlations | |
func: the funciton for which we want to calculate the RMSE. | |
params: parameters to be passed through `func` after fitting to the data | |
returns the squareroot, meansquared error | |
''' | |
y_pred = func(x,*params) | |
mse = mean_squared_error(y,y_pred) | |
return np.sqrt(mse) | |
def r_squared(x,y,func,params): | |
''' | |
Calculates R^2 value or the coefficient of determination for a given fitted function | |
x = actual x values | |
y = actual y values | |
func = fitted function for which we'd like to calculate the R^2 value for | |
params: the parameters to be passed through `func` after fitting to the data | |
returns the r^2 value for the fitted function | |
''' | |
y_pred = func(x,*params) | |
return r2_score(y,y_pred) | |
def adjusted_r_squared(x,y,func,params): | |
''' | |
Calculates the adjusted r^2 value 1-[(1-r^2)(n-1)/(n-k-1)] | |
x = actual x values | |
y = actual y values | |
func = fitted function for which we'd like to calculate the R^2 value for | |
params: the parameters to be passed through `func` after fitting to the data | |
returns adjusted r^2 value | |
''' | |
r_squared = r_squared(x,y,func,params) | |
return (1-((1-r_squared)*(len(x)-1)/(len(x)-1-1))) | |
def abr_curve_fit(levels,abr_corrs,abr_lags): | |
''' | |
Fits curve to both sigmoid and power functions | |
Levels: each decibel level in ABR test | |
abr_corrs: n x d array of cross correlations between each consecutve abr signal | |
abr_lags: n x d array of lags between time for each consecutive abr signal | |
returns parameters for sigmoid curve fit and power_law curve fit along w their corr | |
''' | |
zero_lag_corr = abr_corrs[:,np.where(abr_lags==0)] #get correlations at 0 lag. | |
return curve_fit(sigmoid,levels,zero_lag_corr)[0],curve_fit(power_law,levels,zero_lag_corrs)[0] #fit to curves. | |
def decision_tree(levels,abr_signals,criterion=0.35): | |
''' | |
Goes thru decision tree to get decibel threshold for ABR analysis. See figure 4 of paper for more details. | |
Levels: decibel levels from abr experiment | |
abr_signals: ABR signals at each level | |
criterion: user selected correlation criteria. | |
Returns abr signal threshold. | |
''' | |
noisy=False #flag for whether or not pipeline found good threshold | |
abr_corrs,abr_lags = pairwise_correlation(abr_signals) # calculate correlations and lags | |
sigmoid_params, power_params = abr_curve_fit(levels,abr_corrs) #fit sigmoid and power curves | |
a,b,c,d = sigmoid_params | |
#Begin decision tree. | |
print('Checking if a < criterion < b and 0.005 < d < 0.999') | |
if (criterion < a and criterion < b) and (0.005 < d and d < 0.999): | |
print('True. Going down left branch...') | |
sigmoid_rmse = rmse(levels,abr_corrs,sigmoid,sigmoid_params) | |
power_rmse = rmse(levels,abr_corrs,power_law,power_params) | |
print('Comparing Sigmoid RMSE to Power RMSE.') | |
if sigmoid_rmse < power_rmse: | |
print(f'Sigmoid RMSE = {sigmoid_rmse} < Power RMSE = {power_rmse}. Returning Sigmoid Threshold. Noisy={noisy}') | |
threshold = inverse_sigmoid(levels,criterion,*sigmoid_params) | |
return threshold,noisy | |
else: | |
print(f'Sigmoid RMSE = {sigmoid_rmse} > Power RMSE = {power_rmse}. Checking if powerfit adj R^2 > 0.7.') | |
power_r = adjusted_r_squared(levels,abr_corrs,power_law,*power_params) | |
if power_r > 0.7: | |
print(f'True. Returning Power Threshold. Noisy={noisy}') | |
threshold = inverse_power(levels,criterion,*power_params) | |
return threshold,noisy | |
else: | |
print(f'False. Checking if Max Corr > Criterion') | |
if max(abr_corrs) > criterion: | |
print(f'True. Returning Power Threshold. Noisy = {noisy}.') | |
noisy=True | |
threshold = inverse_power(levels,criterion,*power_params) | |
threshold, noisy | |
else: | |
print('False. No threshold found.') | |
noisy = True | |
threshold = np.nan | |
return threshold, noisy | |
else: | |
print('False. Checking if powerfit adj R^2 > 0.7.') | |
power_r = adjusted_r_squared(levels,abr_corrs,power_law,*power_params) | |
if power_r > 0.7: | |
print(f'True. Returning Power Threshold. Noisy={noisy}') | |
threshold = inverse_power(levels,criterion,*power_params) | |
return threshold,noisy | |
else: | |
print(f'False. Checking if Max Corr > Criterion') | |
if max(abr_corrs) > criterion: | |
print(f'True. Returning Power Threshold. Noise = {noisy}.') | |
noisy=True | |
threshold = inverse_power(levels,criterion,*power_params) | |
threshold, noisy | |
else: | |
print('False. No threshold found.') | |
noisy = True | |
threshold = np.nan | |
return threshold, noisy | |
raise ValueErrorException('Decision tree reached end with no conclusion. Please check code logic.') #Stress test. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment