Skip to content

Instantly share code, notes, and snippets.

@cobanov
Created November 21, 2023 04:29
Show Gist options
  • Save cobanov/b0d62165ee823172a062b5b6aed2d639 to your computer and use it in GitHub Desktop.
Save cobanov/b0d62165ee823172a062b5b6aed2d639 to your computer and use it in GitHub Desktop.
torch, opencv & numpy related utils
from skimage import io, color
import numpy as np
import torch
def get_image(image_or_path: Union[str, np.ndarray, torch.Tensor]) -> np.ndarray:
"""
Reads an image from a file path or an array/tensor and converts it to RGB format (H, W, 3).
Args:
image_or_path (Union[str, numpy.array, torch.Tensor]): The input image or path to it.
Returns:
numpy.ndarray: The image in RGB format, or None if an error occurs.
"""
try:
if isinstance(image_or_path, str):
image = io.imread(image_or_path)
elif isinstance(image_or_path, torch.Tensor):
# Convert tensor to numpy array
image = image_or_path.detach().cpu().numpy()
elif isinstance(image_or_path, np.ndarray):
image = image_or_path
else:
raise ValueError("Unsupported input type")
# Convert to RGB if necessary
if image.ndim == 2:
image = color.gray2rgb(image)
elif image.ndim == 4:
image = image[..., :3]
return image
except IOError as e:
raise IOError(f"Error opening file: {image_or_path}") from e
except Exception as e:
raise ValueError(f"Error processing the image: {e}") from e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment