Last active
November 11, 2022 08:21
-
-
Save ugo-nama-kun/eb27f66d3ef98d77164e49b6f9b9d600 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn import datasets | |
""" | |
8x8 tiny NIST image handler | |
NIST dataset: https://scikit-learn.org/stable/datasets/toy_dataset.html#digits-dataset | |
If you chose NistHandle(flat=True), you'll get 64-dim vector instead of 8x8 numpy array. | |
Usage: | |
handler = NistHandle() | |
image = handler.get(1) # you'll get 8x8 image of "one". | |
""" | |
class NistHandle: | |
def __init__(self, flat=False): | |
digits = datasets.load_digits() | |
self.digit_dict = {} | |
for n, target in enumerate(digits.target): | |
image = digits.images[n] if not flat else digits.data[n] | |
if self.digit_dict.get(target) is None: | |
self.digit_dict[target] = [image] | |
else: | |
self.digit_dict[target].append(image) | |
def get(self, target: int): | |
digit_list = self.digit_dict[target] | |
image_index = np.random.choice(len(digit_list)) | |
return digit_list[image_index] | |
# Examples | |
print("mnist images") | |
mnist = MnistHandle() | |
print(len(mnist.digit_dict[0])) | |
import matplotlib.pyplot as plt | |
plt.imshow(mnist.get(0)) | |
plt.show() | |
print(len(mnist.digit_dict[1])) | |
plt.imshow(mnist.get(1)) | |
print(mnist.get(1).shape) # 8x8 | |
print("flat mnist") | |
mnist = MnistHandle(flat=True) | |
print(len(mnist.digit_dict[0])) | |
print(mnist.get(1).shape) # 64 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment