Created
April 26, 2025 13:35
-
-
Save rakki194/645021d7d34ed00d9efbd57c2f8785f0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Fit(torch.nn.Module): | |
""" | |
A PyTorch module that resizes and optionally pads images to fit within specified bounds while maintaining aspect ratio. | |
This transformer performs the following operations: | |
1. Calculates the appropriate scale to fit the image within bounds while preserving aspect ratio | |
2. Resizes the image using the calculated scale | |
3. Optionally pads the resized image to match the target bounds | |
Args: | |
bounds (tuple[int, int] | int): Target dimensions as (height, width) or a single integer for square bounds | |
interpolation (InterpolationMode): The interpolation method for resizing. Defaults to LANCZOS | |
grow (bool): If True, allows upscaling of images smaller than bounds. If False, only downscales. Defaults to True | |
pad (float | None): Padding value for extending image to bounds. If None, no padding is applied. Defaults to None | |
Example: | |
```python | |
# Create a transformer that fits images to 384x384 with padding value of 0.5 | |
transform = Fit(bounds=(384, 384), pad=0.5) | |
# Apply the transformation to an image | |
fitted_image = transform(original_image) | |
``` | |
""" | |
def __init__( | |
self, | |
bounds: tuple[int, int] | int, | |
interpolation=InterpolationMode.LANCZOS, | |
grow: bool = True, | |
pad: float | None = None, | |
): | |
super().__init__() | |
self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds | |
self.interpolation = interpolation | |
self.grow = grow | |
self.pad = pad | |
def forward(self, img: Image) -> Image: | |
""" | |
Resize and optionally pad the input image to fit within specified bounds. | |
Args: | |
img (Image): Input PIL Image to be transformed | |
Returns: | |
Image: Transformed PIL Image that fits within the specified bounds | |
""" | |
image_width, image_height = img.size | |
target_height, target_width = self.bounds | |
# Calculate scaling factors for height and width | |
height_scale = target_height / image_height | |
width_scale = target_width / image_width | |
# Adjust scaling if grow=False (prevent upscaling) | |
if not self.grow: | |
height_scale = min(height_scale, 1.0) | |
width_scale = min(width_scale, 1.0) | |
# Use smallest scale to maintain aspect ratio | |
final_scale = min(height_scale, width_scale) | |
# Return original if no scaling needed | |
if final_scale == 1.0: | |
return img | |
# Calculate new dimensions after scaling | |
new_height = min(round(image_height * final_scale), target_height) | |
new_width = min(round(image_width * final_scale), target_width) | |
# Resize image | |
resized_img = transforms.Resize( | |
(new_height, new_width), interpolation=self.interpolation | |
)(img) | |
# Return resized image if no padding requested | |
if self.pad is None: | |
return resized_img | |
# Calculate padding dimensions | |
height_padding = target_height - new_height | |
width_padding = target_width - new_width | |
top_padding = height_padding // 2 | |
bottom_padding = height_padding - top_padding | |
left_padding = width_padding // 2 | |
right_padding = width_padding - left_padding | |
# Apply padding and return | |
return transforms.Pad( | |
padding=(left_padding, top_padding, right_padding, bottom_padding), | |
fill=self.pad, | |
)(resized_img) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment