Skip to content

Instantly share code, notes, and snippets.

@akkefa
Created December 20, 2022 19:20
Show Gist options
  • Save akkefa/a6da98124d1d6d7d1f65e1b91cdc94d5 to your computer and use it in GitHub Desktop.
Save akkefa/a6da98124d1d6d7d1f65e1b91cdc94d5 to your computer and use it in GitHub Desktop.
Threshold a tensor into binary values using pytorch (torch.where)
import torch
# Create a tensor with values between 0 and 1
tensor = torch.rand(3, 3)
# Set a threshold value
threshold = 0.5
# Use torch.where to threshold the tensor
binary_tensor = torch.where(tensor > threshold, torch.tensor(1), torch.tensor(0))
# Print the original tensor and the thresholded tensor
print(tensor)
print(binary_tensor)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment