Skip to content

Instantly share code, notes, and snippets.

@umbra-scientia
Created August 29, 2021 01:01
Show Gist options
  • Save umbra-scientia/84617cd75685780f3348febafada8e4b to your computer and use it in GitHub Desktop.
Save umbra-scientia/84617cd75685780f3348febafada8e4b to your computer and use it in GitHub Desktop.
def AttentionMask(encoder_len, state_len, decoder_len, offset=0, near_decay=0, far_decay=0, device='cpu'):
m = -offset*np.tri(decoder_len, encoder_len+decoder_len+state_len, encoder_len)
for i in range(encoder_len+decoder_len-1):
m += np.tri(decoder_len, encoder_len+decoder_len+state_len, encoder_len-i-1)
if state_len:
ms = np.zeros((state_len, encoder_len+decoder_len+state_len))
m = np.concatenate([m, ms], axis=0)
m = torch.tensor(m, dtype=torch.float32, device=device)
mx = 1-np.tri(decoder_len, encoder_len+decoder_len, encoder_len)
mx = np.concatenate([mx, np.zeros((decoder_len, state_len))], axis=1)
if state_len:
msx = np.concatenate([
np.zeros((state_len, encoder_len)),
np.ones((state_len, decoder_len)),
np.zeros((state_len, state_len))
], axis=1)
mx = np.concatenate([mx, msx], axis=0)
mx = torch.tensor(mx, device=device)
m = -(near_decay * torch.relu(-m) + far_decay * torch.relu(m))
m[mx.bool()] = -math.inf
return m
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment