Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created September 2, 2023 15:32
Show Gist options
  • Save Birch-san/961623213b54ddb56fe63517f5c51df1 to your computer and use it in GitHub Desktop.
Save Birch-san/961623213b54ddb56fe63517f5c51df1 to your computer and use it in GitHub Desktop.
Tester for neighbourhood_mask, perimeter_mask
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1, 1, 1, 1, 1, 1, 1, 1],\n",
" [1, 0, 0, 0, 0, 0, 0, 1],\n",
" [1, 0, 0, 0, 0, 0, 0, 1],\n",
" [1, 0, 0, 1, 1, 0, 0, 1],\n",
" [1, 0, 0, 1, 1, 0, 0, 1],\n",
" [1, 0, 0, 0, 0, 0, 0, 1],\n",
" [1, 0, 0, 0, 0, 0, 0, 1],\n",
" [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from typing import Optional, NamedTuple\n",
"from torch import BoolTensor, arange, meshgrid, clamp\n",
"import torch\n",
"\n",
"class Dimensions(NamedTuple):\n",
" height: int\n",
" width: int\n",
"\n",
"def make_neighbourhood_mask(size: Dimensions, size_orig: Dimensions, device='cpu') -> BoolTensor:\n",
" h, w = size\n",
" h_orig, w_orig = size_orig\n",
"\n",
" h_ramp = arange(h, device=device)\n",
" w_ramp = arange(w, device=device)\n",
" h_pos, w_pos = meshgrid(h_ramp, w_ramp, indexing=\"ij\")\n",
"\n",
" # Compute start_h and end_h\n",
" start_h = clamp(h_pos - h_orig // 2, 0, h - h_orig)[..., None, None]\n",
" end_h = start_h + h_orig\n",
"\n",
" # Compute start_w and end_w\n",
" start_w = clamp(w_pos - w_orig // 2, 0, w - w_orig)[..., None, None]\n",
" end_w = start_w + w_orig\n",
"\n",
" # Broadcast and create the mask\n",
" h_range = h_ramp.reshape(1, 1, h, 1)\n",
" w_range = w_ramp.reshape(1, 1, 1, w)\n",
" mask = (h_range >= start_h) & (h_range < end_h) & (w_range >= start_w) & (w_range < end_w)\n",
"\n",
" return mask.view(h * w, h * w)\n",
"\n",
"def make_perimeter_mask(size: Dimensions, canvas_edge: Optional[int] = None, device='cpu') -> BoolTensor:\n",
" h, w = size\n",
"\n",
" h_ramp = arange(h, device=device)\n",
" w_ramp = arange(w, device=device)\n",
"\n",
" # Broadcast and create the mask\n",
" h_range = h_ramp.reshape(h, 1)\n",
" w_range = w_ramp.reshape(1, w)\n",
" \n",
" mask: BoolTensor = (h_range < canvas_edge) | (h_range >= h-canvas_edge) | (w_range < canvas_edge) | (w_range >= w-canvas_edge)\n",
"\n",
" return mask.flatten()\n",
"\n",
"torch.set_printoptions(threshold=10_000, linewidth=200)\n",
"spatial = Dimensions(8, 8)\n",
"pref = Dimensions(4, 4)\n",
"perimeter = 1\n",
"pref_shaved = Dimensions(pref.height-perimeter*2, pref.width-perimeter*2)\n",
"attn_mask = make_neighbourhood_mask(spatial, pref_shaved)\n",
"attn_mask |= make_perimeter_mask(spatial, perimeter)\n",
"\n",
"attn_mask.int().reshape(*spatial, *spatial)[spatial.height//2,spatial.width//2]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "p311",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment