Skip to content

Instantly share code, notes, and snippets.

@cbaziotis
Created January 23, 2017 11:55
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cbaziotis/f8ffa9922081493418287ee14830e6e9 to your computer and use it in GitHub Desktop.
Save cbaziotis/f8ffa9922081493418287ee14830e6e9 to your computer and use it in GitHub Desktop.
def get_class_weights(y, smooth_factor=0):
"""
Returns the weights for each class based on the frequencies of the samples
:param smooth_factor: factor that smooths extremely uneven weights
:param y: list of true labels (the labels must be hashable)
:return: dictionary with the weight for each class
"""
counter = Counter(y)
if smooth_factor > 0:
p = max(counter.values()) * smooth_factor
for k in counter.keys():
counter[k] += p
majority = max(counter.values())
return {cls: float(majority / count) for cls, count in counter.items()}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment