Last active
February 11, 2020 16:10
-
-
Save gngdb/ccbe4dbe31b5b84322b3c540d7204e96 to your computer and use it in GitHub Desktop.
Display Images from arrays and tensors in Python 3 (based on https://github.com/google/deepdream/blob/master/dream.ipynb)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from io import BytesIO | |
import PIL.Image | |
from IPython.display import clear_output, Image, display | |
def showarray(a, fmt='jpeg'): | |
a = a - a.min() | |
a = 255.*(a/a.max()) | |
a = np.uint8(np.clip(a, 0, 255)) | |
f = BytesIO() | |
PIL.Image.fromarray(a).save(f, fmt) | |
display(Image(data=f.getvalue())) | |
def showtensor(t, fmt='jpeg'): | |
if t.dtype is not torch.uint8: | |
t = t.to(torch.float) | |
t = t - t.min() | |
t = 255.*(t/t.max()) | |
t = torch.clamp(t, 0, 255).to(torch.uint8) | |
f = BytesIO() | |
PIL.Image.fromarray(t.numpy()).save(f, fmt) | |
display(Image(data=f.getvalue())) | |
def square_grid(x): | |
"""Takes a 3d tensor of shape (n_images, width, height) and produces a grid of those images. | |
If n_images has an integer square root, y will be square (sqrt(n_images)*width, sqrt(n_images)*height). | |
If not, all images will be displayed in a column (n_images*width, height).""" | |
from math import sqrt | |
n, w, h = x.size() | |
d = sqrt(float(n)) | |
if abs(d - round(d)) < 1e-6: | |
d = int(d) | |
y = x.view(d, d*w, h) | |
y = torch.cat([y[i] for i in range(d)], 1) | |
else: | |
y = x.view(n*w, h) | |
return y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment