Skip to content

Instantly share code, notes, and snippets.

@mht-sharma
Created April 4, 2024 14:40
Show Gist options
  • Save mht-sharma/290f7bf9052e92023b4136c6fefd6717 to your computer and use it in GitHub Desktop.
Save mht-sharma/290f7bf9052e92023b4136c6fefd6717 to your computer and use it in GitHub Desktop.
Llava optimum ONNX inference
import os
from typing import List, Optional, Tuple
import onnxruntime as onnxrt
import requests
import torch
from PIL import Image
from transformers import AutoConfig, AutoProcessor, GenerationConfig, PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
from optimum.utils import NormalizedConfigManager
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
device = torch.device("cpu")
model_name = "llava-1.5-7b-hf/"
processor = AutoProcessor.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors="pt")
class ORTModel(torch.nn.Module):
def __init__(self, path, config):
super().__init__()
self._device = device
self.config = config
self.session = onnxrt.InferenceSession(path, providers=["CPUExecutionProvider"])
self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}
class ORTEncoder(ORTModel):
def forward(
self,
input_ids: torch.FloatTensor,
pixel_values: torch.FloatTensor,
attention_mask: torch.LongTensor,
**kwargs,
) -> BaseModelOutput:
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
"pixel_values": pixel_values.cpu().detach().numpy(),
"attention_mask": attention_mask.cpu().detach().numpy(),
}
# Run inference
outputs = self.session.run(None, onnx_inputs)
for i, output in enumerate(outputs):
outputs[i] = torch.from_numpy(output).to(self._device)
return (
outputs[self.output_names["inputs_embeds"]],
outputs[self.output_names["decoder_attention_mask"]],
outputs[self.output_names["position_ids"]],
)
class ORTDecoderProcessor(ORTModel):
def forward(
self,
input_ids: torch.FloatTensor,
attention_mask: torch.LongTensor,
past_key_value: torch.FloatTensor,
**kwargs,
) -> BaseModelOutput:
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
"attention_mask": attention_mask.cpu().detach().numpy(),
"past_key_values.0.key": past_key_value.cpu().detach().numpy(),
}
# Run inference
outputs = self.session.run(None, onnx_inputs)
for i, output in enumerate(outputs):
outputs[i] = torch.from_numpy(output).to(self._device)
return (
outputs[self.output_names["inputs_embeds"]],
outputs[self.output_names["decoder_attention_mask"]],
outputs[self.output_names["position_ids"]],
)
class ORTDecoder(ORTModel):
def __init__(self, path, config):
super().__init__(path, config)
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.text_config.model_type)(
config.text_config
)
self.generation_config = GenerationConfig.from_model_config(config)
self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)]
self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)]
self.num_pkv = 2
def prepare_pkv(self, batch_size: int):
if self.config.text_config.model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads
shape = (batch_size, num_attention_heads, 0, embed_size_per_head)
key_or_value = torch.zeros(shape, dtype=torch.float32)
past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names)))
return past_key_values
def forward(
self,
attention_mask: torch.LongTensor,
position_ids: torch.LongTensor,
inputs_embeds: torch.FloatTensor,
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
) -> CausalLMOutputWithPast:
onnx_inputs = {
"attention_mask": attention_mask.cpu().detach().numpy(),
"position_ids": position_ids.cpu().detach().numpy(),
"inputs_embeds": inputs_embeds.cpu().detach().numpy(),
}
if past_key_values is None:
past_key_values = self.prepare_pkv(inputs_embeds.shape[0])
else:
past_key_values = tuple(
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()
# Run inference
outputs = self.session.run(None, onnx_inputs)
logits = torch.from_numpy(outputs[self.output_names["logits"]])
past_key_values = tuple(
torch.from_numpy(outputs[self.output_names[key]]) for key in self.key_value_output_names
)
past_key_values = tuple(
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
)
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
class ORTModelForLLava(PreTrainedModel, GenerationMixin):
def __init__(self, *args, **kwargs):
config = AutoConfig.from_pretrained(model_name)
super().__init__(config)
self.config = config
self._device = device
self.vision_tower = ORTEncoder(model_name + "encoder_model.onnx", config)
self.language_model = ORTDecoder(model_name + "decoder_model.onnx", config)
self.decoder_input_processor = ORTDecoderProcessor(model_name + "decoder_input_processor_model.onnx", config)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
**kwargs,
) -> CausalLMOutputWithPast:
if past_key_values is None:
inputs_embeds, attention_mask, position_ids = self.vision_tower(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
)
else:
inputs_embeds, attention_mask, position_ids = self.decoder_input_processor(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_value=past_key_values[0][0][:, :, :, 0],
)
# Decode
decoder_outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
)
return decoder_outputs
def can_generate(self):
return True
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
):
if past_key_values is not None:
cache_length = past_length = past_key_values[0][0].shape[2]
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
elif self.config.image_token_index in input_ids:
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
)
return model_inputs
@property
def device(self) -> torch.device:
return self._device
@device.setter
def device(self, value: torch.device):
self._device = value
def to(self, device):
self.device = device
return self
model = ORTModelForLLava()
generated_ids = model.generate(**inputs, max_length=30)
out = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment