Last active
February 19, 2019 07:00
-
-
Save andrijdavid/9930154d649f2174dd712a6708b9ec32 to your computer and use it in GitHub Desktop.
MultiTask API for fast.ai
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.basics import ItemBase, ItemList, Iterator, Collection, Callable, LabelList, Any, IntsOrStrs, listify, show_some, array, extract_kwargs | |
from fastai.vision import ImageDataBunch, ImageItemList | |
from fastai.data_block import PreProcessors, PreProcessor | |
class MultiLabelProcessor(PreProcessor): | |
def __init__(self, ds:Collection=None): | |
self.ref_ds = ds | |
self.procs_y = [] | |
def process(self, ds:Collection=None): | |
for y in ds.items: | |
p = y._processor | |
yp = None | |
if p is not None: | |
yp = p(ds=y) | |
self.procs_y.append(yp) | |
y.process(yp) | |
ds.c += y.c | |
ds.c_n.append(y.c) | |
ds.items = array(list(zip(*ds.items)), dtype=object) | |
class MultiLabel(ItemBase): | |
def __init__(self, label_list): | |
self.ll = label_list | |
@property | |
def data(self): | |
return tuple(l.data for l in self.ll) | |
@property | |
def obj(self): | |
return tuple(l.obj for l in self.ll) | |
def __repr__(self): | |
r = [l.__repr__() for l in self.ll] | |
return " - ".join(r) | |
def __str__(self): | |
r = [str(l) for l in self.ll] | |
return " - ".join(r) | |
class MultiTaskLabelList(ItemList): | |
_processor = MultiLabelProcessor | |
c:int = 0 | |
c_n:Iterator = [] | |
def __init__(self, items:Iterator, **kwargs): | |
super().__init__([], **kwargs) | |
self.items = items | |
def get(self, i): | |
o = self.items[i] | |
if o is None: return None | |
return MultiLabel(o) | |
@property | |
def classes(self): | |
return tuple(y.classes if hasattr(y, 'classes') else 1 for y in self.items) | |
@property | |
def c2i(self): | |
return tuple(y.c2i if hasattr(y, 'c2i') else None for y in self.items) | |
# ItemList | |
class MultiTaskImageItemList(ImageItemList): | |
task_n: int # Number of task | |
c_n: Iterator # Number of c per task | |
def label_from_df(self, cols:Collection[IntsOrStrs], labels_cls:Collection[Callable]=None, **kwargs): | |
task_n = len(cols) | |
self.task_n = task_n | |
self.copy_new.append('task_n') | |
if self.task_n == 1: | |
return super().label_from_df(cols=cols, **kwargs) | |
labels_cls = listify(labels_cls) | |
if len(cols) > len(labels_cls): | |
labels_cls += [None] * (len(cols) - len(labels_cls)) | |
y = [] | |
for i in range(len(cols)): | |
_, kwargs = extract_kwargs(['label_cls','labels_cls'], kwargs) | |
itemlist = super().label_from_df(cols=cols[i], label_cls=labels_cls[i], **kwargs) | |
y.append(itemlist.y) | |
y = MultiTaskLabelList(y) | |
return self._label_list(x=self, y=y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.vision.learner import cnn_config | |
from fastai.vision import * | |
from fastai.callbacks.hooks import num_features_model | |
class MultiHeadWrapper(nn.Module): | |
def __init__(self, head): | |
super().__init__() | |
self.pool = nn.Sequential(AdaptiveConcatPool2d(), Flatten()) | |
self.head = nn.ModuleList() | |
for h in head: | |
self.head.append(h[2:]) | |
def forward(self, x): | |
vec = self.pool(x) | |
return tuple(h(vec) for h in self.head) | |
def create_mtl_cnn(data:DataBunch, arch:Callable, cut:Union[int,Callable]=None, multiple_head:bool=True, pretrained:bool=True, | |
lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5, | |
custom_head:Optional[nn.Module]=None, split_on:Optional[SplitFuncOrIdxList]=None, | |
bn_final:bool=False, **kwargs:Any)->Learner: | |
"Build convnet style learners." | |
meta = cnn_config(arch) | |
body = create_body(arch, pretrained, cut) | |
nf = num_features_model(body) * 2 | |
head = custom_head or MultiHeadWrapper([create_head(nf, data.c_n[i], lin_ftrs, ps=ps, bn_final=bn_final) for i in range(0, data.task_n)]) if multiple_head else create_head(nf, data.c, lin_ftrs, ps=ps, bn_final=bn_final) | |
model = nn.Sequential(body, head) | |
learn = Learner(data, model, **kwargs) | |
learn.split(ifnone(split_on,meta['split'])) | |
if pretrained: learn.freeze() | |
apply_init(model[1], nn.init.kaiming_normal_) | |
return learn | |
def freeze_task(learner:Learner, nb:int): | |
assert(len(learner.model[1].head) > nb) | |
for l in learner.model[1].head[nb]: | |
requires_grad(l, False) | |
return learner | |
def unfreeze_task(learner:Learner, nb:int): | |
assert(len(learner.model[1].head) > nb) | |
for l in learner.model[1].head[nb]: | |
requires_grad(l, True) | |
return learner | |
Learner.freeze_task = freeze_task | |
Learner.unfreeze_task = unfreeze_task |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment