Skip to content

Instantly share code, notes, and snippets.

@Gdahuks
Created January 27, 2024 19:22
Show Gist options
  • Save Gdahuks/77d977509d62ba2b1257916fe404da22 to your computer and use it in GitHub Desktop.
Save Gdahuks/77d977509d62ba2b1257916fe404da22 to your computer and use it in GitHub Desktop.
Calculates the transition matrix from a given 1D numpy array.
import numpy as np
def get_valid_indices(series: np.ndarray) -> (np.ndarray, np.ndarray):
"""
Get valid indices from a series.
This function gets valid indices from a given series. The valid indices are those where the series
value and successor value are not NaN. The function returns two arrays:
the valid indices and the valid indices shifted by one to calculate the transition matrix.
Parameters
----------
series : np.ndarray
The input series from which to get valid indices. It is expected to be a 1D numpy array.
Returns
-------
valid_indices : np.ndarray
The valid indices in the series. It is a 1D numpy array of integers.
valid_indices + 1 : np.ndarray
The valid indices in the series shifted by one. It is a 1D numpy array of integers.
"""
valid_mask = ~np.isnan(series)
valid_indices = np.nonzero(valid_mask[:-1] & valid_mask[1:])[0]
valid_indices = np.array(valid_indices, dtype=np.int64)
return valid_indices, valid_indices + 1
def remove_zero_rows(
matrix: np.ndarray,
states: np.ndarray,
) -> (np.ndarray, np.ndarray):
"""
Remove rows from the matrix and corresponding columns and states that are all zeros.
This function removes rows from the input matrix that are all zeros. It also removes the corresponding
states. The function is recursive, meaning it will continue to remove zero rows until there are none left.
It is intended to fix the situation during the creation of a transition matrix when a state that occurs only once
is at the end of the series or before np.nan - this results in a row that is 0 everywhere.
The recursion is due to the fact that after removing the row and column,
the new row may turn out to occur with zeros only.
Parameters
----------
matrix : np.ndarray
The input matrix from which to remove zero rows. It is expected to be a 2D numpy array.
states : np.ndarray
The states corresponding to the rows of the matrix. It is expected to be a 1D numpy array of the same
length as the number of rows in the matrix.
Returns
-------
matrix : np.ndarray
The input matrix with zero rows removed. It is a 2D numpy array.
states : np.ndarray
The states with elements corresponding to zero rows in the input matrix removed. It is a 1D numpy array.
"""
zero_rows_mask = np.all(matrix == 0, axis=1)
if not np.any(zero_rows_mask):
return matrix, states
matrix = matrix[~zero_rows_mask, :][:, ~zero_rows_mask]
states = states[~zero_rows_mask]
return remove_zero_rows(matrix, states)
def compute_transition_matrix(
series: np.ndarray,
normalize: bool = True
) -> (np.ndarray, list):
"""
Compute a transition matrix from a given series.
This function computes a transition matrix from a given series. The transition matrix is a square matrix
where each element (i, j) represents the transition from state i to state j. The states are unique values
in the series. If the `normalize` parameter is set to True, the transition matrix is normalized so that
each row sums to 1.
Parameters
----------
series : np.ndarray
The input series from which to compute the transition matrix.
It is expected to be a 1D numpy array. Should not contain None values.
normalize : bool, optional
If True, the transition matrix is normalized so that each row sums to 1. Default is True.
Returns
-------
transition_matrix : np.ndarray
The computed transition matrix. It is a 2D numpy array of shape (n_states, n_states), where n_states
is the number of unique states in the series.
states : list
The unique states in the series. It is a list of length n_states.
"""
states, symbolized_series = np.unique(series, return_inverse=True)
states = states[~np.isnan(states)]
transition_matrix = np.zeros((len(states), len(states)), dtype=np.uint64)
valid_indices, next_valid_indices = get_valid_indices(series)
np.add.at(
transition_matrix,
(symbolized_series[valid_indices], symbolized_series[next_valid_indices]),
1,
)
transition_matrix, states = remove_zero_rows(
transition_matrix, states,
)
if normalize:
row_sums = transition_matrix.sum(axis=1, keepdims=True)
transition_matrix = transition_matrix / row_sums
return transition_matrix, list(states)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment