Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Created June 16, 2022 16:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krsnewwave/f3aedeab649e1f7bd5a920d5dd04f3c4 to your computer and use it in GitHub Desktop.
Save krsnewwave/f3aedeab649e1f7bd5a920d5dd04f3c4 to your computer and use it in GitHub Desktop.
class DictTransform:
def __init__(self, cat_names, cont_names, label_names=None):
self.cats = cat_names
self.conts = cont_names
self.labels = label_names
def transform_with_label(self, batch):
cats = None
conts = None
batch, labels = batch
# take apart the batch and put together into subsets
if self.cats:
cats = self.create_stack(batch, self.cats)
if self.conts:
conts, _ = self.create_stack(batch, self.conts)
return cats, conts, labels
def create_stack(self, batch, target_columns):
columns = []
mh_s = {}
for column_name in target_columns:
target = batch[column_name]
if isinstance(target, torch.Tensor):
if target.is_sparse:
mh_s[column_name] = target
else:
columns.append(target)
# if not a tensor, must be tuple
else:
# multihot column type, appending tuple representation
mh_s[column_name] = target
if columns:
if len(columns) > 1:
# concatenate categoricals -- converting the lists above to a torch tensor
# batch x n_categoricals
columns = torch.cat(columns, 1)
else:
columns = columns[0].unsqueeze(1)
return columns, mh_s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment