Skip to content

Instantly share code, notes, and snippets.

@minesh1291
Forked from codeKgu/model_loading.py
Created August 14, 2023 17:22
Show Gist options
  • Save minesh1291/0563985705c01ba0377ffc8a60313d91 to your computer and use it in GitHub Desktop.
Save minesh1291/0563985705c01ba0377ffc8a60313d91 to your computer and use it in GitHub Desktop.
Tutorial for multimodal_transformers
from multimodal_transformers.model import AutoModelWithTabular, TabularConfig
from transformers import AutoConfig
num_labels = len(np.unique(torch_dataset, labels))
config = AutoConfig.from_pretrained('bert-base-uncased')
tabular_config = TabularConfig(
num_labels=num_labels,
cat_feat_dim=torch_dataset.cat_feats.shape[1],
numerical_feat_dim=torch_dataset.numerical_feats.shape[1],
combine_feat_method='weighted_feature_sum_on_transformer_cat_and_numerical_feats',
)
config.tabular_config = tabular_config
model = AutoModelWithTabular.from_pretrained('bert-base-uncased', config=config)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment