Skip to content

Instantly share code, notes, and snippets.

@Gdahuks
Created January 27, 2024 19:59
Show Gist options
  • Save Gdahuks/b7e72a136f54a348253dd42b4d2c8803 to your computer and use it in GitHub Desktop.
Save Gdahuks/b7e72a136f54a348253dd42b4d2c8803 to your computer and use it in GitHub Desktop.
Calculates the stationary state for the given transition matrix by solving an eigenproblem.
import numpy as np
import scipy
def calculate_stationary_state(transition_matrix: np.ndarray) -> np.ndarray:
"""
Calculate the stationary state of a transition matrix.
This function calculates the stationary state of a given transition matrix. The stationary state is the
state that the system will converge to over time. It is calculated by finding the eigenvector of the
transition matrix that corresponds to an eigenvalue of 1. If no such eigenvector exists, an exception is
raised.
Parameters
----------
transition_matrix : np.ndarray
The transition matrix for which to calculate the stationary state. It is expected to be a 2D numpy array.
Returns
-------
stationary_state : np.ndarray
The stationary state of the transition matrix. It is a 1D numpy array.
Raises
------
Exception
If no stationary state is found, an exception is raised.
"""
eigenvalues, eigenvectors = scipy.linalg.eig(transition_matrix, left=True, right=False)
eigenvectors[np.isclose(eigenvectors, 0)] = 0
eigenvectors = np.real_if_close(np.transpose(eigenvectors))
condition = (
np.isclose(eigenvalues, 1) & (
np.all(np.greater_equal(eigenvectors, 0), axis=1) |
np.all(np.less_equal(eigenvectors, 0), axis=1)
)
)
if not np.any(condition):
raise Exception("No stationary state found")
index = np.argmax(condition)
stationary_state = np.abs(eigenvectors[index])
stationary_state /= stationary_state.sum()
return stationary_state
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment