Skip to content

Instantly share code, notes, and snippets.

@prateekjoshi565
Created July 18, 2020 10:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save prateekjoshi565/786cb83e5afb71b10332d061c9f18d6b to your computer and use it in GitHub Desktop.
Save prateekjoshi565/786cb83e5afb71b10332d061c9f18d6b to your computer and use it in GitHub Desktop.
# converting list of class weights to a tensor
weights= torch.tensor(class_weights,dtype=torch.float)
# push to GPU
weights = weights.to(device)
# define the loss function
cross_entropy = nn.NLLLoss(weight=weights)
# number of training epochs
epochs = 10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment