-
-
Save alexandremuzio/3ba9d8669f57718139da36158180baaf to your computer and use it in GitHub Desktop.
Weird triton kernel behavior for gray scale. (Meant to be copy pasted in a colab with a T4 gpu)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Mainly copied from: https://github.com/cuda-mode/lectures/blob/main/lecture%2014/A_Practitioners_Guide_to_Triton.ipynb | |
import os | |
import matplotlib.pyplot as plt | |
from urllib.request import urlretrieve | |
from pathlib import Path | |
import torch | |
from torch import tensor | |
import torchvision as tv | |
import torchvision.transforms.functional as tvf | |
from torchvision import io | |
import triton | |
import triton.language as tl | |
def cdiv(a,b): return (a + b - 1) // b | |
url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/1600px-Cute_dog.jpg?20140729055059' | |
path_img = Path('puppy.jpg') | |
if not path_img.exists(): urlretrieve(url, path_img) | |
img = io.read_image('puppy.jpg') | |
print(img.shape) | |
img[:2,:3,:4] | |
def show_img(x, figsize=(4,3), **kwargs): | |
plt.figure(figsize=figsize) | |
plt.axis('off') | |
if len(x.shape)==3: x = x.permute(1,2,0) # CHW -> HWC | |
plt.imshow(x.cpu(), **kwargs) | |
######################################################################### | |
# FOR SOME BIZARRE REASON, SETTING size TO SOMETHING DIFFERENT THAN 1066 | |
# creates a weird gray scale image below | |
size = 1066 | |
img = tvf.resize(img, size, antialias=False) | |
ch,h,w = img.shape | |
ch,h,w,h*w | |
show_img(img) | |
########################################################################## | |
@triton.jit | |
def rgb2grey_k(x_ptr, out_ptr, h, w, bs0: tl.constexpr, bs1: tl.constexpr): | |
pid_0 = tl.program_id(0) | |
pid_1 = tl.program_id(1) | |
offs_0 = pid_0 * bs0 + tl.arange(0,bs0) # 1d vector | |
offs_1 = pid_1 * bs1 + tl.arange(0,bs1) # 1d vector | |
# Weirdness: None-slicing currently doesn't work when simulating on cpu. Use tl.expand_dim instead. | |
# offs = w * tl.expand_dims(offs_0, 1) + tl.expand_dims(offs_1, 0) | |
offs = w * offs_0[:,None] + offs_1[None, :] # 2d matrix! - we multiply first offest by width, see image above | |
mask_0 = offs_0 < h # 1d vector | |
mask_1 = offs_1 < w # 1d vector | |
# mask = tl.expand_dims(mask_0, 1) & tl.expand_dims(mask_1, 0) | |
mask = mask_0[:,None] & mask_1[None,:] # 2d matrix! - data musn't go out of bounds along either axis, therefore `logical and` of the individual masks | |
r = tl.load(x_ptr + 0*h*w+offs, mask=mask) | |
g = tl.load(x_ptr + 1*h*w+offs, mask=mask) | |
b = tl.load(x_ptr + 2*h*w+offs, mask=mask) | |
# Weirdness: multiplying float with uint vectors fails when simulating on cpu | |
out = 0.2989*r + 0.5870*g + 0.1140*b # don't worry why it's these 3 numbers we're multiplying with | |
tl.store(out_ptr + offs, out, mask=mask) | |
def rgb2grey(x, bs): | |
c,h,w = x.shape | |
out = torch.empty((h,w), dtype=x.dtype, device=x.device) | |
# grid can be a function returning a 1d/2d/3d-tuple | |
# (having a grid function is not more useful than a grid tuple in this case, but will be below when benchmarking & auto-tuning) | |
grid = lambda meta: (cdiv(h, meta['bs0']), cdiv(w, meta['bs1'])) | |
rgb2grey_k[grid](x, out, h, w, bs0=bs[0], bs1=bs[1]) # all kwargs are passed into grid function | |
return out.view(h,w) | |
grey_img = rgb2grey(img.to('cuda'), bs=(32, 32)).to('cpu') | |
show_img(grey_img, cmap='gray') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Running this as is will output the following: