-
-
Save manishghop/d4e51bb95c491229c5d94b4c7ca04491 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM | |
import torch | |
import torch_mlir | |
from transformers.generation import GenerationConfig | |
from torch_mlir.compiler_utils import run_pipeline_with_repro_report | |
from typing import List, Tuple | |
from pathlib import Path | |
from shark.shark_downloader import download_public_file | |
from shark.shark_importer import get_f16_inputs | |
from shark.shark_importer import import_with_fx, save_mlir | |
from shark.shark_inference import SharkInference | |
import io | |
from io import BytesIO | |
precision = "int4" | |
save_mlir = True | |
model_path = "qwen-7b-int4.mlir" | |
device = "cpu-task" | |
debug = True | |
# fmt: off | |
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: | |
if len(lhs) == 3 and len(rhs) == 2: | |
return [lhs[0], lhs[1], rhs[0]] | |
elif len(lhs) == 2 and len(rhs) == 2: | |
return [lhs[0], rhs[0]] | |
else: | |
raise ValueError("Input shapes not supported.") | |
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: | |
# output dtype is the dtype of the lhs float input | |
lhs_rank, lhs_dtype = lhs_rank_dtype | |
return lhs_dtype | |
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: | |
return | |
brevitas_matmul_rhs_group_quant_library = [ | |
quant〇matmul_rhs_group_quant〡shape, | |
quant〇matmul_rhs_group_quant〡dtype, | |
quant〇matmul_rhs_group_quant〡has_value_semantics] | |
# fmt: on | |
class Qwen(torch.nn.Module): | |
def __init__(self, precision): | |
super().__init__() | |
kwargs = {"torch_dtype": torch.float32} | |
self.model = AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen-7B-Chat", trust_remote_code=True, **kwargs | |
) | |
print("Model before quantization: ", self.model) | |
if precision in ["int4", "int8"]: | |
from brevitas_examples.common.generative.quantize import ( | |
quantize_model, | |
) | |
from brevitas_examples.llm.llm_quant.run_utils import ( | |
get_model_impl, | |
) | |
print("Applying weight quantization..") | |
weight_bit_width = 4 if precision == "int4" else 8 | |
quantize_model( | |
self.model, | |
dtype=torch.float16, | |
weight_bit_width=weight_bit_width, | |
weight_param_method="stats", | |
weight_scale_precision="float_scale", | |
weight_quant_type="asym", | |
weight_quant_granularity="per_group", | |
weight_group_size=weight_group_size, | |
quantize_weight_zero_point=False, | |
) | |
print("Weight quantization applied.") | |
print("Model after quantization: ", self.model) | |
def forward(self, input_ids): | |
output = self.model(input_ids).logits | |
return output | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) | |
accumulates = torch.float16 | |
weight_bit_width = 4 | |
weight_group_size = 64 | |
print("[DEBUG] generating mlir on device") | |
compilation_prompt = "你好" | |
input_ids = tokenizer(compilation_prompt, return_tensors="pt").input_ids | |
# input_ids, attention_mask = inputs.data["input_ids"], inputs.data["attention_mask"] | |
input_ids = torch.tensor(input_ids) | |
inputs = (input_ids,) | |
print(f"[DEBUG] generating torchscript graph") | |
is_f16 = precision in ["fp16", "int4"] | |
model = Qwen(precision) | |
ts_graph = import_with_fx( | |
model, | |
inputs, | |
is_f16=is_f16, | |
precision=precision, | |
f16_input_mask=[False, False], | |
mlir_type="torchscript", | |
) | |
print(f"[DEBUG] Compiling torchscript graph") | |
module = torch_mlir.compile( | |
ts_graph, | |
inputs, | |
output_type=torch_mlir.OutputType.TORCH, | |
backend_legal_ops=["quant.matmul_rhs_group_quant"], | |
extra_library=brevitas_matmul_rhs_group_quant_library, | |
use_tracing=False, | |
verbose=False, | |
) | |
print(f"[DEBUG] Lowering Torch -> Linalg") | |
run_pipeline_with_repro_report( | |
module, | |
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", | |
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", | |
) | |
print("[DEBUG] Successfully Generated mlir on device") | |
print(f"[DEBUG] converting to bytecode") | |
bytecode_stream = BytesIO() | |
module.operation.write_bytecode(bytecode_stream) | |
bytecode = bytecode_stream.getvalue() | |
del module | |
f_ = open(model_path, "wb") | |
f_.write(bytecode) | |
print("Saved falcon mlir at ", str(model_path)) | |
f_.close() | |
del bytecode | |
model_name = model_path.split(".")[0] | |
differentiator = "" | |
suffix = "vmfb" | |
vmfb_model_path = Path(f"{model_name}_{precision}_{device}{differentiator}.{suffix}") | |
print(f"Compiling for device : {device}") | |
shark_module = SharkInference( | |
mlir_module=model_path, | |
device=device, | |
mlir_dialect="tm_tensor", | |
device_idx=None, | |
) | |
path = shark_module.save_module( | |
"./", | |
"qwen", | |
extra_args=[], | |
debug=debug, | |
) | |
print("Saved vic vmfb at ", str(path)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment