Skip to content

Instantly share code, notes, and snippets.

@napoler
Last active October 22, 2021 02:09
Show Gist options
  • Save napoler/13917c8cee276439b768769f0b71d400 to your computer and use it in GitHub Desktop.
Save napoler/13917c8cee276439b768769f0b71d400 to your computer and use it in GitHub Desktop.
def unilm_mask(inputs, s):
idxs = torch.cumsum(s, dim=1)
mask = idxs[:, None, :] <= idxs[:, :, None]
mask = mask[:, None].squeeze(1)
return mask.to(dtype=torch.int64)
def create_lm_mask(attention_mask, direction='l2r'):
seq_len = attention_mask.size(-1)
if attention_mask.ndim == 2:
attention_mask = attention_mask.view(-1, 1, seq_len)
idxs = torch.arange(0, seq_len).to(attention_mask)
if direction == 'l2r':
triu = (idxs.unsqueeze(-1) >= idxs).float()
elif direction == 'r2l':
triu = (idxs.unsqueeze(-1) <= idxs).float()
attention_mask = (attention_mask + triu > 1).float()
return attention_mask
unilm_mask= unilm_mask(inputs['input_ids'],inputs['token_type_ids'])
print("unilm_mask",unilm_mask)
create_lm_mask=create_lm_mask(inputs['input_ids'])
print("create_lm_mask",create_lm_mask)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment