Skip to content

Instantly share code, notes, and snippets.

@ugo-nama-kun
Last active November 11, 2022 08:21
Show Gist options
  • Save ugo-nama-kun/eb27f66d3ef98d77164e49b6f9b9d600 to your computer and use it in GitHub Desktop.
Save ugo-nama-kun/eb27f66d3ef98d77164e49b6f9b9d600 to your computer and use it in GitHub Desktop.
import numpy as np
from sklearn import datasets
"""
8x8 tiny NIST image handler
NIST dataset: https://scikit-learn.org/stable/datasets/toy_dataset.html#digits-dataset
If you chose NistHandle(flat=True), you'll get 64-dim vector instead of 8x8 numpy array.
Usage:
handler = NistHandle()
image = handler.get(1) # you'll get 8x8 image of "one".
"""
class NistHandle:
def __init__(self, flat=False):
digits = datasets.load_digits()
self.digit_dict = {}
for n, target in enumerate(digits.target):
image = digits.images[n] if not flat else digits.data[n]
if self.digit_dict.get(target) is None:
self.digit_dict[target] = [image]
else:
self.digit_dict[target].append(image)
def get(self, target: int):
digit_list = self.digit_dict[target]
image_index = np.random.choice(len(digit_list))
return digit_list[image_index]
# Examples
print("mnist images")
mnist = MnistHandle()
print(len(mnist.digit_dict[0]))
import matplotlib.pyplot as plt
plt.imshow(mnist.get(0))
plt.show()
print(len(mnist.digit_dict[1]))
plt.imshow(mnist.get(1))
print(mnist.get(1).shape) # 8x8
print("flat mnist")
mnist = MnistHandle(flat=True)
print(len(mnist.digit_dict[0]))
print(mnist.get(1).shape) # 64
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment