Created
September 22, 2015 15:25
-
-
Save sammosummo/71bcde28572937380785 to your computer and use it in GitHub Desktop.
Implementation of Kaernbach's (1991) adaptive staircase procedure in Python.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Tue Dec 10 17:41:11 2013 | |
@author: smathias | |
""" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
class Kaernbach1991: | |
def __init__(self, dv0=1, p=0.75, reversals=[2, 4], stepsizes=(2.25, 1.5), initialerrfix=True, geometric=True, avgrevsonly=True, cap=False): | |
""" | |
Helper class for tracking an adaptive staircase using the weighted | |
transformed up/down method proposed by Kaernbach (1991). Keywords are | |
used to set parameters at initialisation, but they can be changed at | |
any point during the run if necessary. | |
The main part is the method 'trial' which advances the staircase. | |
Once the staircase is over, all of the data can be accessed, and | |
summarised graphically using the function 'makefig' (requires | |
matplotlib). | |
""" | |
s = self | |
s.dv = dv0 | |
s.dvs = [] | |
s.dvs4avg = [] | |
s.p = p | |
s.factor = self.p / (1 - self.p) | |
s.reversals = reversals | |
s.stepsizes = stepsizes | |
s.initialerrfix = initialerrfix | |
s.geometric = geometric | |
s.avgrevsonly = avgrevsonly | |
s.revn = 0 | |
s.phase = 0 | |
s.staircaseover = False | |
s.firsttrial = True | |
s.prevcorr = None | |
s.trialn = 0 | |
s.cap = cap | |
def trial(self, corr): | |
""" | |
Advance the staircase by one trial. Takes a Boolean which indicates | |
whether the listener got the trial correct or incorrect. | |
""" | |
# do nothing if the staircase is already over | |
s = self | |
if not s.staircaseover: | |
s.trialn += 1 | |
s.dvs.append(self.dv) | |
# record dv if needed | |
if not s.firsttrial: | |
if corr != s.prevcorr: | |
reversal = True | |
s.revn += 1 | |
else: | |
reversal = False | |
if s.phase == 1: | |
if s.avgrevsonly: | |
if reversal: | |
s.dvs4avg.append(s.dv) | |
else: | |
s.dvs4avg.append(s.dv) | |
# initial error fix: if the dv goes above the initial value during | |
# the first phase, add more reversals ... | |
if s.initialerrfix: | |
if not corr: | |
if s.trialn <= s.factor + 1: | |
s.reversals[0] += 2 | |
s.initialerrfix = False | |
# change the dv | |
if s.geometric: | |
if corr: | |
s.dv /= (s.stepsizes[s.phase] ** (1/float(s.factor))) | |
else: | |
s.dv *= s.stepsizes[s.phase] | |
else: | |
if corr: | |
s.dv -= (s.stepsizes[s.phase] / float(s.factor)) | |
else: | |
s.dv += s.stepsizes[s.phase] | |
# cap dv | |
if s.cap: | |
if s.dv > s.cap: s.dv = s.cap | |
# update the object | |
if s.revn >= s.reversals[0]: | |
s.phase = 1 | |
if s.revn >= np.sum(s.reversals): | |
s.staircaseover = True | |
s.firsttrial = False | |
s.prevcorr = corr | |
def getthreshold(self): | |
""" | |
Once the staircase is over, get the average (geometric by default) of | |
the dvs to calculate the threshold. | |
""" | |
s = self | |
if s.staircaseover: | |
if s.geometric: | |
return np.exp(np.mean(np.log(s.dvs4avg))) | |
else: | |
return np.mean(s.dvs4avg) | |
def makefig(self, f=None): | |
""" | |
View or save the staircase. | |
""" | |
s= self | |
x = np.arange(s.trialn) + 1 | |
y = s.dvs | |
if s.geometric: | |
plt.semilogy(x, y) | |
else: | |
plt.plot(x, y) | |
plt.xlim(min(x), max(x)) | |
plt.ylim(min(y), max(y)) | |
plt.ylabel('Dependent variable') | |
plt.xlabel('Trial') | |
if s.staircaseover: | |
plt.hlines(s.getthreshold(), min(x), max(x), 'r') | |
if f: | |
s.savefig(f) | |
else: | |
plt.show() | |
def main(): | |
trials = np.random.randint(0, 2, 50) | |
kaernbach1991 = Kaernbach1991() | |
for trial in trials: | |
kaernbach1991.trial(trial) | |
if kaernbach1991.staircaseover: | |
break | |
print kaernbach1991.getthreshold() | |
kaernbach1991.makefig() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment