-
-
Save krsnewwave/f3aedeab649e1f7bd5a920d5dd04f3c4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DictTransform: | |
def __init__(self, cat_names, cont_names, label_names=None): | |
self.cats = cat_names | |
self.conts = cont_names | |
self.labels = label_names | |
def transform_with_label(self, batch): | |
cats = None | |
conts = None | |
batch, labels = batch | |
# take apart the batch and put together into subsets | |
if self.cats: | |
cats = self.create_stack(batch, self.cats) | |
if self.conts: | |
conts, _ = self.create_stack(batch, self.conts) | |
return cats, conts, labels | |
def create_stack(self, batch, target_columns): | |
columns = [] | |
mh_s = {} | |
for column_name in target_columns: | |
target = batch[column_name] | |
if isinstance(target, torch.Tensor): | |
if target.is_sparse: | |
mh_s[column_name] = target | |
else: | |
columns.append(target) | |
# if not a tensor, must be tuple | |
else: | |
# multihot column type, appending tuple representation | |
mh_s[column_name] = target | |
if columns: | |
if len(columns) > 1: | |
# concatenate categoricals -- converting the lists above to a torch tensor | |
# batch x n_categoricals | |
columns = torch.cat(columns, 1) | |
else: | |
columns = columns[0].unsqueeze(1) | |
return columns, mh_s |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment