Skip to content

Instantly share code, notes, and snippets.

@HamidShojanazeri
Created September 9, 2021 05:07
Show Gist options
  • Save HamidShojanazeri/ce29ede884936dc8204bde7691c87f73 to your computer and use it in GitHub Desktop.
Save HamidShojanazeri/ce29ede884936dc8204bde7691c87f73 to your computer and use it in GitHub Desktop.
training_code.py
import torch
import datasets
import transformers
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import Trainer
from nn_pruning.sparse_trainer import SparseTrainer
from nn_pruning.patch_coordinator import SparseTrainingArguments
from transformers import TrainingArguments
from transformers import AutoModelForSequenceClassification
from nn_pruning.patch_coordinator import ModelPatchingCoordinator
import numpy as np
from datasets import load_metric
from nn_pruning.inference_model_patcher import optimize_model
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__} and torch v{torch.__version__}")
boolq = load_dataset("super_glue", "boolq")
print(boolq['train'][0])
boolq.rename_column_("label", "labels")
bert_ckpt = "bert-base-uncased"
bert_teacher = "bert-large-uncased"
tokenizer = AutoTokenizer.from_pretrained(bert_ckpt)
def tokenize_and_encode(examples):
return tokenizer(examples['question'], examples['passage'], truncation="only_second")
boolq_enc = boolq.map(tokenize_and_encode, batched=True)
class PruningTrainer(SparseTrainer, Trainer):
def __init__(self, sparse_args, *args, **kwargs):
Trainer.__init__(self, *args, **kwargs)
SparseTrainer.__init__(self, sparse_args)
def compute_loss(self, model, inputs, return_outputs=False):
"""
We override the default loss in SparseTrainer because it throws an
error when run without distillation
"""
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
self.metrics["ce_loss"] += float(loss)
self.loss_counter += 1
return (loss, outputs) if return_outputs else loss
sparse_args = SparseTrainingArguments()
print(sparse_args)
hyperparams = {
"initial_warmup": 1,
"final_warmup": 3,
"initial_threshold": 1.0, # When using topK set to 1 (initial density). With sigmoied_threshold, use 0.0 (cutoff)
"final_threshold": 0.5, # When using topK, this is the final density. With sigmoied_threshold, use 0.1 (final cutoff, which is a bit arbitrary of course, set regularization_final_lambda to adjust final sparsity)
"dense_pruning_method": "topK:1d_alt", #"sigmoied_threshold:1d_alt",
"dense_block_rows":1,
"dense_block_cols":1,
"dense_lambda":0.25,
"attention_pruning_method": "topK", #"sigmoied_threshold",
"attention_block_rows":32,
"attention_block_cols":32,
"attention_lambda":1.0,
"ampere_pruning_method": "disabled",
"mask_init": "constant",
"mask_scale": 0.0,
"regularization": None, # "l1" when pruning_methods are sigmoied_threshold
"regularization_final_lambda": 20, # To be tweaked to adjust sparsity : the higher, the more sparse. Try different values by multiplying by 2x several times
"distil_teacher_name_or_path":None,
"distil_alpha_ce": 0.1,
"distil_alpha_teacher": 0.9,
"attention_output_with_dense": 0,
"layer_norm_patch" : 0,
"gelu_patch":0
}
for k,v in hyperparams.items():
if hasattr(sparse_args, k):
setattr(sparse_args, k, v)
else:
print(f"sparse_args does not have argument {k}")
print("******************************** sparse args *******************", sparse_args)
batch_size = 4
learning_rate = 2e-5
num_train_epochs = 1
logging_steps = len(boolq_enc["train"]) // batch_size
# warmup for 10% of training steps
warmup_steps = logging_steps * num_train_epochs * 0.1
args = TrainingArguments(
output_dir="checkpoints",
evaluation_strategy="epoch",
save_strategy="epoch",
num_train_epochs=num_train_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
learning_rate=learning_rate,
weight_decay=0.01,
logging_steps=logging_steps,
disable_tqdm=False,
report_to=None,
load_best_model_at_end=True,
metric_for_best_model="accuracy"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mpc = ModelPatchingCoordinator(
sparse_args=sparse_args,
device=device,
cache_dir="checkpoints",
logit_names="logits",
model_name_or_path= bert_teacher,
teacher_constructor=AutoModelForSequenceClassification)
bert_model = AutoModelForSequenceClassification.from_pretrained(bert_ckpt).to(device)
mpc.patch_model(bert_model)
bert_model.save_pretrained("models/patched")
accuracy_score = load_metric('accuracy')
def compute_metrics(pred):
predictions, labels = pred
predictions = np.argmax(predictions, axis=1)
return accuracy_score.compute(predictions=predictions, references=labels)
trainer = PruningTrainer(
sparse_args=sparse_args,
args=args,
model=bert_model,
train_dataset=boolq_enc["train"],
eval_dataset=boolq_enc["validation"],
tokenizer=tokenizer,
compute_metrics=compute_metrics,
# load_best_model_at_end=True,
# metric_for_best_model="accuracy"
)
trainer.set_patch_coordinator(mpc)
trainer.train()
output_model_path = "models/bert-base-uncased-finepruned-boolq-highly-sparse-v2"
trainer.save_model(output_model_path)
mpc.compile_model(trainer.model)
prunebert_model = optimize_model(trainer.model, "dense")
prunebert_model.num_parameters() / bert_model.num_parameters()
print(" ############# Parameter downsize ############# ", prunebert_model.num_parameters() / bert_model.num_parameters())
from time import perf_counter
def compute_latencies(model,
question="Is Saving Private Ryan based on a book?",
passage="""In 1994, Robert Rodat wrote the script for the film. Rodat’s script was submitted to
producer Mark Gordon, who liked it and in turn passed it along to Spielberg to direct. The film is
loosely based on the World War II life stories of the Niland brothers. A shooting date was set for
June 27, 1997"""):
inputs = tokenizer(question, passage, truncation="only_second", return_tensors="pt")
latencies = []
# Warmup
for _ in range(10):
_ = model(**inputs)
for _ in range(100):
start_time = perf_counter()
_ = model(**inputs)
latency = perf_counter() - start_time
latencies.append(latency)
# Compute run statistics
time_avg_ms = 1000 * np.mean(latencies)
time_std_ms = 1000 * np.std(latencies)
print(f"Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f}")
return {"time_avg_ms": time_avg_ms, "time_std_ms": time_std_ms}
latencies = {}
latencies["prunebert"] = compute_latencies(prunebert_model.to("cpu"))
print("############# Latency of the pruned model #############", latencies["prunebert"])
bert_unpruned = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased").to("cpu")
latencies["bert-base"] = compute_latencies(bert_unpruned.to("cpu"))
print("############# Latency of the pretrained model #############", latencies["bert-base"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment