Last active
July 3, 2019 23:09
-
-
Save TimboKZ/ef3aec4c570497035b06eae30c8e61f5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.patches as patches | |
from matplotlib import pyplot as plt | |
T_count = 5 | |
x_count = 7 | |
state_count = 6 | |
state_rows = 2 | |
state_cols = 3 | |
assert state_count == state_rows * state_cols | |
def get_state_distribution(T, x): | |
# Generate random distribution over states | |
probs = np.random.uniform(0.0, 1.0, size=state_count) | |
return probs / np.sum(probs) | |
def main(): | |
# Populate state matrix | |
matrix = np.zeros((T_count, x_count, state_count)) | |
for T in range(T_count): | |
for x in range(x_count): | |
matrix[T, x, :] = get_state_distribution(T, x) | |
# Prepare 2D matrix for the plot | |
matrix_2d = np.zeros((T_count * state_rows, x_count * state_cols)) | |
for i in range(T_count): | |
for j in range(x_count): | |
for k in range(state_count): | |
index_i = i * state_rows + k // state_cols | |
index_j = j * state_cols + k % state_cols | |
matrix_2d[index_i, index_j] = matrix[i, j, k] | |
fig, ax = plt.subplots() | |
ax.set_title('State diagram') | |
ax.imshow(matrix_2d) | |
ax.set_ylabel('T') | |
ax.set_yticks(np.arange(0, T_count) * state_rows + state_rows // 2 - (0.5 if state_rows % 2 == 0 else 0)) | |
ax.set_yticklabels(np.arange(0, T_count)) | |
ax.set_xlabel('x') | |
ax.set_xticks(np.arange(0, x_count) * state_cols + state_cols // 2 - (0.5 if state_cols % 2 == 0 else 0)) | |
ax.set_xticklabels(np.arange(0, x_count)) | |
# Draw horizontal lines | |
for i in range(T_count + 1): | |
path = patches.Polygon([(-0.5, i * state_rows - 0.5), (x_count * state_cols, i * state_rows - 0.5)], | |
facecolor='none', edgecolor='black', linewidth=3, closed=True, joinstyle='round') | |
ax.add_patch(path) | |
# Draw vertical lines | |
for i in range(x_count + 1): | |
path = patches.Polygon([(i * state_cols - 0.5, - 0.5), (i * state_cols - 0.5, T_count * state_rows)], | |
facecolor='none', edgecolor='black', linewidth=3, closed=True, joinstyle='round') | |
ax.add_patch(path) | |
# Enumerate each state cell | |
for i in range(T_count): | |
for j in range(x_count): | |
for k in range(state_count): | |
index_i = j * state_cols + k % state_cols | |
index_j = i * state_rows + k // state_cols | |
ax.text(index_i, index_j, k + 1, ha="center", va="center", color="w") | |
plt.show() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment