Skip to content

Instantly share code, notes, and snippets.

@atharvas
Last active January 29, 2024 18:12
Show Gist options
  • Save atharvas/34a9a423027504848f96281f881b4d4c to your computer and use it in GitHub Desktop.
Save atharvas/34a9a423027504848f96281f881b4d4c to your computer and use it in GitHub Desktop.
Profiling masks_to_boxes
import torch
import time
import tracemalloc
import matplotlib.pyplot as plt
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
"""
Compute the bounding boxes around the provided masks.
Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Args:
masks (Tensor[N, H, W]): masks to transform where N is the number of masks
and (H, W) are the spatial dimensions.
Returns:
Tensor[N, 4]: bounding boxes
"""
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(masks_to_boxes)
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
n = masks.shape[0]
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
for index, mask in enumerate(masks):
y, x = torch.where(mask != 0)
bounding_boxes[index, 0] = torch.min(x)
bounding_boxes[index, 1] = torch.min(y)
bounding_boxes[index, 2] = torch.max(x)
bounding_boxes[index, 3] = torch.max(y)
return bounding_boxes
def batched_masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
"""Omitted docstring for brevity"""
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(masks_to_boxes)
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
non_zero_xs = torch.any(masks, axis=1).float()
non_zero_ys = torch.any(masks, axis=2).float()
y1 = non_zero_ys.argmax(dim=1)
x1 = non_zero_xs.argmax(dim=1)
y2 = (masks.shape[1] - 1) - non_zero_ys.flip(dims=[1]).argmax(dim=1)
x2 = (masks.shape[2] - 1) - non_zero_xs.flip(dims=[1]).argmax(dim=1)
bounding_boxes = torch.stack((x1, y1, x2, y2), dim=1).float()
return bounding_boxes
for _ in range(100):
masks = torch.randint(2, (10, 64, 64))
assert (masks_to_boxes(masks) == batched_masks_to_boxes(masks)).all(), "masks_to_boxes and batched_masks_to_boxes are not equivalent"
def profile_function(func, batch_sizes):
times, memory = [], []
for b in batch_sizes:
masks = torch.randint(2, (int(2**b), 64, 64))
tracemalloc.start()
start = time.time()
func(masks)
current, peak = tracemalloc.get_traced_memory()
end = time.time()
tracemalloc.stop()
times.append(end - start)
memory.append(peak / 10**6)
return times, memory
def profile_function_on_gpu(func, batch_sizes):
# Need a separate function for GPU profiling because of CUDA memory management.
times, memory = [], []
for b in batch_sizes:
masks = torch.randint(2, (int(2**b), 64, 64), device="cuda:0")
torch.cuda.reset_peak_memory_stats(device="cuda:0")
torch.cuda.reset_accumulated_memory_stats(device="cuda:0")
start = time.time()
start_mem = torch.cuda.max_memory_allocated(device='cuda:0')
func(masks)
end_mem = torch.cuda.max_memory_allocated(device='cuda:0')
end = time.time()
times.append(end - start)
memory.append((end_mem - start_mem) / 10**6)
return times, memory
def plot_results(x, y_old, y_new, ylabel, title, filename):
plt.figure(figsize=(12, 8))
plt.xlabel("batch size (B, 64, 64)")
plt.ylabel(ylabel)
plt.title(title)
plt.plot(x, y_old, label="old", color="red", marker="o")
plt.plot(x, y_new, label="new", color="blue", marker="o")
plt.xticks(x)
plt.legend()
plt.savefig(filename)
batch_sizes = range(1, 16) # 2 -> 32768
times_old, mem_old = profile_function(masks_to_boxes, batch_sizes)
times_new, mem_new = profile_function(batched_masks_to_boxes, batch_sizes)
gpu_times_old, gpu_mem_old = profile_function_on_gpu(masks_to_boxes, batch_sizes)
gpu_times_new, gpu_mem_new = profile_function_on_gpu(batched_masks_to_boxes, batch_sizes)
plot_results([2**b for b in batch_sizes], mem_old, mem_new, "Memory (MB)", "Memory Usage Comparison", "memory.png")
plot_results([2**b for b in batch_sizes], times_old, times_new, "Time (s)", "Speed Comparison", "speed.png")
plot_results([2**b for b in batch_sizes], gpu_mem_old, gpu_mem_new, "Memory (MB)", f"Memory Usage Comparison (GPU: {torch.cuda.get_device_name(0)})", "memory_gpu.png")
plot_results([2**b for b in batch_sizes], gpu_times_old, gpu_times_new, "Time (s)", f"Speed Comparison (GPU: {torch.cuda.get_device_name(0)})", "speed_gpu.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment