Skip to content

Instantly share code, notes, and snippets.

@tarekziade
Created February 22, 2024 09:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tarekziade/0304ea29f9101f219fd77a4918ef0347 to your computer and use it in GitHub Desktop.
Save tarekziade/0304ea29f9101f219fd77a4918ef0347 to your computer and use it in GitHub Desktop.
t5 distillation with bert-squeeze
from bert_squeeze.assistants import DistilAssistant
from lightning.pytorch import Trainer
config_assistant = {
"teacher_kwargs": {
"pretrained_model": "cnicu/t5-small-booksum",
},
"student_kwargs": {
"pretrained_model": "cnicu/t5-small-booksum",
},
"data_kwargs": {
"teacher_module": {
"dataset_config": {
"path": "kmfoda/booksum",
"target_col": "summary_text",
"source_col": "chapter",
}
}
},
"callbacks": [
{"_target_": "bert_squeeze.utils.callbacks.quantization.DynamicQuantization"},
],
}
assistant = DistilAssistant("distil-seq2seq", **config_assistant)
model = assistant.model
callbacks = assistant.callbacks
train_dataloader = assistant.data.train_dataloader()
test_dataloader = assistant.data.test_dataloader()
basic_trainer = Trainer(
max_epochs=1,
callbacks=callbacks,
)
basic_trainer.fit(
model=model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment