Skip to content

Instantly share code, notes, and snippets.

@cstorm125
Created December 26, 2018 09:30
Show Gist options
  • Save cstorm125/c4096b32512678269e3170b2b97cbd77 to your computer and use it in GitHub Desktop.
Save cstorm125/c4096b32512678269e3170b2b97cbd77 to your computer and use it in GitHub Desktop.
class weights for weighted loss / weighted sampler
def get_class_weights(label_freq, mu=1., return_log=False):
total = np.sum(list(label_freq.values()))
keys = label_freq.keys()
class_weight = dict()
class_weight_log = dict()
for key in keys:
score = total / float(label_freq[key])
score_log = np.log(mu * score)
class_weight[key] = round(score, 2) if score > 1.0 else 1.0
class_weight_log[key] = round(score_log, 2) if score_log > 1.0 else 1.0
if return_log:
return(class_weight_log)
else:
return(class_weight)
def get_sample_weights(labels,class_weights,splitter=' '):
sample_weights = []
for label in labels:
max_weight = np.max([class_weights[i] for i in str(label).split(splitter)])
sample_weights.append(max_weight)
return(sample_weights)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment