Skip to content

Instantly share code, notes, and snippets.

@tmbdev
Created August 13, 2015 20:49
Show Gist options
  • Save tmbdev/928422f4b491f4aef4f5 to your computer and use it in GitHub Desktop.
Save tmbdev/928422f4b491f4aef4f5 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
# <nbformat>3.0</nbformat>
# <codecell>
import openfst
from scipy.ndimage import measurements,filters
from collections import defaultdict
from pylab import *
import unicodedata
class AsciiCodec:
def __init__(self):
d = {}
for i,c in enumerate(range(32,128)):
c = unichr(c)
d[i+1] = c
d[c] = i+1
self.d = d
def nclasses(self):
return 97
def encode(self,s):
s = unicode(s)
s = unicodedata.normalize('NFKD',s)
s = s.encode("ascii","replace")
reject = self.d["~"]
return array([self.d.get(c,reject) for c in s],'i')
def decode(self,a):
return "".join([self.d.get(i,"~") for i in a])
default_codec = AsciiCodec()
def fstprint(fst,states=None):
for state in fst:
for x in fst.iterarcs(state):
if states is None or state in states:
print state,"->",x.nextstate,(x.ilabel,x.olabel),x.weight.Value()
# <markdowncell>
# We start by making an FST for the transcript. This is of the form
# "#+A+#+B+#+C+#+", where "#" is the "other" label. This transduces
# class labels to positions in the ground truth string. We need to
# make sure class labels start at 1, since 0 is reserved for epsilon.
# <codecell>
def make_transcript_fst(targets):
# Make an fst for the list of classes "targets".
# We're assuming that class 0 is the "skip" class and actual
# classes are numbered starting at 1. Since 0 is epsilon in OpenFST,
# we add an offset.
gt = [1]
for i,c in enumerate(targets):
gt += [c+1,1]
transcript_fst = openfst.LogVectorFst()
states = [transcript_fst.AddState() for i in range(len(gt)+1)]
for i,c in enumerate(gt):
transcript_fst.AddArc(states[i],int(c),i+1,0.0,states[i])
transcript_fst.AddArc(states[i],int(c),i+1,0.0,states[i+1])
transcript_fst.SetStart(states[0])
transcript_fst.SetFinal(states[-1],0.0)
return transcript_fst,gt
# <markdowncell>
# Next, for demonstration purposes, we just generate random outputs
# from the classifier and transform those into a transducer. The
# input labels are times and the output labels are classes. We add 1
# again because 0 means epsilon.
# <codecell>
def make_output_fst(outputs,threshold=100.0):
n,nc = outputs.shape
signal_fst = openfst.LogVectorFst()
states = [signal_fst.AddState() for i in range(n+1)]
for i in range(n):
for c in range(0,nc):
if outputs[i,c]>=threshold: continue
signal_fst.AddArc(states[i],i+1,int(c)+1,outputs[i,c],states[i+1])
signal_fst.SetStart(states[0])
signal_fst.SetFinal(states[-1],0.0)
return signal_fst
# <markdowncell>
# Now we compose the two.
# <codecell>
def shortest_distance(comp,reverse=False):
# A wrapper for the ShortestDistance function that returns
# a NumPy vector.
dist = openfst.vector_logweight()
openfst.ShortestDistance(comp,dist,reverse)
return array([x.Value() for x in dist])
# <codecell>
def compute_time(comp):
# Compute a map from states to times.
from collections import defaultdict
time = defaultdict(set)
for state in comp:
for x in comp.iterarcs(state):
time[state].add(x.ilabel)
for state in time.keys():
l = list(time[state])
time[state] = l[0] if len(l)==1 else -1
time[1+max(time.keys())] = -1
return time
# <codecell>
def compute_transitions(comp,time,gt,dist,rdist):
# Compute a table indexed by state pairs and containing
# a list of arcs between those state pairs.
transitions = defaultdict(list)
for state in comp:
for x in comp.iterarcs(state):
t0 = time[state]
t1 = time[x.nextstate]
# print (state,x.nextstate),(t0,t1)
label = gt[x.olabel-1]
lcost = dist[state]
tcost = x.weight.Value()
rcost = rdist[x.nextstate]
cost = lcost+tcost+rcost
transitions[(t0,t1)].append((label,cost))
return transitions
# <codecell>
def arc_posteriors(ts,nc=None):
# Given a list of arcs with negative log costs,
# compute a posterior distribution.
c = array([x[0] for x in ts],'i')
if nc is None: nc = amax(c)+1
l = array([x[1] for x in ts])
l -= amin(l)
l = -l-log(sum(exp(-l)))
return measurements.sum(exp(l),c,range(nc))
# <codecell>
def ctc_align(outputs,transcript,threshold=100.0,verbose=0):
# Perform CTC-style alignment between a 2D array
# representing classifier outputs and a corresponding
# vector of transcriptions. This replaces each
# element x in the transcript with a pattern _+x+_+
# and then performs forward-backward computations.
# It outputs an array in the same shape as classifier
# outputs, but updated with the result of the forward-backward
# algorithm.
n,nc = outputs.shape
signal_fst = make_output_fst(outputs,threshold=threshold)
assert openfst.Verify(signal_fst)
transcript = array(transcript,'i')
transcript_fst,gt = make_transcript_fst(transcript)
assert openfst.Verify(transcript_fst)
comp = openfst.LogVectorFst()
openfst.ArcSortOutput(signal_fst)
openfst.ArcSortInput(transcript_fst)
if verbose: print "compose"
openfst.Compose(signal_fst,transcript_fst,comp)
openfst.Connect(comp)
assert openfst.Verify(comp)
if verbose: print "sd1"
dist = shortest_distance(comp)
if verbose: print "sd2"
rdist = shortest_distance(comp,True)
if verbose: print "compute time"
time = compute_time(comp)
if verbose: print "transitions"
transitions = compute_transitions(comp,time,gt,dist,rdist)
result = []
if verbose: print "posteriors"
for i in range(1,n):
ps = arc_posteriors(transitions[(i,i+1)],nc+1)
result.append(ps)
if verbose: print "done"
result = array(result)
return result[:,1:]
# <codecell>
if __name__=="__main__":
transcript = arange(50,dtype='i')+1
outputs = filters.gaussian_filter(rand(500,51),1.0)
ctc = ctc_align(outputs,transcript)
figsize(15,4)
print outputs.shape,ctc.shape
subplot(211); imshow(outputs.T)
subplot(212); imshow(ctc.T)
figsize(8,8)
for i in range(ctc.shape[1]):
plot(ctc[:,i])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment