Skip to content

Instantly share code, notes, and snippets.

@rakki194
Created April 26, 2025 13:35
Show Gist options
  • Save rakki194/645021d7d34ed00d9efbd57c2f8785f0 to your computer and use it in GitHub Desktop.
Save rakki194/645021d7d34ed00d9efbd57c2f8785f0 to your computer and use it in GitHub Desktop.
class Fit(torch.nn.Module):
"""
A PyTorch module that resizes and optionally pads images to fit within specified bounds while maintaining aspect ratio.
This transformer performs the following operations:
1. Calculates the appropriate scale to fit the image within bounds while preserving aspect ratio
2. Resizes the image using the calculated scale
3. Optionally pads the resized image to match the target bounds
Args:
bounds (tuple[int, int] | int): Target dimensions as (height, width) or a single integer for square bounds
interpolation (InterpolationMode): The interpolation method for resizing. Defaults to LANCZOS
grow (bool): If True, allows upscaling of images smaller than bounds. If False, only downscales. Defaults to True
pad (float | None): Padding value for extending image to bounds. If None, no padding is applied. Defaults to None
Example:
```python
# Create a transformer that fits images to 384x384 with padding value of 0.5
transform = Fit(bounds=(384, 384), pad=0.5)
# Apply the transformation to an image
fitted_image = transform(original_image)
```
"""
def __init__(
self,
bounds: tuple[int, int] | int,
interpolation=InterpolationMode.LANCZOS,
grow: bool = True,
pad: float | None = None,
):
super().__init__()
self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds
self.interpolation = interpolation
self.grow = grow
self.pad = pad
def forward(self, img: Image) -> Image:
"""
Resize and optionally pad the input image to fit within specified bounds.
Args:
img (Image): Input PIL Image to be transformed
Returns:
Image: Transformed PIL Image that fits within the specified bounds
"""
image_width, image_height = img.size
target_height, target_width = self.bounds
# Calculate scaling factors for height and width
height_scale = target_height / image_height
width_scale = target_width / image_width
# Adjust scaling if grow=False (prevent upscaling)
if not self.grow:
height_scale = min(height_scale, 1.0)
width_scale = min(width_scale, 1.0)
# Use smallest scale to maintain aspect ratio
final_scale = min(height_scale, width_scale)
# Return original if no scaling needed
if final_scale == 1.0:
return img
# Calculate new dimensions after scaling
new_height = min(round(image_height * final_scale), target_height)
new_width = min(round(image_width * final_scale), target_width)
# Resize image
resized_img = transforms.Resize(
(new_height, new_width), interpolation=self.interpolation
)(img)
# Return resized image if no padding requested
if self.pad is None:
return resized_img
# Calculate padding dimensions
height_padding = target_height - new_height
width_padding = target_width - new_width
top_padding = height_padding // 2
bottom_padding = height_padding - top_padding
left_padding = width_padding // 2
right_padding = width_padding - left_padding
# Apply padding and return
return transforms.Pad(
padding=(left_padding, top_padding, right_padding, bottom_padding),
fill=self.pad,
)(resized_img)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment