Skip to content

Instantly share code, notes, and snippets.

@movefast
Last active December 12, 2018 10:29
Show Gist options
  • Save movefast/b9a774d434c6fd6eb87b2ddcab3e32ba to your computer and use it in GitHub Desktop.
Save movefast/b9a774d434c6fd6eb87b2ddcab3e32ba to your computer and use it in GitHub Desktop.
A random sampler weighted on prev batch losses using fastai library
from torch.utils.data.sampler import Sampler
from torch.utils.data.sampler import RandomSampler
class WeightedLossSampler(Sampler, Callback):
def __init__(self, data_source, replacement=True, bs=64, init_fac=1, ls_init_fac=1e-2):
self.num_samples = len(data_source)
self.weights = to_gpu(torch.ones(self.num_samples)*init_fac)
self.replacement = replacement
self.i = 0
self.bs = bs
self.ls_param = 0
self.init_fac, self.ls_init_fac = init_fac, ls_init_fac
def __iter__(self):
self.idxes = to_gpu(torch.multinomial(self.weights.add(self.ls_param), self.num_samples, self.replacement))
return iter(self.idxes)
def __len__(self):
return self.num_samples
def set_weights_per_batch(self, wts):
end = min(self.i+self.bs, self.num_samples)
self.weights[self.idxes[self.i:end]] = wts
self.i += self.bs
def on_epoch_end(self, metrics):
if not hasattr(self, 'prev_loss'):
self.prev_loss = metrics[0][0]
self.ls_param = self.prev_loss * self.ls_init_fac
else:
cur_loss = metrics[0][0]
# assume normal learning curve
ls_fac = np.exp((cur_loss - self.prev_loss) / self.prev_loss)
self.ls_param = self.ls_param * ls_fac
self.prev_loss = cur_loss
self.i = 0
def on_batch_end(self, raw_losses):
self.set_weights_per_batch(raw_losses.data)
class SortishSampler(Sampler, Callback):
def __init__(self, data_source, bs):
self.data_source,self.bs = data_source,bs
self.i = 0
self.num_samples = len(self.data_source)
self.weights = to_gpu(torch.ones(self.num_samples))
def __len__(self): return len(self.data_source)
def __iter__(self):
idxs = np.random.permutation(len(self.data_source))
sz = self.bs*50
ck_idx = [idxs[i:i+sz] for i in range(0, len(idxs), sz)]
weights = to_np(self.weights)
sort_idx = sum([sorted(s, key=lambda x: [x], reverse=True) for s in ck_idx], [])
sz = self.bs
self.idxes = torch.LongTensor(np.array(sort_idx)).cuda()
return iter(self.idxes)
def set_weights_per_batch(self, wts):
end = min(self.i+self.bs, self.num_samples)
self.weights[self.idxes[self.i:end]] = wts
if end == self.num_samples:
self.i = 0
else:
self.i += self.bs
def on_batch_end(self, raw_losses):
self.set_weights_per_batch(raw_losses.data)
class ImageClassifierData(ImageData):
def __init__(self, path, datasets, bs, num_workers, classes):
trn_ds,val_ds,fix_ds,aug_ds,test_ds,test_aug_ds = datasets
self.path,self.bs,self.num_workers,self.classes = path,bs,num_workers,classes
# self.our_sampler = RandomSampler(trn_ds)
self.our_sampler = WeightedLossSampler(trn_ds, replacement=True, bs=64)
# self.our_sampler = SortishSampler(trn_ds, bs=64)
self.trn_dl = self.get_dl(trn_ds,False, self.our_sampler)
self.val_dl, self.fix_dl,self.aug_dl,self.test_dl,self.test_aug_dl = [
self.get_dl(ds,shuf) for ds,shuf in [
(val_ds,False),(fix_ds,False),(aug_ds,False),
(test_ds,False),(test_aug_ds,False)
]
]
@classmethod
def from_arrays(cls, path, trn, val, bs=64, tfms=(None,None), classes=None, num_workers=4, test=None):
datasets = cls.get_ds(ArraysIndexDataset, trn, val, tfms, test=test)
return cls(path, datasets, bs, num_workers, classes=classes)
def get_dl(self, ds, shuffle, sampler=None):
if ds is None: return None
return DataLoader(ds, batch_size=self.bs, shuffle=shuffle,
num_workers=self.num_workers, pin_memory=False, sampler=sampler)
diff --git a/fastai/model.py b/fastai/model.py
index a9b41bc..4f82b6d 100644
--- a/fastai/model.py
+++ b/fastai/model.py
@@ -48,8 +48,8 @@ class Stepper():
output = self.m(*xs)
if isinstance(output,tuple): output,*xtra = output
if self.fp16: self.m.zero_grad()
- else: self.opt.zero_grad()
- loss = raw_loss = self.crit(output, y)
+ else: self.opt.zero_grad()
+ loss = raw_loss = torch.mean(self.crit(output, y))
if self.loss_scale != 1: assert(self.fp16); loss = loss*self.loss_scale
if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss)
loss.backward()
@@ -64,10 +64,32 @@ class Stepper():
torch.cuda.synchronize()
return torch_item(raw_loss.data)
+ def step_with_raw_loss(self, xs, y, epoch):
+ xtra = []
+ output = self.m(*xs)
+ if isinstance(output,tuple): output,*xtra = output
+ if self.fp16: self.m.zero_grad()
+ else: self.opt.zero_grad()
+ raw_loss = raw_loss_1 = self.crit(output, y)
+ loss = raw_loss = torch.mean(raw_loss)
+ if self.loss_scale != 1: assert(self.fp16); loss = loss*self.loss_scale
+ if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss)
+ loss.backward()
+ if self.fp16: update_fp32_grads(self.fp32_params, self.m)
+ if self.loss_scale != 1:
+ for param in self.fp32_params: param.grad.data.div_(self.loss_scale)
+ if self.clip: # Gradient clipping
+ nn.utils.clip_grad_norm(trainable_params_(self.m), self.clip)
+ self.opt.step()
+ if self.fp16:
+ copy_fp32_to_model(self.m, self.fp32_params)
+ torch.cuda.synchronize()
+ return torch_item(raw_loss.data), raw_loss_1
+
def evaluate(self, xs, y):
preds = self.m(*xs)
if isinstance(preds,tuple): preds=preds[0]
- return preds, self.crit(preds, y)
+ return preds, torch.mean(self.crit(preds, y))
def set_train_mode(m):
if (hasattr(m, 'running_mean') and (getattr(m,'bn_freeze',False)
@@ -125,13 +147,13 @@ def fit(model, data, n_epochs, opt, crit, metrics=None, callbacks=None, stepper=
for (*x,y) in t:
batch_num += 1
for cb in callbacks: cb.on_batch_begin()
- loss = model_stepper.step(V(x),V(y), epoch)
+ loss, raw_losses = model_stepper.step_with_raw_loss(V(x),V(y), epoch)
avg_loss = avg_loss * avg_mom + loss * (1-avg_mom)
debias_loss = avg_loss / (1 - avg_mom**batch_num)
t.set_postfix(loss=debias_loss)
stop=False
los = debias_loss if not all_val else [debias_loss] + validate_next(model_stepper,metrics, val_iter)
- for cb in callbacks: stop = stop or cb.on_batch_end(los)
+ for cb in callbacks: stop = stop or cb.on_batch_end(raw_losses)
if stop: return
if batch_num >= cnt_phases[phase]:
for cb in callbacks: cb.on_phase_end()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment