Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ArthurDelannoyazerty/76a16a11bedaaf38575f4aa2e4660820 to your computer and use it in GitHub Desktop.
Save ArthurDelannoyazerty/76a16a11bedaaf38575f4aa2e4660820 to your computer and use it in GitHub Desktop.
Take a numpy array or torch tensor and format it to be directly displayed (by matplotlib)
import math
import torch
import numpy as np
def prepare_image_for_display(
data: torch.Tensor|np.ndarray,
channels_selected: list[int] = [0,1,2],
permute_channels: bool = False,
normalize: bool = True,
uint8: bool = False
) -> np.ndarray:
"""
Prepare a tensor or numpy array for image display.
Args:
data (torch.Tensor | np.ndarray): Input data
channels_selected (list[int]): List of channel indices to select (default: first 3 channels)
permute_channels (bool): Whether to permute from (C,H,W) to (H,W,C) format
normalize (bool): Whether to normalize values to [0,1] range --> For 3 channels img, if not used matplotlib will clip the outside values
uint8 (bool): Wether to return the data as a uint8 normalized image
Returns:
np.ndarray: 2D grayscale or 3D RGB image array ready for display
Raises:
ValueError: For unsupported input shapes or invalid parameters
"""
# --------------------------------- To Numpy --------------------------------- #
if isinstance(data, torch.Tensor):
img_array = data.detach().cpu().numpy()
else:
img_array = np.array(data)
# ----------------------------- Handle Dimension ----------------------------- #
# Reduce dimensions until we have a 3D or 2D array
while len(img_array.shape) > 3:
img_array = img_array[0]
if len(img_array.shape)==1: # If we have a 1D array
size_2d = int(math.sqrt(img_array.shape[0]))
if size_2d * size_2d != img_array.shape[0]:
raise ValueError(f"Array of shape {img_array.shape} cannot be converted to a 2D array of shape ({size_2d},{size_2d})")
img_array = img_array.reshape((size_2d,size_2d)) # reshape into a 2D square array
if len(img_array.shape)==2: # If we have a 2D array we can display it
pass
if len(img_array.shape)==3:
# Check channels
if img_array.shape[0]>3:
if max(channels_selected)>img_array.shape[0] or len(channels_selected)!=3: # If wrong channels selected, defaulting to the first 3 channels
img_array = img_array[[0,1,2]]
else:
img_array = img_array[channels_selected]
if img_array.shape[0]==2: #If 2 channels only, we add an empty 3rd channel
zeros = np.zeros((1, img_array.shape[1], img_array.shape[2]), dtype=img_array.dtype)
img_array = np.concatenate((img_array, zeros), axis=0)
if img_array.shape[0]==1: #If 1 channel we make the image 2D
img_array = np.squeeze(img_array)
if normalize or uint8:
min_val = np.min(img_array)
max_val = np.max(img_array)
if max_val > min_val:
img_array = (img_array - min_val) / (max_val - min_val)
if uint8:
img_array = (img_array * 255).astype(np.uint8)
if permute_channels and len(img_array.shape) == 3:
# Permute from (C, H, W) to (H, W, C)
img_array = img_array.transpose((1, 2, 0))
return img_array
@ArthurDelannoyazerty
Copy link
Author

Some test code :

def run_tests():
    """
    A non-pytest function to test and demonstrate the functionality of
    prepare_image_for_display. It creates various test cases and
    displays the results using matplotlib.
    """
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        print("Matplotlib not found. Please install it to run the display tests: pip install matplotlib")
        return

    print("--- Testing prepare_image_for_display ---")

    # --- Test Cases ---
    test_cases = []

    # Case 1: 1D numpy array
    test_cases.append({
        "title": "1. 1D numpy array (144 elements)",
        "data": np.arange(144),
        "params": {}
    })

    # Case 2: 2D torch tensor
    test_cases.append({
        "title": "2. 2D torch tensor (50x50)",
        "data": torch.randn(50, 50),
        "params": {}
    })

    # Case 3: 3D torch tensor with 1 channel
    test_cases.append({
        "title": "3. 3D torch tensor (1, 64, 64)",
        "data": torch.rand(1, 64, 64),
        "params": {"permute_channels": True}
    })

    # Case 4: 3D torch tensor with 2 channels
    x, y = np.meshgrid(np.linspace(-1, 1, 64), np.linspace(-1, 1, 64))
    ch1 = np.exp(-(x**2 + y**2))
    ch2 = np.sin(x*5) * np.cos(y*5)
    test_cases.append({
        "title": "4. 3D torch tensor (2, 64, 64)",
        "data": torch.from_numpy(np.stack([ch1, ch2])).float(),
        "params": {"permute_channels": True}
    })

    # Case 5: 3D numpy array with 4 channels
    r, g = np.meshgrid(np.linspace(0, 1, 64), np.linspace(0, 1, 64))
    b = np.ones((64, 64)) * 0.5
    a = np.eye(64, dtype=np.float32)
    data_4ch = np.stack([r, g, b, a], axis=0)
    test_cases.append({
        "title": "5. 3D numpy array (4, 64, 64)",
        "data": data_4ch,
        "params": {"channels_selected": [0, 1, 2], "permute_channels": True}
    })

    # Case 6: 4D torch tensor
    test_cases.append({
        "title": "6. 4D torch tensor (5, 3, 32, 32)",
        "data": torch.rand(5, 3, 32, 32),
        "params": {"permute_channels": True}
    })

    # Case 7: uint8 output
    r, g = np.meshgrid(np.linspace(0, 1, 64), np.linspace(0, 1, 64))
    b = np.ones((64, 64)) * 0.5
    data_3ch = np.stack([r, g, b], axis=0)
    test_cases.append({
        "title": "7. 3D numpy array to uint8",
        "data": data_3ch,
        "params": {"permute_channels": True, "uint8": True}
    })

    # --- Run and Plot ---
    num_cases = len(test_cases)
    cols = 3
    rows = (num_cases + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
    axes = axes.flatten()

    for i, case in enumerate(test_cases):
        print(f"\n{case['title']}")
        data = case["data"]
        params = case["params"]
        
        input_shape = data.shape if hasattr(data, 'shape') else 'N/A'
        print(f"   Input shape: {input_shape}")

        ax = axes[i]
        
        try:
            img_out = prepare_image_for_display(data, **params)
            print(f"   Output shape: {img_out.shape}, dtype: {img_out.dtype}")
            
            if len(img_out.shape) == 2:
                ax.imshow(img_out, cmap='gray')
            else:
                ax.imshow(img_out)

        except Exception as e:
            print(f"   ERROR processing case: {e}")
            ax.text(0.5, 0.5, f"ERROR:\n{e}", ha='center', va='center', color='red')

        ax.set_title(case['title'])
        ax.axis('off')

    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.savefig('test_img.png')


if __name__ == "__main__":
    run_tests()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment