Skip to content

Instantly share code, notes, and snippets.

@alper111
Created March 23, 2022 08:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alper111/5ba8f76461ac14ca718c119201831370 to your computer and use it in GitHub Desktop.
Save alper111/5ba8f76461ac14ca718c119201831370 to your computer and use it in GitHub Desktop.
PyTorch torchvision.utils.save_grid method with colored pad_value.
from PIL import Image
def save_with_color(x, filename, n_per_row, pad_value=[0., 0., 0.], padding=2):
N, C, H, W = x.shape
n_per_col = N // n_per_row
if N > n_per_row * n_per_col:
n_per_col += 1
canvas = torch.empty(H*n_per_col+(n_per_col-1)*padding,
W*n_per_row+(n_per_row-1)*padding,
3,
dtype=torch.uint8)
canvas[:, :] = torch.tensor([p*255 for p in pad_value], dtype=torch.uint8)
for i in range(n_per_col):
for j in range(n_per_row):
current_idx = i*n_per_row+j
if current_idx == N:
Image.fromarray(canvas.numpy()).save(filename)
return 0
if C == 1:
canvas[i*H+i*padding:(i+1)*H+i*padding, j*W+j*padding:(j+1)*W+j*padding] = (x[current_idx].permute(1, 2, 0).repeat(1, 1, 3)*255).byte()
else:
canvas[i*H+i*padding:(i+1)*H+i*padding, j*W+j*padding:(j+1)*W+j*padding] = (x[current_idx].permute(1, 2, 0)*255).byte()
Image.fromarray(canvas.numpy()).save(filename)
return 0
@alper111
Copy link
Author

It is a quick and dirty implementation, feel free to update and optimize.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment