-
-
Save maxmatical/49c32c96c24635243cccfc3164083767 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import * | |
from fastai.text.all import * | |
from blurr.data.all import * | |
from blurr.modeling.all import * | |
import torch.nn as nn | |
class BatchLossFilter(Callback): | |
""" Callback that selects the hardest samples in every batch representing a percentage of the total loss""" | |
def __init__(self, loss_perc=1., schedule_func:Optional[callable]=None): | |
store_attr() | |
def before_fit(self): | |
self.run = not hasattr(self, "gather_preds") | |
if not(self.run): return | |
self.crit = self.learn.loss_func | |
if hasattr(self.crit, 'reduction'): self.red = self.crit.reduction | |
def before_batch(self): | |
if not self.training or self.loss_perc == 1.: return | |
with torch.no_grad(): | |
if hasattr(self.crit, 'reduction'): setattr(self.crit, 'reduction', 'none') | |
self.losses = self.crit(self.learn.model(self.x), self.y) | |
if hasattr(self.crit, 'reduction'): setattr(self.crit, 'reduction', self.red) | |
self.losses /= self.losses.sum() | |
idxs = torch.argsort(self.losses, descending=True) | |
if self.schedule_func is not None: loss_perc = self.loss_perc * self.schedule_func(self.pct_train) | |
else: loss_perc = self.loss_perc | |
cut_idx = torch.argmax((self.losses[idxs].cumsum(0) > loss_perc).float()) | |
idxs = idxs[:cut_idx] | |
self.learn.xb = tuple(xbi[idxs] for xbi in self.learn.xb) | |
self.learn.yb = tuple(ybi[idxs] for ybi in self.learn.yb) | |
def after_fit(self): | |
if hasattr(self.learn.loss_func, 'reduction'): setattr(self.learn.loss_func, 'reduction', self.red) | |
path = untar_data(URLs.IMDB_SAMPLE) | |
model_path = Path('models') | |
imdb_df = pd.read_csv(path/'texts.csv') | |
model_cls = AutoModelForSequenceClassification | |
pretrained_model_name = "distilbert-base-uncased" | |
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, model_cls=model_cls) | |
blocks = (HF_TextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), CategoryBlock) | |
dblock = DataBlock(blocks=blocks, get_x=ColReader('text'), get_y=ColReader('label'), splitter=ColSplitter()) | |
dls = dblock.dataloaders(imdb_df, bs=4) | |
model = HF_BaseModelWrapper(hf_model) | |
learn = Learner(dls, | |
model, | |
loss_func=LabelSmoothingCrossEntropyFlat(), | |
metrics=[accuracy], | |
cbs=[HF_BaseModelCallback], | |
splitter=hf_splitter).to_fp16() | |
learn.unfreeze() | |
cbs = [ | |
BatchLossFilter(loss_perc=0.4) | |
] | |
learn.fit_one_cycle( | |
3, | |
lr_max=3e-5, | |
cbs = cbs | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment