Skip to content

Instantly share code, notes, and snippets.

@sammosummo
Created September 22, 2015 15:25
Show Gist options
  • Save sammosummo/71bcde28572937380785 to your computer and use it in GitHub Desktop.
Save sammosummo/71bcde28572937380785 to your computer and use it in GitHub Desktop.
Implementation of Kaernbach's (1991) adaptive staircase procedure in Python.
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 10 17:41:11 2013
@author: smathias
"""
import matplotlib.pyplot as plt
import numpy as np
class Kaernbach1991:
def __init__(self, dv0=1, p=0.75, reversals=[2, 4], stepsizes=(2.25, 1.5), initialerrfix=True, geometric=True, avgrevsonly=True, cap=False):
"""
Helper class for tracking an adaptive staircase using the weighted
transformed up/down method proposed by Kaernbach (1991). Keywords are
used to set parameters at initialisation, but they can be changed at
any point during the run if necessary.
The main part is the method 'trial' which advances the staircase.
Once the staircase is over, all of the data can be accessed, and
summarised graphically using the function 'makefig' (requires
matplotlib).
"""
s = self
s.dv = dv0
s.dvs = []
s.dvs4avg = []
s.p = p
s.factor = self.p / (1 - self.p)
s.reversals = reversals
s.stepsizes = stepsizes
s.initialerrfix = initialerrfix
s.geometric = geometric
s.avgrevsonly = avgrevsonly
s.revn = 0
s.phase = 0
s.staircaseover = False
s.firsttrial = True
s.prevcorr = None
s.trialn = 0
s.cap = cap
def trial(self, corr):
"""
Advance the staircase by one trial. Takes a Boolean which indicates
whether the listener got the trial correct or incorrect.
"""
# do nothing if the staircase is already over
s = self
if not s.staircaseover:
s.trialn += 1
s.dvs.append(self.dv)
# record dv if needed
if not s.firsttrial:
if corr != s.prevcorr:
reversal = True
s.revn += 1
else:
reversal = False
if s.phase == 1:
if s.avgrevsonly:
if reversal:
s.dvs4avg.append(s.dv)
else:
s.dvs4avg.append(s.dv)
# initial error fix: if the dv goes above the initial value during
# the first phase, add more reversals ...
if s.initialerrfix:
if not corr:
if s.trialn <= s.factor + 1:
s.reversals[0] += 2
s.initialerrfix = False
# change the dv
if s.geometric:
if corr:
s.dv /= (s.stepsizes[s.phase] ** (1/float(s.factor)))
else:
s.dv *= s.stepsizes[s.phase]
else:
if corr:
s.dv -= (s.stepsizes[s.phase] / float(s.factor))
else:
s.dv += s.stepsizes[s.phase]
# cap dv
if s.cap:
if s.dv > s.cap: s.dv = s.cap
# update the object
if s.revn >= s.reversals[0]:
s.phase = 1
if s.revn >= np.sum(s.reversals):
s.staircaseover = True
s.firsttrial = False
s.prevcorr = corr
def getthreshold(self):
"""
Once the staircase is over, get the average (geometric by default) of
the dvs to calculate the threshold.
"""
s = self
if s.staircaseover:
if s.geometric:
return np.exp(np.mean(np.log(s.dvs4avg)))
else:
return np.mean(s.dvs4avg)
def makefig(self, f=None):
"""
View or save the staircase.
"""
s= self
x = np.arange(s.trialn) + 1
y = s.dvs
if s.geometric:
plt.semilogy(x, y)
else:
plt.plot(x, y)
plt.xlim(min(x), max(x))
plt.ylim(min(y), max(y))
plt.ylabel('Dependent variable')
plt.xlabel('Trial')
if s.staircaseover:
plt.hlines(s.getthreshold(), min(x), max(x), 'r')
if f:
s.savefig(f)
else:
plt.show()
def main():
trials = np.random.randint(0, 2, 50)
kaernbach1991 = Kaernbach1991()
for trial in trials:
kaernbach1991.trial(trial)
if kaernbach1991.staircaseover:
break
print kaernbach1991.getthreshold()
kaernbach1991.makefig()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment