Skip to content

Instantly share code, notes, and snippets.

@petrosDemetrakopoulos
Created December 17, 2022 15:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save petrosDemetrakopoulos/a8ec137cd655bb508a3dcf15ef07271c to your computer and use it in GitHub Desktop.
Save petrosDemetrakopoulos/a8ec137cd655bb508a3dcf15ef07271c to your computer and use it in GitHub Desktop.
Predicting the frames
# pick a random index from validation dataset
random_index = np.random.choice(range(len(X_val)), size=1)
test_serie_X = X_val[random_index[0]]
test_serie_Y = y_val[random_index[0]]
first_frames = test_serie_X
original_frames = test_serie_Y
# predict the next 18 fames
new_prediction = model.predict(np.expand_dims(first_frames, axis=0))
new_prediction = np.squeeze(new_prediction, axis=0)
fig, axes = plt.subplots(2, 18, figsize=(20, 4))
# Plot the ground truth frames.
for idx, ax in enumerate(axes[0]):
ax.imshow(np.squeeze(original_frames[idx]), cmap="viridis")
ax.set_title(f"Frame {idx + 18}")
ax.axis("off")
# Plot the predicted frames.
for idx, ax in enumerate(axes[1]):
ax.imshow((new_prediction[idx]).reshape((344,315)), cmap="viridis")
ax.set_title(f"Frame {idx + 18}")
ax.axis("off")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment