Skip to content

Instantly share code, notes, and snippets.

@DIYer22
Created August 22, 2018 14:23
Show Gist options
  • Save DIYer22/4f10165ecc147b5108ebe22799aa4c3c to your computer and use it in GitHub Desktop.
Save DIYer22/4f10165ecc147b5108ebe22799aa4c3c to your computer and use it in GitHub Desktop.
code for self-attention
# -*- coding: utf-8 -*-
"""
code for self-attention
see self-attention GAN
none-local
@author: yl
"""
from boxx import *
ylsys.usecuda = False
from boxx.ylth import *
eps = 1e-10
size = lambda t:reduce(mul,t.size())
def flatten(t, dim=-1):
shape = list(t.shape)
shape[dim-1] *= shape[dim]
shape.pop(dim)
return t.reshape(tuple(shape))
def v2t(v, shape=None):
if shape is None:
shape = v.shape[:-1] + (int(v.shape[-1]**.5),)*2
return v.reshape(shape)
img = sda.astronaut()
arr = norma(img)[::10,::10]
t = tht(arr).permute(2,0,1).float()
t.requires_grad = True
v = flatten(t, )
cols = th.matmul(v[...,None], torch.ones(v.shape[-1])[...,None,:])
rows = th.matmul(torch.ones(v.shape[-1])[...,None], v[...,None,:])
distance = ((rows-cols)**2).mean(-3)**.5 # 可以优化为半三角形
expt = th.exp(-distance)
#expt = th.exp(th.exp(expt))
attention = expt/(expt.sum(-1)[...,None]+eps)
newv = (attention*v[...,None,:]).sum(-1)
newt = v2t(newv)
pix = 10
show(v2t(attention[pix]), img, None and v2t(-distance[pix]))
pix = 25*52
show(v2t(attention[pix]), img, None and v2t(-distance[pix]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment