Created
August 22, 2018 14:23
-
-
Save DIYer22/4f10165ecc147b5108ebe22799aa4c3c to your computer and use it in GitHub Desktop.
code for self-attention
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
code for self-attention | |
see self-attention GAN | |
none-local | |
@author: yl | |
""" | |
from boxx import * | |
ylsys.usecuda = False | |
from boxx.ylth import * | |
eps = 1e-10 | |
size = lambda t:reduce(mul,t.size()) | |
def flatten(t, dim=-1): | |
shape = list(t.shape) | |
shape[dim-1] *= shape[dim] | |
shape.pop(dim) | |
return t.reshape(tuple(shape)) | |
def v2t(v, shape=None): | |
if shape is None: | |
shape = v.shape[:-1] + (int(v.shape[-1]**.5),)*2 | |
return v.reshape(shape) | |
img = sda.astronaut() | |
arr = norma(img)[::10,::10] | |
t = tht(arr).permute(2,0,1).float() | |
t.requires_grad = True | |
v = flatten(t, ) | |
cols = th.matmul(v[...,None], torch.ones(v.shape[-1])[...,None,:]) | |
rows = th.matmul(torch.ones(v.shape[-1])[...,None], v[...,None,:]) | |
distance = ((rows-cols)**2).mean(-3)**.5 # 可以优化为半三角形 | |
expt = th.exp(-distance) | |
#expt = th.exp(th.exp(expt)) | |
attention = expt/(expt.sum(-1)[...,None]+eps) | |
newv = (attention*v[...,None,:]).sum(-1) | |
newt = v2t(newv) | |
pix = 10 | |
show(v2t(attention[pix]), img, None and v2t(-distance[pix])) | |
pix = 25*52 | |
show(v2t(attention[pix]), img, None and v2t(-distance[pix])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment