Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save leslie-fang-intel/696041fa7e7352ecb985b04a5e1188de to your computer and use it in GitHub Desktop.
Save leslie-fang-intel/696041fa7e7352ecb985b04a5e1188de to your computer and use it in GitHub Desktop.
import torch
import transformers
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
EvalPrediction,
HfArgumentParser,
PretrainedConfig,
Trainer,
TrainingArguments,
default_data_collator,
set_seed,
)
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, prepare_qat_pt2e, convert_pt2e
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch._export import capture_pre_autograd_graph
from torch.export import Dim
torch._inductor.config.freezing = True
torch._inductor.config.cpp.enable_kernel_profile=True
if __name__ == "__main__":
config = AutoConfig.from_pretrained(
'albert-base-v1',
num_labels=2,
finetuning_task='mrpc',
cache_dir=None,
revision='main',
token=None,
trust_remote_code=False,
)
tokenizer = AutoTokenizer.from_pretrained(
'albert-base-v1',
cache_dir=None,
use_fast=True,
revision='main',
token=None,
trust_remote_code=False,
)
model = AutoModelForSequenceClassification.from_pretrained(
'albert-base-v1',
from_tf=False,
config=config,
cache_dir=None,
revision='main',
token=None,
trust_remote_code=False,
ignore_mismatched_sizes=False,
)
quantizer = xiq.X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config(is_dynamic=True))
quantizer._set_aten_operator_qconfig(torch.ops.aten.matmul.default, quantizer.global_config)
# example_inputs = dict(next(iter(self._dataloaders[0])))
example_inputs = {
}
example_inputs["labels"] = torch.randint(0, 2, (64,), dtype=torch.int64)
# torch.tensor(
# [1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1,
# 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0,
# 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1]
# )
example_inputs["input_ids"] = torch.randint(0, 1000, (64, 16), dtype=torch.int64)
example_inputs["token_type_ids"] = torch.randint(0, 2, (64, 16), dtype=torch.int64)
example_inputs["attention_mask"] = torch.randint(0, 2, (64, 16), dtype=torch.int64)
input_shapes = {k: list(v.shape) for (k, v) in example_inputs.items()}
print("input_shapes is: {}".format(input_shapes), flush=True)
dims = set()
for _, v in input_shapes.items():
dims.update(v)
dim_str_map = {x: Dim("dim" + str(list(dims).index(x))) for x in dims}
dynamic_shapes = {k: {v.index(dim): dim_str_map[dim] for dim in v} for (k, v) in input_shapes.items()}
if "labels" in dynamic_shapes.keys():
for k in dynamic_shapes.keys():
if k != "labels":
tmp_dims = input_shapes[k]
for tmp_dim in input_shapes[k]:
if tmp_dim not in input_shapes['labels']:
del dynamic_shapes[k][input_shapes[k].index(tmp_dim)]
print("dynamic_shapes is: {}".format(dynamic_shapes), flush=True)
with torch.no_grad():
print("======================= export model ===============================")
model = model.eval()
# breakpoint()
exported_model = capture_pre_autograd_graph(
model,
(),
kwargs=example_inputs,
dynamic_shapes=dynamic_shapes,
)
print("---- exported_model is: {}".format(exported_model), flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment