Skip to content

Instantly share code, notes, and snippets.

@BenLangmead
Created November 14, 2013 02:49
Show Gist options
  • Save BenLangmead/7460513 to your computer and use it in GitHub Desktop.
Save BenLangmead/7460513 to your computer and use it in GitHub Desktop.
{
"metadata": {
"name": ""
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": [
"import math\n",
"import numpy\n",
"\n",
"class HMM(object):\n",
" ''' Simple Hidden Markov Model implementation. User provides\n",
" transition, emission and initial probabilities in dictionaries\n",
" mapping 2-character codes onto floating-point probabilities\n",
" for those table entries. States and emissions are represented\n",
" with single characters. Emission symbols comes from a finite. '''\n",
" \n",
" def __init__(self, A, E, I):\n",
" ''' Initialize the HMM given transition, emission and initial\n",
" probability tables. '''\n",
" \n",
" # put state labels to the set self.Q\n",
" self.Q, self.S = set(), set() # states and symbols\n",
" for a, prob in A.iteritems():\n",
" asrc, adst = a[0], a[1]\n",
" self.Q.add(asrc)\n",
" self.Q.add(adst)\n",
" # add all the symbols to the set self.S\n",
" for e, prob in E.iteritems():\n",
" eq, es = e[0], e[1]\n",
" self.Q.add(eq)\n",
" self.S.add(es)\n",
" \n",
" self.Q = sorted(list(self.Q))\n",
" self.S = sorted(list(self.S))\n",
" \n",
" # create maps from state labels / emission symbols to integers\n",
" # that function as unique IDs\n",
" qmap, smap = {}, {}\n",
" for i in xrange(len(self.Q)): qmap[self.Q[i]] = i\n",
" for i in xrange(len(self.S)): smap[self.S[i]] = i\n",
" lenq = len(self.Q)\n",
" \n",
" # create and populate transition probability matrix\n",
" self.A = numpy.zeros(shape=(lenq, lenq), dtype=float)\n",
" for a, prob in A.iteritems():\n",
" asrc, adst = a[0], a[1]\n",
" self.A[qmap[asrc], qmap[adst]] = prob\n",
" # make A stochastic (i.e. make rows add to 1)\n",
" self.A /= self.A.sum(axis=1)[:, numpy.newaxis]\n",
" \n",
" # create and populate emission probability matrix\n",
" self.E = numpy.zeros(shape=(lenq, len(self.S)), dtype=float)\n",
" for e, prob in E.iteritems():\n",
" eq, es = e[0], e[1]\n",
" self.E[qmap[eq], smap[es]] = prob\n",
" # make E stochastic (i.e. make rows add to 1)\n",
" self.E /= self.E.sum(axis=1)[:, numpy.newaxis]\n",
" \n",
" # initial probabilities\n",
" self.I = [ 0.0 ] * len(self.Q)\n",
" for a, prob in I.iteritems():\n",
" self.I[qmap[a]] = prob\n",
" # make I stochastic (i.e. adds to 1)\n",
" self.I = numpy.divide(self.I, sum(self.I))\n",
" \n",
" self.qmap, self.smap = qmap, smap\n",
" \n",
" # Make log-base-2 versions for log-space functions\n",
" self.Alog = numpy.log2(self.A)\n",
" self.Elog = numpy.log2(self.E)\n",
" self.Ilog = numpy.log2(self.I)\n",
" \n",
" def jointProb(self, p, x):\n",
" ''' Return joint probability of path p and emission string x '''\n",
" p = map(self.qmap.get, p) # turn state characters into ids\n",
" x = map(self.smap.get, x) # turn emission characters into ids\n",
" tot = self.I[p[0]] # start with initial probability\n",
" for i in xrange(1, len(p)):\n",
" tot *= self.A[p[i-1], p[i]] # transition probability\n",
" for i in xrange(0, len(p)):\n",
" tot *= self.E[p[i], x[i]] # emission probability\n",
" return tot\n",
" \n",
" def jointProbL(self, p, x):\n",
" ''' Return log2 of joint probability of path p and emission\n",
" string x. Just like self.jointProb(...) but log2 domain. '''\n",
" p = map(self.qmap.get, p) # turn state characters into ids\n",
" x = map(self.smap.get, x) # turn emission characters into ids\n",
" tot = self.Ilog[p[0]] # start with initial probability\n",
" for i in xrange(1, len(p)):\n",
" tot += self.Alog[p[i-1], p[i]] # transition probability\n",
" for i in xrange(0, len(p)):\n",
" tot += self.Elog[p[i], x[i]] # emission probability\n",
" return tot\n",
" \n",
" def viterbi(self, x):\n",
" ''' Given sequence of emissions, return the most probable path\n",
" along with its probability. '''\n",
" x = map(self.smap.get, x) # turn emission characters into ids\n",
" nrow, ncol = len(self.Q), len(x)\n",
" mat = numpy.zeros(shape=(nrow, ncol), dtype=float) # prob\n",
" matTb = numpy.zeros(shape=(nrow, ncol), dtype=int) # backtrace\n",
" # Fill in first column\n",
" for i in xrange(0, nrow):\n",
" mat[i, 0] = self.E[i, x[0]] * self.I[i]\n",
" # Fill in rest of prob and Tb tables\n",
" for j in xrange(1, ncol):\n",
" for i in xrange(0, nrow):\n",
" ep = self.E[i, x[j]]\n",
" mx, mxi = mat[0, j-1] * self.A[0, i] * ep, 0\n",
" for i2 in xrange(1, nrow):\n",
" pr = mat[i2, j-1] * self.A[i2, i] * ep\n",
" if pr > mx:\n",
" mx, mxi = pr, i2\n",
" mat[i, j], matTb[i, j] = mx, mxi\n",
" # Find final state with maximal probability\n",
" omx, omxi = mat[0, ncol-1], 0\n",
" for i in xrange(1, nrow):\n",
" if mat[i, ncol-1] > omx:\n",
" omx, omxi = mat[i, ncol-1], i\n",
" # Backtrace\n",
" i, p = omxi, [omxi]\n",
" for j in xrange(ncol-1, 0, -1):\n",
" i = matTb[i, j]\n",
" p.append(i)\n",
" p = ''.join(map(lambda x: self.Q[x], p[::-1]))\n",
" return omx, p # Return probability and path\n",
" \n",
" def viterbiL(self, x):\n",
" ''' Given sequence of emissions, return the most probable path\n",
" along with log2 of its probability. Just like viterbi(...)\n",
" but in log2 domain. '''\n",
" x = map(self.smap.get, x) # turn emission characters into ids\n",
" nrow, ncol = len(self.Q), len(x)\n",
" mat = numpy.zeros(shape=(nrow, ncol), dtype=float) # prob\n",
" matTb = numpy.zeros(shape=(nrow, ncol), dtype=int) # backtrace\n",
" # Fill in first column\n",
" for i in xrange(0, nrow):\n",
" mat[i, 0] = self.Elog[i, x[0]] + self.Ilog[i]\n",
" # Fill in rest of log prob and Tb tables\n",
" for j in xrange(1, ncol):\n",
" for i in xrange(0, nrow):\n",
" ep = self.Elog[i, x[j]]\n",
" mx, mxi = mat[0, j-1] + self.Alog[0, i] + ep, 0\n",
" for i2 in xrange(1, nrow):\n",
" pr = mat[i2, j-1] + self.Alog[i2, i] + ep\n",
" if pr > mx:\n",
" mx, mxi = pr, i2\n",
" mat[i, j], matTb[i, j] = mx, mxi\n",
" # Find final state with maximal log probability\n",
" omx, omxi = mat[0, ncol-1], 0\n",
" for i in xrange(1, nrow):\n",
" if mat[i, ncol-1] > omx:\n",
" omx, omxi = mat[i, ncol-1], i\n",
" # Backtrace\n",
" i, p = omxi, [omxi]\n",
" for j in xrange(ncol-1, 0, -1):\n",
" i = matTb[i, j]\n",
" p.append(i)\n",
" p = ''.join(map(lambda x: self.Q[x], p[::-1]))\n",
" return omx, p # Return log probability and path"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# We experiment with joint probabilities first\n",
"\n",
"hmm = HMM({\"FF\":0.9, \"FL\":0.1, \"LF\":0.1, \"LL\":0.9}, # transition matrix A\n",
" {\"FH\":0.5, \"FT\":0.5, \"LH\":0.75, \"LT\":0.25}, # emission matrix E\n",
" {\"F\":0.5, \"L\":0.5}) # initial probabilities I\n",
"jprob1 = hmm.jointProb(\"FFFLLLFFFFF\", \"THTHHHTHTTH\")\n",
"myprob1 = (0.5 ** 9) * (0.75 ** 3) * (0.9 ** 8) * (0.1 ** 2)\n",
"jprob1, myprob1\n",
"# these should be about equal"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 2,
"text": [
"(3.5469405120849628e-06, 3.5469405120849624e-06)"
]
}
],
"prompt_number": 2
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# confirming that log version of jointProb works as expected\n",
"jprobL1 = hmm.jointProbL(\"FFFLLLFFFFF\", \"THTHHHTHTTH\")\n",
"math.log(jprob1, 2), jprobL1"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 3,
"text": [
"(-18.104993435171657, -18.104993435171654)"
]
}
],
"prompt_number": 3
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Trying another path\n",
"jprob2 = hmm.jointProb(\"FFFFFFFFFFF\", \"THTHHHTHTTH\")\n",
"myprob2 = (0.5 ** 12) * (0.9 ** 10)\n",
"jprob2, myprob2\n",
"# these should be about equal"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 4,
"text": [
"(8.51265722900391e-05, 8.512657229003909e-05)"
]
}
],
"prompt_number": 4
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Note that jprob2 is greater than jprob1\n",
"\n",
"# Now we experiment with viterbi decoding\n",
"jprobOpt, path = hmm.viterbi(\"THTHHHTHTTH\")\n",
"path"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 5,
"text": [
"'FFFFFFFFFFF'"
]
}
],
"prompt_number": 5
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# maximum likelihood path is same path (all fair) as the second one\n",
"# we tried above, so jprobOpt should equal jprob2\n",
"jprobOpt, jprob2"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 6,
"text": [
"(8.51265722900391e-05, 8.51265722900391e-05)"
]
}
],
"prompt_number": 6
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# confirming that log version of viterbi works as expected\n",
"jprobLOpt, _ = hmm.viterbiL(\"THTHHHTHTTH\")\n",
"math.log(jprobOpt, 2), jprobLOpt"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 7,
"text": [
"(-13.520030934450498, -13.520030934450496)"
]
}
],
"prompt_number": 7
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Now let's make a new HMM with the same states but where jumps\n",
"# between fair (F) and loaded (L) are much more probable\n",
"hmm = HMM({\"FF\":0.6, \"FL\":0.4, \"LF\":0.4, \"LL\":0.6}, # transition matrix A\n",
" {\"FH\":0.5, \"FT\":0.5, \"LH\":0.8, \"LT\":0.2}, # emission matrix E\n",
" {\"F\":0.5, \"L\":0.5}) # initial probabilities I\n",
"hmm.viterbi(\"THTHHHTHTTH\")"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 8,
"text": [
"(2.8665446400000001e-06, 'FFFLLLFFFFL')"
]
}
],
"prompt_number": 8
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Here's an example of underflow. Note that probability returned\n",
"# is 0.0 and the state string becomes all Fs after a while.\n",
"hmm.viterbi(\"THTHHHTHTTH\" * 100)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 9,
"text": [
"(0.0,\n",
" 'FFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF')"
]
}
],
"prompt_number": 9
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Moving to log2 domain fixes underflow\n",
"hmm.viterbiL(\"THTHHHTHTTH\" * 100)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 10,
"text": [
"(-1824.4030071946879,\n",
" 'FFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFFFFFLLLFFFFL')"
]
}
],
"prompt_number": 10
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"cpgHmm = HMM({'IO':0.20, 'OI':0.20, 'II':0.80, 'OO':0.80},\n",
" {'IA':0.10, 'IC':0.40, 'IG':0.40, 'IT':0.10,\n",
" 'OA':0.25, 'OC':0.25, 'OG':0.25, 'OT':0.25},\n",
" {'I' :0.50, 'O' :0.50})\n",
"x = 'ATATATACGCGCGCGCGCGCGATATATATATATA'\n",
"logp, path = cpgHmm.viterbiL(x)\n",
"print x\n",
"print path # finds the CpG island fine"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"ATATATACGCGCGCGCGCGCGATATATATATATA\n",
"OOOOOOOIIIIIIIIIIIIIIOOOOOOOOOOOOO\n"
]
}
],
"prompt_number": 11
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"x = 'ATATCGCGCGCGATATATCGCGCGCGATATATAT'\n",
"logp, path = cpgHmm.viterbiL(x)\n",
"print x\n",
"print path # finds two CpG islands fine"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"ATATCGCGCGCGATATATCGCGCGCGATATATAT\n",
"OOOOIIIIIIIIOOOOOOIIIIIIIIOOOOOOOO\n"
]
}
],
"prompt_number": 12
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"x = 'ATATATACCCCCCCCCCCCCCATATATATATATA'\n",
"logp, path = cpgHmm.viterbiL(x)\n",
"print x\n",
"print path # oops! - this is just a bunch of Cs. What went wrong?"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"ATATATACCCCCCCCCCCCCCATATATATATATA\n",
"OOOOOOOIIIIIIIIIIIIIIOOOOOOOOOOOOO\n"
]
}
],
"prompt_number": 13
},
{
"cell_type": "code",
"collapsed": false,
"input": [],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 13
}
],
"metadata": {}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment