Last active
October 22, 2021 02:09
-
-
Save napoler/13917c8cee276439b768769f0b71d400 to your computer and use it in GitHub Desktop.
unilm v1注意力矩阵生成 Created with Copy to Gist https://terrychan.org/2021/10/unilm_mask-%e6%b3%a8%e6%84%8f%e5%8a%9b%e7%9f%a9%e9%98%b5%e7%94%9f%e6%88%90/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def unilm_mask(inputs, s): | |
idxs = torch.cumsum(s, dim=1) | |
mask = idxs[:, None, :] <= idxs[:, :, None] | |
mask = mask[:, None].squeeze(1) | |
return mask.to(dtype=torch.int64) | |
def create_lm_mask(attention_mask, direction='l2r'): | |
seq_len = attention_mask.size(-1) | |
if attention_mask.ndim == 2: | |
attention_mask = attention_mask.view(-1, 1, seq_len) | |
idxs = torch.arange(0, seq_len).to(attention_mask) | |
if direction == 'l2r': | |
triu = (idxs.unsqueeze(-1) >= idxs).float() | |
elif direction == 'r2l': | |
triu = (idxs.unsqueeze(-1) <= idxs).float() | |
attention_mask = (attention_mask + triu > 1).float() | |
return attention_mask | |
unilm_mask= unilm_mask(inputs['input_ids'],inputs['token_type_ids']) | |
print("unilm_mask",unilm_mask) | |
create_lm_mask=create_lm_mask(inputs['input_ids']) | |
print("create_lm_mask",create_lm_mask) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment