Created
August 6, 2019 10:08
-
-
Save nicochidt/0f5fb5344fa915e3961b53c6009ba0e4 to your computer and use it in GitHub Desktop.
Labels relabelling on Fastai
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai import * | |
from fastai.vision import * | |
from pathlib import Path | |
import pandas as pd | |
# data boxes is a very simple file with the following lines (file,bb,class): | |
# img1.jpg,16 121 390 492,0 | |
# img2.jpg,396 1057 1450 2100,1 | |
# img3.jpg,39 213 519 673,0 | |
imgs = [] | |
img_file = open('data/boxes.csv', 'r') | |
for line in img_file: | |
en = line.strip().split(',') | |
imgs.append( (en[0], [int(i) for i in en[1].split()] , en[2])) | |
# let's create a function to get the labels following the COCO format ([ [ list of bb], [ list of labels] ]) | |
images = [ i[0] for i in imgs ] | |
lbl_bbox = [ [ [ i[1] ] , [ i[2] ] ] for i in imgs] | |
img2bbox = dict(zip(images, lbl_bbox)) | |
get_y_func = lambda o:img2bbox[o.name] | |
# Let's now crete the data bunch | |
data = (ObjectItemList.from_folder('data/imgs') | |
.split_by_rand_pct(seed=42) | |
.label_from_func(get_y_func) | |
.transform(None, size=224, tfm_y=True, resize_method=ResizeMethod.SQUISH) | |
.databunch(bs=3, collate_fn=bb_pad_collate, num_workers=0) | |
) | |
# if we show a batch in a notebook, we'll see the resized images with the right bounding boxes and classes being 0,1 | |
# Let's define a custom loss to train: | |
def detn_loss(input, target1, target2): | |
h,w = target2.shape | |
target2 = target2.reshape(h) # This reshape is needed to get the right dimensions for the target tensor | |
bb_t, c_t = target1, target2 | |
print("Target classes %s" % c_t) # I'm printing the target classes here, the problem will be evident running this | |
bb_i, c_i = input[:,:4], input[:, 4:] | |
bb_i = F.torch.sigmoid(bb_i) * 224 | |
return F.l1_loss(bb_i, bb_t) + F.cross_entropy(c_i, c_t) * 20 | |
# Lets create the model: | |
f_model = models.resnet34 | |
sz=224 | |
n_classes = 2 | |
head_reg = nn.Sequential( | |
Flatten(), | |
nn.Linear(25088, 4 + n_classes) | |
) | |
learn = cnn_learner(data, f_model, custom_head=head_reg) | |
learn.opt_fn = optim.Adam | |
learn.loss_func = detn_loss | |
# Let's run lr_find | |
learn.lr_find() | |
# lr_find prints the following output before failing: | |
# Target classes tensor([1, 1, 2]) | |
# For some reason, the labels were transformed from (0,1) to (1,2) and we get an error when computing | |
# the cross_entropy: RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Target classes tensor([1, 1, 2]) | |
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph. | |
--------------------------------------------------------------------------- | |
RuntimeError Traceback (most recent call last) | |
<ipython-input-31-1c9b4f5231f8> in <module> | |
7 learn.opt_fn = optim.Adam | |
8 learn.loss_func = detn_loss | |
----> 9 learn.lr_find() | |
~/fastai/lib/python3.6/site-packages/fastai/train.py in lr_find(learn, start_lr, end_lr, num_it, stop_div, wd) | |
30 cb = LRFinder(learn, start_lr, end_lr, num_it, stop_div) | |
31 epochs = int(np.ceil(num_it/len(learn.data.train_dl))) | |
---> 32 learn.fit(epochs, start_lr, callbacks=[cb], wd=wd) | |
33 | |
34 def to_fp16(learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=True, clip:float=None, | |
~/fastai/lib/python3.6/site-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks) | |
198 callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks) | |
199 if defaults.extra_callbacks is not None: callbacks += defaults.extra_callbacks | |
--> 200 fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks) | |
201 | |
202 def create_opt(self, lr:Floats, wd:Floats=0.)->None: | |
~/fastai/lib/python3.6/site-packages/fastai/basic_train.py in fit(epochs, learn, callbacks, metrics) | |
99 for xb,yb in progress_bar(learn.data.train_dl, parent=pbar): | |
100 xb, yb = cb_handler.on_batch_begin(xb, yb) | |
--> 101 loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler) | |
102 if cb_handler.on_batch_end(loss): break | |
103 | |
~/fastai/lib/python3.6/site-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler) | |
28 | |
29 if not loss_func: return to_detach(out), yb[0].detach() | |
---> 30 loss = loss_func(out, *yb) | |
31 | |
32 if opt is not None: | |
<ipython-input-30-23e359252d00> in detn_loss(input, target1, target2) | |
9 print(c_i.shape) | |
10 print(c_t.shape) | |
---> 11 return F.cross_entropy(c_i, c_t) * 20 | |
12 return F.l1_loss(bb_i, bb_t) #+ F.cross_entropy(c_i, c_t) * 20 | |
~/fastai/lib/python3.6/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction) | |
2054 if size_average is not None or reduce is not None: | |
2055 reduction = _Reduction.legacy_get_string(size_average, reduce) | |
-> 2056 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) | |
2057 | |
2058 | |
~/fastai/lib/python3.6/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction) | |
1869 .format(input.size(0), target.size(0))) | |
1870 if dim == 2: | |
-> 1871 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) | |
1872 elif dim == 4: | |
1873 ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index) | |
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed. at ../aten/src/THNN/generic/ClassNLLCriterion.c:92 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment