Created
June 27, 2025 08:57
-
-
Save ArthurDelannoyazerty/76a16a11bedaaf38575f4aa2e4660820 to your computer and use it in GitHub Desktop.
Take a numpy array or torch tensor and format it to be directly displayed (by matplotlib)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
import numpy as np | |
def prepare_image_for_display( | |
data: torch.Tensor|np.ndarray, | |
channels_selected: list[int] = [0,1,2], | |
permute_channels: bool = False, | |
normalize: bool = True, | |
uint8: bool = False | |
) -> np.ndarray: | |
""" | |
Prepare a tensor or numpy array for image display. | |
Args: | |
data (torch.Tensor | np.ndarray): Input data | |
channels_selected (list[int]): List of channel indices to select (default: first 3 channels) | |
permute_channels (bool): Whether to permute from (C,H,W) to (H,W,C) format | |
normalize (bool): Whether to normalize values to [0,1] range --> For 3 channels img, if not used matplotlib will clip the outside values | |
uint8 (bool): Wether to return the data as a uint8 normalized image | |
Returns: | |
np.ndarray: 2D grayscale or 3D RGB image array ready for display | |
Raises: | |
ValueError: For unsupported input shapes or invalid parameters | |
""" | |
# --------------------------------- To Numpy --------------------------------- # | |
if isinstance(data, torch.Tensor): | |
img_array = data.detach().cpu().numpy() | |
else: | |
img_array = np.array(data) | |
# ----------------------------- Handle Dimension ----------------------------- # | |
# Reduce dimensions until we have a 3D or 2D array | |
while len(img_array.shape) > 3: | |
img_array = img_array[0] | |
if len(img_array.shape)==1: # If we have a 1D array | |
size_2d = int(math.sqrt(img_array.shape[0])) | |
if size_2d * size_2d != img_array.shape[0]: | |
raise ValueError(f"Array of shape {img_array.shape} cannot be converted to a 2D array of shape ({size_2d},{size_2d})") | |
img_array = img_array.reshape((size_2d,size_2d)) # reshape into a 2D square array | |
if len(img_array.shape)==2: # If we have a 2D array we can display it | |
pass | |
if len(img_array.shape)==3: | |
# Check channels | |
if img_array.shape[0]>3: | |
if max(channels_selected)>img_array.shape[0] or len(channels_selected)!=3: # If wrong channels selected, defaulting to the first 3 channels | |
img_array = img_array[[0,1,2]] | |
else: | |
img_array = img_array[channels_selected] | |
if img_array.shape[0]==2: #If 2 channels only, we add an empty 3rd channel | |
zeros = np.zeros((1, img_array.shape[1], img_array.shape[2]), dtype=img_array.dtype) | |
img_array = np.concatenate((img_array, zeros), axis=0) | |
if img_array.shape[0]==1: #If 1 channel we make the image 2D | |
img_array = np.squeeze(img_array) | |
if normalize or uint8: | |
min_val = np.min(img_array) | |
max_val = np.max(img_array) | |
if max_val > min_val: | |
img_array = (img_array - min_val) / (max_val - min_val) | |
if uint8: | |
img_array = (img_array * 255).astype(np.uint8) | |
if permute_channels and len(img_array.shape) == 3: | |
# Permute from (C, H, W) to (H, W, C) | |
img_array = img_array.transpose((1, 2, 0)) | |
return img_array |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Some test code :