Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created September 2, 2023 15:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Birch-san/b1abfb7001d0b27fc042b7204ef5c490 to your computer and use it in GitHub Desktop.
Save Birch-san/b1abfb7001d0b27fc042b7204ef5c490 to your computer and use it in GitHub Desktop.
Tester for neighbourhood_mask, perimeter_mask
from typing import Optional, NamedTuple
from torch import BoolTensor, arange, meshgrid, clamp
import torch
class Dimensions(NamedTuple):
height: int
width: int
def make_neighbourhood_mask(size: Dimensions, size_orig: Dimensions, device='cpu') -> BoolTensor:
h, w = size
h_orig, w_orig = size_orig
h_ramp = arange(h, device=device)
w_ramp = arange(w, device=device)
h_pos, w_pos = meshgrid(h_ramp, w_ramp, indexing="ij")
# Compute start_h and end_h
start_h = clamp(h_pos - h_orig // 2, 0, h - h_orig)[..., None, None]
end_h = start_h + h_orig
# Compute start_w and end_w
start_w = clamp(w_pos - w_orig // 2, 0, w - w_orig)[..., None, None]
end_w = start_w + w_orig
# Broadcast and create the mask
h_range = h_ramp.reshape(1, 1, h, 1)
w_range = w_ramp.reshape(1, 1, 1, w)
mask = (h_range >= start_h) & (h_range < end_h) & (w_range >= start_w) & (w_range < end_w)
return mask.view(h * w, h * w)
def make_perimeter_mask(size: Dimensions, canvas_edge: Optional[int] = None, device='cpu') -> BoolTensor:
h, w = size
h_ramp = arange(h, device=device)
w_ramp = arange(w, device=device)
# Broadcast and create the mask
h_range = h_ramp.reshape(h, 1)
w_range = w_ramp.reshape(1, w)
mask: BoolTensor = (h_range < canvas_edge) | (h_range >= h-canvas_edge) | (w_range < canvas_edge) | (w_range >= w-canvas_edge)
return mask.flatten()
torch.set_printoptions(threshold=10_000, linewidth=200)
spatial = Dimensions(8, 8)
pref = Dimensions(4, 4)
perimeter = 1
pref_shaved = Dimensions(pref.height-perimeter*2, pref.width-perimeter*2)
attn_mask = make_neighbourhood_mask(spatial, pref_shaved)
attn_mask |= make_perimeter_mask(spatial, perimeter)
attn_mask.int().reshape(*spatial, *spatial)[spatial.height//2,spatial.width//2]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment