Skip to content

Instantly share code, notes, and snippets.

@kpsychas
Last active February 17, 2020 08:28
Show Gist options
  • Save kpsychas/d2ffd54748f72227193e9059c537b92b to your computer and use it in GitHub Desktop.
Save kpsychas/d2ffd54748f72227193e9059c537b92b to your computer and use it in GitHub Desktop.
Stationary Distribution of a Continuous Time Markov Chain in Python
"""
Simple class that finds the stationary distribution of a
Continuous Time Markov Chain with two examples
"""
from collections import defaultdict
import numpy as np
class CTMC:
def __init__(self):
self._states = set()
self._rates = defaultdict(float)
def add_state(self, state):
if state in self._states:
raise ValueError("State '{}' already exists".format(state))
self._states.add(state)
def add_states(self, states):
for state in set(states):
self.add_state(state)
def add_transition(self, state1, state2, rate):
if state1 not in self._states:
self.add_state(state1)
if state2 not in self._states:
self.add_state(state2)
self._rates[(state1, state2)] = rate
def stationary_distribution(self):
states = list(self._states)
rates = self._rates
N = len(states)
Q = np.zeros((N, N))
for i1, s1 in enumerate(states):
for i2, s2 in enumerate(states[i1+1:]):
Q[i1, i2+i1+1] = rates[(s1, s2)]
Q[i2+i1+1, i1] = rates[(s2, s1)]
Q[i1, i1] = - sum(Q[i1])
Q[:, 0] = np.ones(N)
b = np.zeros(N)
b[0] = 1
return np.linalg.solve(np.transpose(Q), b), states
@classmethod
def from_states(cls, states):
chain = CTMC()
chain.add_states(states)
return chain
@classmethod
def print_stationary_distribution(cls, p, states):
print("Probability | State")
print("-------------------")
for i, state in enumerate(states):
print("{:11.5f} | {}".format(p[i], state))
def problem1():
chain = CTMC.from_states(["a", "b"])
chain.add_transition("a", "b", 2)
chain.add_transition("b", "a", 3)
CTMC.print_stationary_distribution(*chain.stationary_distribution())
def problem2():
import itertools
N = 4
chain = CTMC()
for i, j in itertools.product(range(N), range(N)):
chain.add_state((i, j))
r1 = 1
r2 = 1
r3 = 2
for i, j in itertools.product(range(N), range(N)):
if i + 1 < N:
chain.add_transition((i+1, j), (i, j), r1)
if j + 1 < N:
chain.add_transition((i, j+1), (i, j), r2)
if i+1 < N and j+1 < N:
chain.add_transition((i, j), (i+1, j+1), r3)
p, states = chain.stationary_distribution()
print("Probability | State ")
print("---------------------")
for i, j in itertools.product(range(N), range(N)):
print("{:11.5f} | {} ".format(p[states.index((i, j))], (i, j)))
def main():
problem1()
problem2()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment