Skip to content

Instantly share code, notes, and snippets.

@pyaf
Last active February 2, 2018 16:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pyaf/9c015fc03b80885a77d57fbe4c1a2a05 to your computer and use it in GitHub Desktop.
Save pyaf/9c015fc03b80885a77d57fbe4c1a2a05 to your computer and use it in GitHub Desktop.
Loss function for MURA model
# tai = total abnormal images, tni = total normal images
# study_data[x] is a study level dataframe
tai = {x: get_count(study_data[x], 'positive') for x in data_cat}
tni = {x: get_count(study_data[x], 'negative') for x in data_cat}
Wt1 = {x: n_p(tni[x] / (tni[x] + tai[x])) for x in data_cat}
Wt0 = {x: n_p(tai[x] / (tni[x] + tai[x])) for x in data_cat}
class Loss(torch.nn.modules.Module):
def __init__(self, Wt1, Wt0):
super(Loss, self).__init__()
self.Wt1 = Wt1
self.Wt0 = Wt0
def forward(self, inputs, targets, phase):
loss = - (self.Wt1[phase] * targets * inputs.log() + self.Wt0[phase] * (1 - targets) * (1 - inputs).log())
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment