Last active
January 29, 2024 18:12
-
-
Save atharvas/34a9a423027504848f96281f881b4d4c to your computer and use it in GitHub Desktop.
Profiling masks_to_boxes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import time | |
import tracemalloc | |
import matplotlib.pyplot as plt | |
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: | |
""" | |
Compute the bounding boxes around the provided masks. | |
Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with | |
``0 <= x1 < x2`` and ``0 <= y1 < y2``. | |
Args: | |
masks (Tensor[N, H, W]): masks to transform where N is the number of masks | |
and (H, W) are the spatial dimensions. | |
Returns: | |
Tensor[N, 4]: bounding boxes | |
""" | |
# if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
# _log_api_usage_once(masks_to_boxes) | |
if masks.numel() == 0: | |
return torch.zeros((0, 4), device=masks.device, dtype=torch.float) | |
n = masks.shape[0] | |
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float) | |
for index, mask in enumerate(masks): | |
y, x = torch.where(mask != 0) | |
bounding_boxes[index, 0] = torch.min(x) | |
bounding_boxes[index, 1] = torch.min(y) | |
bounding_boxes[index, 2] = torch.max(x) | |
bounding_boxes[index, 3] = torch.max(y) | |
return bounding_boxes | |
def batched_masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: | |
"""Omitted docstring for brevity""" | |
# if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
# _log_api_usage_once(masks_to_boxes) | |
if masks.numel() == 0: | |
return torch.zeros((0, 4), device=masks.device, dtype=torch.float) | |
non_zero_xs = torch.any(masks, axis=1).float() | |
non_zero_ys = torch.any(masks, axis=2).float() | |
y1 = non_zero_ys.argmax(dim=1) | |
x1 = non_zero_xs.argmax(dim=1) | |
y2 = (masks.shape[1] - 1) - non_zero_ys.flip(dims=[1]).argmax(dim=1) | |
x2 = (masks.shape[2] - 1) - non_zero_xs.flip(dims=[1]).argmax(dim=1) | |
bounding_boxes = torch.stack((x1, y1, x2, y2), dim=1).float() | |
return bounding_boxes | |
for _ in range(100): | |
masks = torch.randint(2, (10, 64, 64)) | |
assert (masks_to_boxes(masks) == batched_masks_to_boxes(masks)).all(), "masks_to_boxes and batched_masks_to_boxes are not equivalent" | |
def profile_function(func, batch_sizes): | |
times, memory = [], [] | |
for b in batch_sizes: | |
masks = torch.randint(2, (int(2**b), 64, 64)) | |
tracemalloc.start() | |
start = time.time() | |
func(masks) | |
current, peak = tracemalloc.get_traced_memory() | |
end = time.time() | |
tracemalloc.stop() | |
times.append(end - start) | |
memory.append(peak / 10**6) | |
return times, memory | |
def profile_function_on_gpu(func, batch_sizes): | |
# Need a separate function for GPU profiling because of CUDA memory management. | |
times, memory = [], [] | |
for b in batch_sizes: | |
masks = torch.randint(2, (int(2**b), 64, 64), device="cuda:0") | |
torch.cuda.reset_peak_memory_stats(device="cuda:0") | |
torch.cuda.reset_accumulated_memory_stats(device="cuda:0") | |
start = time.time() | |
start_mem = torch.cuda.max_memory_allocated(device='cuda:0') | |
func(masks) | |
end_mem = torch.cuda.max_memory_allocated(device='cuda:0') | |
end = time.time() | |
times.append(end - start) | |
memory.append((end_mem - start_mem) / 10**6) | |
return times, memory | |
def plot_results(x, y_old, y_new, ylabel, title, filename): | |
plt.figure(figsize=(12, 8)) | |
plt.xlabel("batch size (B, 64, 64)") | |
plt.ylabel(ylabel) | |
plt.title(title) | |
plt.plot(x, y_old, label="old", color="red", marker="o") | |
plt.plot(x, y_new, label="new", color="blue", marker="o") | |
plt.xticks(x) | |
plt.legend() | |
plt.savefig(filename) | |
batch_sizes = range(1, 16) # 2 -> 32768 | |
times_old, mem_old = profile_function(masks_to_boxes, batch_sizes) | |
times_new, mem_new = profile_function(batched_masks_to_boxes, batch_sizes) | |
gpu_times_old, gpu_mem_old = profile_function_on_gpu(masks_to_boxes, batch_sizes) | |
gpu_times_new, gpu_mem_new = profile_function_on_gpu(batched_masks_to_boxes, batch_sizes) | |
plot_results([2**b for b in batch_sizes], mem_old, mem_new, "Memory (MB)", "Memory Usage Comparison", "memory.png") | |
plot_results([2**b for b in batch_sizes], times_old, times_new, "Time (s)", "Speed Comparison", "speed.png") | |
plot_results([2**b for b in batch_sizes], gpu_mem_old, gpu_mem_new, "Memory (MB)", f"Memory Usage Comparison (GPU: {torch.cuda.get_device_name(0)})", "memory_gpu.png") | |
plot_results([2**b for b in batch_sizes], gpu_times_old, gpu_times_new, "Time (s)", f"Speed Comparison (GPU: {torch.cuda.get_device_name(0)})", "speed_gpu.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment