Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active November 26, 2019 02:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/eec589d7e8dac6d5bc12ad8149e081b8 to your computer and use it in GitHub Desktop.
Save wassname/eec589d7e8dac6d5bc12ad8149e081b8 to your computer and use it in GitHub Desktop.
fastai binary class data block
from typing import Iterator, Collection
from fastai.data_block import CategoryListBase
from fastai.text import *
class BinaryProcessor(CategoryProcessor):
def create_classes(self, classes):
self.classes = classes
if classes is not None: self.c2i = {0:0, 1:1}
def generate_classes(self, items):
return [0]
class BinaryCategoryList(CategoryListBase):
"Basic `ItemList` for single classification labels."
_processor=BinaryProcessor
def __init__(self, items:Iterator, classes:Collection=None, label_delim:str=None, **kwargs):
super().__init__(items, classes=classes, **kwargs)
mean = self.items.mean()
if mean and mean != 0:
weight = torch.tensor([1 / mean]).cuda()
print(f'Weighting BCEWithLogitsFlat by {weight.item()}')
else:
weight = None
raise Exception('debug')
self.loss_func = BCEWithLogitsFlat(weight=weight)
def reconstruct(self, t):
return Category(t, self.c2i[t.item()])
def get(self, i):
o = self.items[i]
if o is None: return None
return Category(o, self.c2i[o])
def analyze_pred(self, pred, thresh:float=0.5): return pred.argmax()
import torch
from fastai.metrics import auc_roc_score
def auc_roc_score_multi(input, targ):
"""area under curve for multi category list (multiple bce losses)."""
n = input.shape[1]
scores = [auc_roc_score(input[:, i], targ[:, i]) for i in range(n)]
return torch.tensor(scores).mean()
def fbeta_binary(y_pred, y_true, **args):
return fbeta(y_pred[:, None], y_true[:, None], **args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment