Skip to content

Instantly share code, notes, and snippets.

@andrijdavid
Last active February 19, 2019 07:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save andrijdavid/9930154d649f2174dd712a6708b9ec32 to your computer and use it in GitHub Desktop.
Save andrijdavid/9930154d649f2174dd712a6708b9ec32 to your computer and use it in GitHub Desktop.
MultiTask API for fast.ai
from fastai.basics import ItemBase, ItemList, Iterator, Collection, Callable, LabelList, Any, IntsOrStrs, listify, show_some, array, extract_kwargs
from fastai.vision import ImageDataBunch, ImageItemList
from fastai.data_block import PreProcessors, PreProcessor
class MultiLabelProcessor(PreProcessor):
def __init__(self, ds:Collection=None):
self.ref_ds = ds
self.procs_y = []
def process(self, ds:Collection=None):
for y in ds.items:
p = y._processor
yp = None
if p is not None:
yp = p(ds=y)
self.procs_y.append(yp)
y.process(yp)
ds.c += y.c
ds.c_n.append(y.c)
ds.items = array(list(zip(*ds.items)), dtype=object)
class MultiLabel(ItemBase):
def __init__(self, label_list):
self.ll = label_list
@property
def data(self):
return tuple(l.data for l in self.ll)
@property
def obj(self):
return tuple(l.obj for l in self.ll)
def __repr__(self):
r = [l.__repr__() for l in self.ll]
return " - ".join(r)
def __str__(self):
r = [str(l) for l in self.ll]
return " - ".join(r)
class MultiTaskLabelList(ItemList):
_processor = MultiLabelProcessor
c:int = 0
c_n:Iterator = []
def __init__(self, items:Iterator, **kwargs):
super().__init__([], **kwargs)
self.items = items
def get(self, i):
o = self.items[i]
if o is None: return None
return MultiLabel(o)
@property
def classes(self):
return tuple(y.classes if hasattr(y, 'classes') else 1 for y in self.items)
@property
def c2i(self):
return tuple(y.c2i if hasattr(y, 'c2i') else None for y in self.items)
# ItemList
class MultiTaskImageItemList(ImageItemList):
task_n: int # Number of task
c_n: Iterator # Number of c per task
def label_from_df(self, cols:Collection[IntsOrStrs], labels_cls:Collection[Callable]=None, **kwargs):
task_n = len(cols)
self.task_n = task_n
self.copy_new.append('task_n')
if self.task_n == 1:
return super().label_from_df(cols=cols, **kwargs)
labels_cls = listify(labels_cls)
if len(cols) > len(labels_cls):
labels_cls += [None] * (len(cols) - len(labels_cls))
y = []
for i in range(len(cols)):
_, kwargs = extract_kwargs(['label_cls','labels_cls'], kwargs)
itemlist = super().label_from_df(cols=cols[i], label_cls=labels_cls[i], **kwargs)
y.append(itemlist.y)
y = MultiTaskLabelList(y)
return self._label_list(x=self, y=y)
from fastai.vision.learner import cnn_config
from fastai.vision import *
from fastai.callbacks.hooks import num_features_model
class MultiHeadWrapper(nn.Module):
def __init__(self, head):
super().__init__()
self.pool = nn.Sequential(AdaptiveConcatPool2d(), Flatten())
self.head = nn.ModuleList()
for h in head:
self.head.append(h[2:])
def forward(self, x):
vec = self.pool(x)
return tuple(h(vec) for h in self.head)
def create_mtl_cnn(data:DataBunch, arch:Callable, cut:Union[int,Callable]=None, multiple_head:bool=True, pretrained:bool=True,
lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5,
custom_head:Optional[nn.Module]=None, split_on:Optional[SplitFuncOrIdxList]=None,
bn_final:bool=False, **kwargs:Any)->Learner:
"Build convnet style learners."
meta = cnn_config(arch)
body = create_body(arch, pretrained, cut)
nf = num_features_model(body) * 2
head = custom_head or MultiHeadWrapper([create_head(nf, data.c_n[i], lin_ftrs, ps=ps, bn_final=bn_final) for i in range(0, data.task_n)]) if multiple_head else create_head(nf, data.c, lin_ftrs, ps=ps, bn_final=bn_final)
model = nn.Sequential(body, head)
learn = Learner(data, model, **kwargs)
learn.split(ifnone(split_on,meta['split']))
if pretrained: learn.freeze()
apply_init(model[1], nn.init.kaiming_normal_)
return learn
def freeze_task(learner:Learner, nb:int):
assert(len(learner.model[1].head) > nb)
for l in learner.model[1].head[nb]:
requires_grad(l, False)
return learner
def unfreeze_task(learner:Learner, nb:int):
assert(len(learner.model[1].head) > nb)
for l in learner.model[1].head[nb]:
requires_grad(l, True)
return learner
Learner.freeze_task = freeze_task
Learner.unfreeze_task = unfreeze_task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment