Skip to content

Instantly share code, notes, and snippets.

Created December 9, 2022 02:08
Show Gist options
  • Save void-main/840fe163f4c891a7dddb17af076959dc to your computer and use it in GitHub Desktop.
Save void-main/840fe163f4c891a7dddb17af076959dc to your computer and use it in GitHub Desktop.
import torch
import triton
import triton.language as tl
import numpy as np
def torch_xyxy2xywh(x):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
y[:, 2] = x[:, 2] - x[:, 0] # width
y[:, 3] = x[:, 3] - x[:, 1] # height
return y
# xyxy2xywh
triton.Config({'BLOCK_SIZE': 64}),
triton.Config({'BLOCK_SIZE': 128}),
triton.Config({'BLOCK_SIZE': 256}),
triton.Config({'BLOCK_SIZE': 512}),
def triton_xyxy2xywh_kernel(
BLOCK_SIZE: tl.constexpr,
# one program process `[BLOCK_SIZE, 4]` block
pid = tl.program_id(0)
stride_row = 4 # always have 4 cols
offset_row = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offset_col = tl.arange(0, 4)
row_mask = offset_row[:, None] < n_rows
x0_ptrs = input_ptr + offset_row[:, None] * stride_row
x1_ptrs = input_ptr + offset_row[:, None] * stride_row + 1
x2_ptrs = input_ptr + offset_row[:, None] * stride_row + 2
x3_ptrs = input_ptr + offset_row[:, None] * stride_row + 3
x0 = tl.load(x0_ptrs, mask=row_mask)
x1 = tl.load(x1_ptrs, mask=row_mask)
x2 = tl.load(x2_ptrs, mask=row_mask)
x3 = tl.load(x3_ptrs, mask=row_mask)
y0 = (x0 + x2) / 2 # x center
y1 = (x1 + x3) / 2 # y center
y2 = x2 - x0 # width
y3 = x3 - x1 # height + offset_row[:, None] * stride_row , y0, mask=row_mask) + offset_row[:, None] * stride_row + 1, y1, mask=row_mask) + offset_row[:, None] * stride_row + 2, y2, mask=row_mask) + offset_row[:, None] * stride_row + 3, y3, mask=row_mask)
def triton_xyxy2xywh(x: torch.Tensor):
n_cols = x.shape[0]
n_elements = x.numel()
output = torch.empty_like(x)
assert x.is_cuda and output.is_cuda
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
triton_xyxy2xywh_kernel[grid](x, output, n_cols)
return output
size = 767
x = torch.rand(size, 4, device='cuda')
output_torch = torch_xyxy2xywh(x)
output_triton = triton_xyxy2xywh(x)
print(f'max difference: {torch.max(torch.abs(output_triton - output_torch))}')
x_names=['size'], # argument names to use as an x-axis for the plot
2 ** i for i in range(4, 24, 1)
], # different possible values for `x_name`
x_log=True, # x axis is logarithmic
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg`
line_names=['Triton', 'Torch'], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel='GB/s', # label name for the y-axis
plot_name='xyxy2xywh-performance', # name for the plot. Used also as a file name for saving the plot.
args={}, # values for function arguments not in `x_names` and `y_name`
def benchmark(size, provider):
x = torch.rand(size, 4, device='cuda', dtype=torch.float32)
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_xyxy2xywh(x))
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_xyxy2xywh(x))
gbps = lambda ms: 12 * size / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms), save_path='.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment