Skip to content

Instantly share code, notes, and snippets.

@manishghop
Created November 22, 2023 08:37
Show Gist options
  • Save manishghop/fb3ae898b4ea2c8c6bb404efac408c0e to your computer and use it in GitHub Desktop.
Save manishghop/fb3ae898b4ea2c8c6bb404efac408c0e to your computer and use it in GitHub Desktop.
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
compilation_prompt = "你好"
input_ids = tokenizer(compilation_prompt, return_tensors="pt").input_ids
print(input_ids,input_ids.shape)
input_id_len = len(input_ids[0])
input_ids = torch.tensor(input_ids).reshape([1, input_id_len])
inputs = (input_ids,)
vmfb_path = "qwen-7b-int4.vmfb"
device = "cpu"
mlir_dialect = "tm_tensor"
device_id = None
shark_module = SharkInference(
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
)
shark_module.load_module(vmfb_path)
print("Successfully Loaded vmfb model")
output = shark_module.forward(inputs)
print(output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment