scores = torch.matmul(query, keys.transpose(2, 3)) / math.sqrt(self.head_size)
scores = F.softmax(scores.float(), dim=-1).type_as(query)
view raw model.py hosted with ❤ by GitHub