Skip to content

Instantly share code, notes, and snippets.

@maxmatical
Created September 1, 2021 19:20
Show Gist options
  • Save maxmatical/49c32c96c24635243cccfc3164083767 to your computer and use it in GitHub Desktop.
Save maxmatical/49c32c96c24635243cccfc3164083767 to your computer and use it in GitHub Desktop.
from transformers import *
from fastai.text.all import *
from blurr.data.all import *
from blurr.modeling.all import *
import torch.nn as nn
class BatchLossFilter(Callback):
""" Callback that selects the hardest samples in every batch representing a percentage of the total loss"""
def __init__(self, loss_perc=1., schedule_func:Optional[callable]=None):
store_attr()
def before_fit(self):
self.run = not hasattr(self, "gather_preds")
if not(self.run): return
self.crit = self.learn.loss_func
if hasattr(self.crit, 'reduction'): self.red = self.crit.reduction
def before_batch(self):
if not self.training or self.loss_perc == 1.: return
with torch.no_grad():
if hasattr(self.crit, 'reduction'): setattr(self.crit, 'reduction', 'none')
self.losses = self.crit(self.learn.model(self.x), self.y)
if hasattr(self.crit, 'reduction'): setattr(self.crit, 'reduction', self.red)
self.losses /= self.losses.sum()
idxs = torch.argsort(self.losses, descending=True)
if self.schedule_func is not None: loss_perc = self.loss_perc * self.schedule_func(self.pct_train)
else: loss_perc = self.loss_perc
cut_idx = torch.argmax((self.losses[idxs].cumsum(0) > loss_perc).float())
idxs = idxs[:cut_idx]
self.learn.xb = tuple(xbi[idxs] for xbi in self.learn.xb)
self.learn.yb = tuple(ybi[idxs] for ybi in self.learn.yb)
def after_fit(self):
if hasattr(self.learn.loss_func, 'reduction'): setattr(self.learn.loss_func, 'reduction', self.red)
path = untar_data(URLs.IMDB_SAMPLE)
model_path = Path('models')
imdb_df = pd.read_csv(path/'texts.csv')
model_cls = AutoModelForSequenceClassification
pretrained_model_name = "distilbert-base-uncased"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, model_cls=model_cls)
blocks = (HF_TextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), CategoryBlock)
dblock = DataBlock(blocks=blocks, get_x=ColReader('text'), get_y=ColReader('label'), splitter=ColSplitter())
dls = dblock.dataloaders(imdb_df, bs=4)
model = HF_BaseModelWrapper(hf_model)
learn = Learner(dls,
model,
loss_func=LabelSmoothingCrossEntropyFlat(),
metrics=[accuracy],
cbs=[HF_BaseModelCallback],
splitter=hf_splitter).to_fp16()
learn.unfreeze()
cbs = [
BatchLossFilter(loss_perc=0.4)
]
learn.fit_one_cycle(
3,
lr_max=3e-5,
cbs = cbs
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment