Skip to content

Instantly share code, notes, and snippets.

@napoler
Last active October 22, 2021 02:00
Show Gist options
  • Save napoler/e946ec12a90e2a856a8de9d90e8c6694 to your computer and use it in GitHub Desktop.
Save napoler/e946ec12a90e2a856a8de9d90e8c6694 to your computer and use it in GitHub Desktop.
torch where 示例 Created with Copy to Gist
>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620, 0.3139],
[ 0.3898, -0.7197],
[ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000, 0.3139],
[ 0.3898, 1.0000],
[ 0.0478, 1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779, 0.0383],
[-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
[0.0000, 0.0000]], dtype=torch.float64)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment