PyTorch module for classification or regression of categorical+continuous+text inputs. This module is based on library
class MyDataset(Dataset):
def __init__(self, cats, conts, texts, y, is_reg, is_multi, reverse_text=False):
n = len(cats[0]) if cats else len(conts[0])
self.cats = np.stack(cats, 1).astype(np.int64) if cats else np.zeros((n,1))
self.conts = np.stack(conts, 1).astype(np.float32) if conts else np.zeros((n,1))
self.texts = np.zeros((n,1)) if texts is None else np.array(texts)
self.y = np.zeros((n,1)) if y is None else np.array(y).reshape(-1, 1).astype(np.float32)
if is_reg:
self.y = self.y[:,None]
self.is_reg = is_reg
self.is_multi = is_multi
self.reverse_text = reverse_text
def __len__(self): return len(self.y)
def __getitem__(self, idx):
t = self.texts[idx]
if self.reverse_text:
t = list(reversed(t))
return [self.cats[idx], self.conts[idx], np.array(t), self.y[idx]]
def from_data_frame(cls, df, cat_flds, cont_flds, text_fld, y_fld, is_reg=True, is_multi=False):
cat_cols = [c.values for n,c in df[cat_flds].items()]
cont_cols = [c.values for n,c in df[cont_flds].items()]
return cls(cat_cols, cont_cols, df[text_fld], df[y_fld], is_reg, is_multi)
class MyModel(BasicModel):
def get_layer_groups(self):
return [m.rnn_enc, m.structured_model]
class MyLearner(Learner):
def __init__(self, data, models, **kwargs):
super().__init__(data, models, **kwargs)
def _get_crit(self, data):
if data.is_reg:
return F.mse_loss
elif data.is_multi:
return F.binary_cross_entropy
return F.nll_loss
def predict_array(self,x_cat,x_cont, text):
return to_np(self.model(to_gpu(V(T(x_cat))),to_gpu(V(T(x_cont))), to_gpu(V(T(text)))))
def summary(self):
x = [torch.ones(3,[1]).long(), torch.rand(3,[1])]
return model_summary(self.model, x)
def save_encoder(self, name): save_model(self.model[0], self.get_model_path(name))
def load_encoder(self, name): load_model(self.model[0], self.get_model_path(name))
def build_learner(trn_df, val_df, cat_flds, cont_flds, text_fld, y_fld):
trn_ds = MyDataset.from_data_frame(trn_df, cat_flds, cont_flds, text_fld, y_fld)
val_ds = MyDataset.from_data_frame(val_df, cat_flds, cont_flds, text_fld, y_fld)
trn_samp = SortishSampler(trn_df[text_fld], key=lambda x: len(trn_df[text_fld].iloc[x]), bs=bs//2)
val_samp = SortSampler(val_df[text_fld], key=lambda x: len(val_df[text_fld].iloc[x]))
trn_dl = DataLoader(trn_ds, bs//2, num_workers=1, pad_idx=1, sampler=trn_samp)
# return trn_dl
val_dl = DataLoader(val_ds, bs, num_workers=1, pad_idx=1, sampler=val_samp)
md = ModelData(PATH, trn_dl, val_dl)
model = RNN_Structured_regressor(text_bptt, text_max_seq, text_ntoken, text_emb_sz, text_n_hid,
text_n_layers, text_pad_token, struct_emb_szs, struct_n_cont, y_range,)
model_wrapper = MyModel(to_gpu(model))
return MyLearner(md, model_wrapper, opt_fn=optim.Adam)
from fastai.text import *
from fastai.structured import proc_df
import pandas as pd
import numpy as np
class MixedInputModelWithText(nn.Module):
def __init__(self, emb_szs, n_cont, emb_drop, out_sz, szs, drops,
y_range=None, use_bn=False, is_reg=True, is_multi=False, n_text=0):
for i, (c, s) in enumerate(emb_szs): assert c > 1, f"cardinality must be >=2, got emb_szs[{i}]: ({c},{s})"
if is_reg == False and is_multi == False: assert out_sz >= 2, "For classification with out_sz=1, use is_multi=True"
self.embs = nn.ModuleList([nn.Embedding(c, s) for c, s in emb_szs])
for emb in self.embs: self.emb_init(emb)
n_emb = sum(e.embedding_dim for e in self.embs)
self.n_emb, self.n_cont, self.text_emb_sz = n_emb, n_cont, n_text
szs = [n_emb + n_cont + n_text] + szs
self.lins = nn.ModuleList([
nn.Linear(szs[i], szs[i + 1]) for i in range(len(szs) - 1)])
self.bns = nn.ModuleList([
nn.BatchNorm1d(sz) for sz in szs[1:]])
for o in self.lins: kaiming_normal(
self.outp = nn.Linear(szs[-1], out_sz)
self.emb_drop = nn.Dropout(emb_drop)
self.drops = nn.ModuleList([nn.Dropout(drop) for drop in drops]) = nn.BatchNorm1d(n_cont)
self.use_bn, self.y_range = use_bn, y_range
self.is_reg = is_reg
self.is_multi = is_multi
def emb_init(self, x):
x =
sc = 2/(x.size(1)+1)
def forward(self, x_cat, x_cont, text_emb):
if self.n_emb != 0:
x1 = [e(x_cat[:, i]) for i, e in enumerate(self.embs)]
x1 =, 1)
x1 = self.emb_drop(x1)
x1 = torch.Tensor()
if self.text_emb_sz != 0:
x2 = self.emb_drop(text_emb)
x2 = torch.Tensor()
if self.n_cont != 0:
x3 =
all_xs = [x1, x2, x3]
all_xs = [cur_x for cur_x in all_xs if cur_x.nelement() != 0]
x =, 1)
for l, d, b in zip(self.lins, self.drops, self.bns):
x = F.relu(l(x))
if self.use_bn: x = b(x)
x = d(x)
x = self.outp(x)
if not self.is_reg:
if self.is_multi:
x = F.sigmoid(x)
x = F.log_softmax(x)
elif self.y_range:
x = F.sigmoid(x)
x = x * (self.y_range[1] - self.y_range[0])
x = x + self.y_range[0]
return x
class MyMultiBatchRNN(RNN_Encoder):
def __init__(self, bptt, max_seq, *args, **kwargs):
self.max_seq,self.bptt = max_seq,bptt
super().__init__(*args, **kwargs)
def concat(self, arrs):
return [[l[si] for l in arrs]) for si in range(len(arrs[0]))]
def forward(self, input):
sl,bs = input.size()
for l in self.hidden:
for h in l:
raw_outputs, outputs = [],[]
for i in range(0, sl, self.bptt):
r, o = super().forward(input[i: min(i+self.bptt, sl)])
if i>(sl-self.max_seq):
return self.concat(raw_outputs), self.concat(outputs)
class RNN_Structured_regressor(nn.Module):
def __init__(self, text_bptt, text_max_seq, text_ntoken, text_emb_sz, text_n_hid, text_n_layers, text_pad_token,
struct_emb_szs, struct_n_cont, y_range, struct_layers_szs=[1000,500]):
self.rnn_enc = MyMultiBatchRNN(bptt=text_bptt, max_seq=text_max_seq, ntoken=text_ntoken, emb_sz=text_emb_sz,
n_hid=text_n_hid, n_layers=text_n_layers, pad_token=text_pad_token,
dropouth=0.3, dropouti=0.65, dropoute=0.1, wdrop=0.5, qrnn=False)
# struct_emb_szs = [(text_emb_sz, text_emb_sz)] + struct_emb_szs
self.structured_model = MixedInputModelWithText(struct_emb_szs, struct_n_cont, emb_drop=0.04, out_sz=1,
szs=struct_layers_szs, drops=[0.001,0.01], y_range=y_range,
use_bn=False, is_reg=True, is_multi=False, n_text=text_emb_sz)
def forward(self, x_cat, x_cont, text_inp):
raw_outputs, outputs = self.rnn_enc(torch.t(text_inp))
encoded_text = outputs[-1][-1] # add max pooling afterwards
return self.structured_model(x_cat, x_cont, encoded_text)
