Skip to content

Instantly share code, notes, and snippets.

@megha444
Last active October 10, 2020 06:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save megha444/32b0320e4eec146669e52924a1181319 to your computer and use it in GitHub Desktop.
Save megha444/32b0320e4eec146669e52924a1181319 to your computer and use it in GitHub Desktop.
from sklearn.utils.class_weight import compute_class_weight
#compute class weights
classw = compute_class_weight('balanced', np.unique(train_labels), trainlabels)
print("Class Weights:",classweights)
#Output obtained: [0.57743559 3.72848948]
# converting list of class weights to a tensor
weight= torch.tensor(classw , dtype=torch.float)
# push to GPU
weight = weight.to(device)
# define the loss function
crossentropy = nn.NLLLoss(weight=weight)
# number of training epochs
epochs = 10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment