Skip to content

Instantly share code, notes, and snippets.

@Karts27

Karts27/plot.py Secret

Created September 14, 2021 09:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Karts27/f6e095f276b7dc0f9dcb3a32c6cbcdae to your computer and use it in GitHub Desktop.
Save Karts27/f6e095f276b7dc0f9dcb3a32c6cbcdae to your computer and use it in GitHub Desktop.
def visualize_images(df, img_size, number_of_images, name):
plt.figure(figsize=(8,8))
n_rows = df.shape[0]
f = plt.figure(figsize=(15,15)) # defining a figure
reshaped_df = df.reshape(df.shape[0], img_size, img_size)
number_of_rows = number_of_images/5 if number_of_images%5 == 0 else (number_of_images/5) +1
for i in range(number_of_images):
f.add_subplot(number_of_rows, 5, i+1, xticks=[], yticks=[])
#plt.figure(figsize = (7,7))
plt.title(np.argmax(name[i]), color = 'blue', fontdict = {'size' : '25'})
plt.imshow(reshaped_df[i], cmap='gray')
def visualize_input(img, ax):
img = img.reshape(32, 32)
ax.imshow(img, cmap='gray')
width, height = img.shape
thresh = img.max()/2.5
for x in range(width):
for y in range(height):
ax.annotate(str(round(img[x][y],2)), xy=(y,x),
horizontalalignment='center',
verticalalignment='center',
color='white' if img[x][y]<thresh else 'black')
fig = plt.figure(figsize = (15,15))
ax = fig.add_subplot(111, xticks=[], yticks=[])
visualize_input(X_train[0], ax)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment