Created
September 9, 2021 05:07
-
-
Save HamidShojanazeri/ce29ede884936dc8204bde7691c87f73 to your computer and use it in GitHub Desktop.
training_code.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import datasets | |
import transformers | |
from datasets import load_dataset | |
from transformers import AutoTokenizer | |
from transformers import Trainer | |
from nn_pruning.sparse_trainer import SparseTrainer | |
from nn_pruning.patch_coordinator import SparseTrainingArguments | |
from transformers import TrainingArguments | |
from transformers import AutoModelForSequenceClassification | |
from nn_pruning.patch_coordinator import ModelPatchingCoordinator | |
import numpy as np | |
from datasets import load_metric | |
from nn_pruning.inference_model_patcher import optimize_model | |
datasets.logging.set_verbosity_error() | |
transformers.logging.set_verbosity_error() | |
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__} and torch v{torch.__version__}") | |
boolq = load_dataset("super_glue", "boolq") | |
print(boolq['train'][0]) | |
boolq.rename_column_("label", "labels") | |
bert_ckpt = "bert-base-uncased" | |
bert_teacher = "bert-large-uncased" | |
tokenizer = AutoTokenizer.from_pretrained(bert_ckpt) | |
def tokenize_and_encode(examples): | |
return tokenizer(examples['question'], examples['passage'], truncation="only_second") | |
boolq_enc = boolq.map(tokenize_and_encode, batched=True) | |
class PruningTrainer(SparseTrainer, Trainer): | |
def __init__(self, sparse_args, *args, **kwargs): | |
Trainer.__init__(self, *args, **kwargs) | |
SparseTrainer.__init__(self, sparse_args) | |
def compute_loss(self, model, inputs, return_outputs=False): | |
""" | |
We override the default loss in SparseTrainer because it throws an | |
error when run without distillation | |
""" | |
outputs = model(**inputs) | |
# Save past state if it exists | |
# TODO: this needs to be fixed and made cleaner later. | |
if self.args.past_index >= 0: | |
self._past = outputs[self.args.past_index] | |
# We don't use .loss here since the model may return tuples instead of ModelOutput. | |
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] | |
self.metrics["ce_loss"] += float(loss) | |
self.loss_counter += 1 | |
return (loss, outputs) if return_outputs else loss | |
sparse_args = SparseTrainingArguments() | |
print(sparse_args) | |
hyperparams = { | |
"initial_warmup": 1, | |
"final_warmup": 3, | |
"initial_threshold": 1.0, # When using topK set to 1 (initial density). With sigmoied_threshold, use 0.0 (cutoff) | |
"final_threshold": 0.5, # When using topK, this is the final density. With sigmoied_threshold, use 0.1 (final cutoff, which is a bit arbitrary of course, set regularization_final_lambda to adjust final sparsity) | |
"dense_pruning_method": "topK:1d_alt", #"sigmoied_threshold:1d_alt", | |
"dense_block_rows":1, | |
"dense_block_cols":1, | |
"dense_lambda":0.25, | |
"attention_pruning_method": "topK", #"sigmoied_threshold", | |
"attention_block_rows":32, | |
"attention_block_cols":32, | |
"attention_lambda":1.0, | |
"ampere_pruning_method": "disabled", | |
"mask_init": "constant", | |
"mask_scale": 0.0, | |
"regularization": None, # "l1" when pruning_methods are sigmoied_threshold | |
"regularization_final_lambda": 20, # To be tweaked to adjust sparsity : the higher, the more sparse. Try different values by multiplying by 2x several times | |
"distil_teacher_name_or_path":None, | |
"distil_alpha_ce": 0.1, | |
"distil_alpha_teacher": 0.9, | |
"attention_output_with_dense": 0, | |
"layer_norm_patch" : 0, | |
"gelu_patch":0 | |
} | |
for k,v in hyperparams.items(): | |
if hasattr(sparse_args, k): | |
setattr(sparse_args, k, v) | |
else: | |
print(f"sparse_args does not have argument {k}") | |
print("******************************** sparse args *******************", sparse_args) | |
batch_size = 4 | |
learning_rate = 2e-5 | |
num_train_epochs = 1 | |
logging_steps = len(boolq_enc["train"]) // batch_size | |
# warmup for 10% of training steps | |
warmup_steps = logging_steps * num_train_epochs * 0.1 | |
args = TrainingArguments( | |
output_dir="checkpoints", | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
num_train_epochs=num_train_epochs, | |
per_device_train_batch_size=batch_size, | |
per_device_eval_batch_size=batch_size, | |
learning_rate=learning_rate, | |
weight_decay=0.01, | |
logging_steps=logging_steps, | |
disable_tqdm=False, | |
report_to=None, | |
load_best_model_at_end=True, | |
metric_for_best_model="accuracy" | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
mpc = ModelPatchingCoordinator( | |
sparse_args=sparse_args, | |
device=device, | |
cache_dir="checkpoints", | |
logit_names="logits", | |
model_name_or_path= bert_teacher, | |
teacher_constructor=AutoModelForSequenceClassification) | |
bert_model = AutoModelForSequenceClassification.from_pretrained(bert_ckpt).to(device) | |
mpc.patch_model(bert_model) | |
bert_model.save_pretrained("models/patched") | |
accuracy_score = load_metric('accuracy') | |
def compute_metrics(pred): | |
predictions, labels = pred | |
predictions = np.argmax(predictions, axis=1) | |
return accuracy_score.compute(predictions=predictions, references=labels) | |
trainer = PruningTrainer( | |
sparse_args=sparse_args, | |
args=args, | |
model=bert_model, | |
train_dataset=boolq_enc["train"], | |
eval_dataset=boolq_enc["validation"], | |
tokenizer=tokenizer, | |
compute_metrics=compute_metrics, | |
# load_best_model_at_end=True, | |
# metric_for_best_model="accuracy" | |
) | |
trainer.set_patch_coordinator(mpc) | |
trainer.train() | |
output_model_path = "models/bert-base-uncased-finepruned-boolq-highly-sparse-v2" | |
trainer.save_model(output_model_path) | |
mpc.compile_model(trainer.model) | |
prunebert_model = optimize_model(trainer.model, "dense") | |
prunebert_model.num_parameters() / bert_model.num_parameters() | |
print(" ############# Parameter downsize ############# ", prunebert_model.num_parameters() / bert_model.num_parameters()) | |
from time import perf_counter | |
def compute_latencies(model, | |
question="Is Saving Private Ryan based on a book?", | |
passage="""In 1994, Robert Rodat wrote the script for the film. Rodat’s script was submitted to | |
producer Mark Gordon, who liked it and in turn passed it along to Spielberg to direct. The film is | |
loosely based on the World War II life stories of the Niland brothers. A shooting date was set for | |
June 27, 1997"""): | |
inputs = tokenizer(question, passage, truncation="only_second", return_tensors="pt") | |
latencies = [] | |
# Warmup | |
for _ in range(10): | |
_ = model(**inputs) | |
for _ in range(100): | |
start_time = perf_counter() | |
_ = model(**inputs) | |
latency = perf_counter() - start_time | |
latencies.append(latency) | |
# Compute run statistics | |
time_avg_ms = 1000 * np.mean(latencies) | |
time_std_ms = 1000 * np.std(latencies) | |
print(f"Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f}") | |
return {"time_avg_ms": time_avg_ms, "time_std_ms": time_std_ms} | |
latencies = {} | |
latencies["prunebert"] = compute_latencies(prunebert_model.to("cpu")) | |
print("############# Latency of the pruned model #############", latencies["prunebert"]) | |
bert_unpruned = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased").to("cpu") | |
latencies["bert-base"] = compute_latencies(bert_unpruned.to("cpu")) | |
print("############# Latency of the pretrained model #############", latencies["bert-base"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment