Skip to content

Instantly share code, notes, and snippets.

@manishghop
Created November 22, 2023 08:36
Show Gist options
  • Save manishghop/d4e51bb95c491229c5d94b4c7ca04491 to your computer and use it in GitHub Desktop.
Save manishghop/d4e51bb95c491229c5d94b4c7ca04491 to your computer and use it in GitHub Desktop.
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
import torch_mlir
from transformers.generation import GenerationConfig
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from typing import List, Tuple
from pathlib import Path
from shark.shark_downloader import download_public_file
from shark.shark_importer import get_f16_inputs
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_inference import SharkInference
import io
from io import BytesIO
precision = "int4"
save_mlir = True
model_path = "qwen-7b-int4.mlir"
device = "cpu-task"
debug = True
# fmt: off
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
return [lhs[0], rhs[0]]
else:
raise ValueError("Input shapes not supported.")
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
quant〇matmul_rhs_group_quant〡shape,
quant〇matmul_rhs_group_quant〡dtype,
quant〇matmul_rhs_group_quant〡has_value_semantics]
# fmt: on
class Qwen(torch.nn.Module):
def __init__(self, precision):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
self.model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen-7B-Chat", trust_remote_code=True, **kwargs
)
print("Model before quantization: ", self.model)
if precision in ["int4", "int8"]:
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model,
dtype=torch.float16,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
print("Model after quantization: ", self.model)
def forward(self, input_ids):
output = self.model(input_ids).logits
return output
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
accumulates = torch.float16
weight_bit_width = 4
weight_group_size = 64
print("[DEBUG] generating mlir on device")
compilation_prompt = "你好"
input_ids = tokenizer(compilation_prompt, return_tensors="pt").input_ids
# input_ids, attention_mask = inputs.data["input_ids"], inputs.data["attention_mask"]
input_ids = torch.tensor(input_ids)
inputs = (input_ids,)
print(f"[DEBUG] generating torchscript graph")
is_f16 = precision in ["fp16", "int4"]
model = Qwen(precision)
ts_graph = import_with_fx(
model,
inputs,
is_f16=is_f16,
precision=precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
print(f"[DEBUG] Compiling torchscript graph")
module = torch_mlir.compile(
ts_graph,
inputs,
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] Lowering Torch -> Linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
print("[DEBUG] Successfully Generated mlir on device")
print(f"[DEBUG] converting to bytecode")
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
del module
f_ = open(model_path, "wb")
f_.write(bytecode)
print("Saved falcon mlir at ", str(model_path))
f_.close()
del bytecode
model_name = model_path.split(".")[0]
differentiator = ""
suffix = "vmfb"
vmfb_model_path = Path(f"{model_name}_{precision}_{device}{differentiator}.{suffix}")
print(f"Compiling for device : {device}")
shark_module = SharkInference(
mlir_module=model_path,
device=device,
mlir_dialect="tm_tensor",
device_idx=None,
)
path = shark_module.save_module(
"./",
"qwen",
extra_args=[],
debug=debug,
)
print("Saved vic vmfb at ", str(path))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment