output = torch.matmul(scores, values)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
view raw model.py hosted with ❤ by GitHub