Skip to content

Instantly share code, notes, and snippets.

@kusal1990
Created June 25, 2022 07:09
Show Gist options
  • Save kusal1990/b6deb603bfd944b29583ed6902253da4 to your computer and use it in GitHub Desktop.
Save kusal1990/b6deb603bfd944b29583ed6902253da4 to your computer and use it in GitHub Desktop.
from transformers import TFXLNetForMultipleChoice
easy_train_dict = {'input_tokens':easy_train_input_ids,
'attention_mask':easy_train_attention_mask}
viola = tf.data.Dataset.from_tensor_slices((easy_train_dict,tf.keras.utils.to_categorical(easy_train_labels.values)))
viola = viola.shuffle(32).batch(8).cache().prefetch(tf.data.experimental.AUTOTUNE)
easy_dev_dict = {'input_tokens':easy_dev_input_ids,
'attention_mask':easy_dev_attention_mask}
viola_dev = tf.data.Dataset.from_tensor_slices((easy_dev_dict,tf.keras.utils.to_categorical(easy_dev_labels.values, num_classes=5)))
viola_dev = viola_dev.shuffle(32).batch(8).cache().prefetch(tf.data.experimental.AUTOTUNE)
pre_trained_model = TFXLNetForMultipleChoice.from_pretrained('xlnet-base-cased')
model_input_ids = Input(shape=(5,128,), name='input_tokens', dtype='int32')
masks_input = Input(shape=(5,128,), name='attention_mask', dtype='int32')
x = {'input_ids':model_input_ids,
'attention_mask':masks_input}
x = pre_trained_model(x)['logits']
# x = Dense(64, activation='relu', kernel_initializer='he_normal')(x)
# x = Dropout(0.2)(x)
outputs = Dense(5, activation='softmax')(x)
model = Model(inputs=[model_input_ids, masks_input], outputs=outputs)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=5e-6),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=['accuracy'])
model.fit(viola, epochs=4)
@kusal1990
Copy link
Author

ok

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment