Skip to content

Instantly share code, notes, and snippets.

@TimboKZ
Last active July 3, 2019 23:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TimboKZ/ef3aec4c570497035b06eae30c8e61f5 to your computer and use it in GitHub Desktop.
Save TimboKZ/ef3aec4c570497035b06eae30c8e61f5 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.patches as patches
from matplotlib import pyplot as plt
T_count = 5
x_count = 7
state_count = 6
state_rows = 2
state_cols = 3
assert state_count == state_rows * state_cols
def get_state_distribution(T, x):
# Generate random distribution over states
probs = np.random.uniform(0.0, 1.0, size=state_count)
return probs / np.sum(probs)
def main():
# Populate state matrix
matrix = np.zeros((T_count, x_count, state_count))
for T in range(T_count):
for x in range(x_count):
matrix[T, x, :] = get_state_distribution(T, x)
# Prepare 2D matrix for the plot
matrix_2d = np.zeros((T_count * state_rows, x_count * state_cols))
for i in range(T_count):
for j in range(x_count):
for k in range(state_count):
index_i = i * state_rows + k // state_cols
index_j = j * state_cols + k % state_cols
matrix_2d[index_i, index_j] = matrix[i, j, k]
fig, ax = plt.subplots()
ax.set_title('State diagram')
ax.imshow(matrix_2d)
ax.set_ylabel('T')
ax.set_yticks(np.arange(0, T_count) * state_rows + state_rows // 2 - (0.5 if state_rows % 2 == 0 else 0))
ax.set_yticklabels(np.arange(0, T_count))
ax.set_xlabel('x')
ax.set_xticks(np.arange(0, x_count) * state_cols + state_cols // 2 - (0.5 if state_cols % 2 == 0 else 0))
ax.set_xticklabels(np.arange(0, x_count))
# Draw horizontal lines
for i in range(T_count + 1):
path = patches.Polygon([(-0.5, i * state_rows - 0.5), (x_count * state_cols, i * state_rows - 0.5)],
facecolor='none', edgecolor='black', linewidth=3, closed=True, joinstyle='round')
ax.add_patch(path)
# Draw vertical lines
for i in range(x_count + 1):
path = patches.Polygon([(i * state_cols - 0.5, - 0.5), (i * state_cols - 0.5, T_count * state_rows)],
facecolor='none', edgecolor='black', linewidth=3, closed=True, joinstyle='round')
ax.add_patch(path)
# Enumerate each state cell
for i in range(T_count):
for j in range(x_count):
for k in range(state_count):
index_i = j * state_cols + k % state_cols
index_j = i * state_rows + k // state_cols
ax.text(index_i, index_j, k + 1, ha="center", va="center", color="w")
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment