Last active
November 26, 2019 02:42
-
-
Save wassname/eec589d7e8dac6d5bc12ad8149e081b8 to your computer and use it in GitHub Desktop.
fastai binary class data block
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Iterator, Collection | |
from fastai.data_block import CategoryListBase | |
from fastai.text import * | |
class BinaryProcessor(CategoryProcessor): | |
def create_classes(self, classes): | |
self.classes = classes | |
if classes is not None: self.c2i = {0:0, 1:1} | |
def generate_classes(self, items): | |
return [0] | |
class BinaryCategoryList(CategoryListBase): | |
"Basic `ItemList` for single classification labels." | |
_processor=BinaryProcessor | |
def __init__(self, items:Iterator, classes:Collection=None, label_delim:str=None, **kwargs): | |
super().__init__(items, classes=classes, **kwargs) | |
mean = self.items.mean() | |
if mean and mean != 0: | |
weight = torch.tensor([1 / mean]).cuda() | |
print(f'Weighting BCEWithLogitsFlat by {weight.item()}') | |
else: | |
weight = None | |
raise Exception('debug') | |
self.loss_func = BCEWithLogitsFlat(weight=weight) | |
def reconstruct(self, t): | |
return Category(t, self.c2i[t.item()]) | |
def get(self, i): | |
o = self.items[i] | |
if o is None: return None | |
return Category(o, self.c2i[o]) | |
def analyze_pred(self, pred, thresh:float=0.5): return pred.argmax() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from fastai.metrics import auc_roc_score | |
def auc_roc_score_multi(input, targ): | |
"""area under curve for multi category list (multiple bce losses).""" | |
n = input.shape[1] | |
scores = [auc_roc_score(input[:, i], targ[:, i]) for i in range(n)] | |
return torch.tensor(scores).mean() | |
def fbeta_binary(y_pred, y_true, **args): | |
return fbeta(y_pred[:, None], y_true[:, None], **args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment