Skip to content

Instantly share code, notes, and snippets.

@cmdr2
Last active June 1, 2023 06:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cmdr2/15e954519ed3cb7a36ddb377cd1d576f to your computer and use it in GitHub Desktop.
Save cmdr2/15e954519ed3cb7a36ddb377cd1d576f to your computer and use it in GitHub Desktop.
TensorRT and DirectML with regular diffusers pipelines
import torch
import tensorrt as trt
from polygraphy import cuda
import sys
from packaging import version
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.onnx_utils import OnnxRuntimeModel, ORT_TO_NP_TYPE
from dataclasses import dataclass
import numpy as np
import onnxruntime as ort
import onnx
import shutil
import os
device = "cuda:0"
dtype = torch.float32
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", local_files_only=True)
pipe.requires_safety_checker = False
pipe.safety_checker = None
pipe = pipe.to(device, torch_dtype=dtype)
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
def init():
convert_pipeline_unet_to_onnx(pipe, "F:/sd_1_5_unet.onnx")
convert_unet_onnx_to_trt(onnx_path="F:/sd_1_5_unet.onnx", save_path="F:/sd_1_5_unet.trt")
infer_trt(unet_path="F:/sd_1_5_unet.trt", width=512, height=512)
# infer_dml(unet_path="F:/sd_1_5_unet.onnx")
## INFERENCE
@dataclass
class UnetResult:
sample: torch.FloatTensor = None
class UnetTRT:
def __init__(self, engine_path):
trt.init_libnvinfer_plugins(None, "")
with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
self.engine = runtime.deserialize_cuda_engine(f.read())
self.trt_context = self.engine.create_execution_context()
self.tensors = {}
def allocate_buffers(self, width=512, height=512):
"Call this once before an image is generated, not per sample"
self.tensors.clear()
shape_dict = {
"sample": (2, 4, width // 8, height // 8),
"encoder_hidden_states": (2, 77, 768),
"timestep": (2,),
}
for i, binding in enumerate(self.engine):
if binding in shape_dict:
shape = shape_dict[binding]
else:
shape = self.engine.get_binding_shape(binding)
if binding == "out_sample":
shape = (2, 4, width // 8, height // 8)
if self.engine.binding_is_input(binding):
self.trt_context.set_binding_shape(i, shape)
self.tensors[binding] = torch.empty(tuple(shape), dtype=dtype, device=device)
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
feed_dict = {
"sample": sample,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
}
stream = cuda.Stream()
for name, tensor in feed_dict.items():
self.tensors[name].copy_(tensor)
for name, tensor in self.tensors.items():
self.trt_context.set_tensor_address(name, tensor.data_ptr())
if not self.trt_context.execute_async_v3(stream_handle=stream.ptr):
raise RuntimeError("Inference failed!")
return UnetResult(sample=self.tensors["out_sample"])
def infer_trt(unet_path: str, width=512, height=512):
unet_trt = UnetTRT(unet_path)
pipe.unet.forward = unet_trt.forward
unet_trt.allocate_buffers(width=width, height=height)
images = pipe(
"photograph of an astronaut standing on mars", num_inference_steps=50, width=width, height=height
).images
images[0].save("astro_trt.jpg")
def infer_dml(unet_path: str):
# batch_size = 1
# these are supposed to make things faster, haven't tested if they do
sess_options = ort.SessionOptions()
sess_options.enable_mem_pattern = False
# sess_options.add_free_dimension_override_by_name("unet_sample_batch", batch_size * 2)
# sess_options.add_free_dimension_override_by_name("unet_sample_channels", 4)
# sess_options.add_free_dimension_override_by_name("unet_sample_height", 64)
# sess_options.add_free_dimension_override_by_name("unet_sample_width", 64)
# sess_options.add_free_dimension_override_by_name("unet_time_batch", 1)
# sess_options.add_free_dimension_override_by_name("unet_hidden_batch", batch_size * 2)
# sess_options.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
unet_dml = OnnxRuntimeModel(
model=OnnxRuntimeModel.load_model(unet_path, "DmlExecutionProvider", sess_options=sess_options)
)
def forward(sample, timestep, encoder_hidden_states, **kwargs):
timestep_dtype = next(
(input.type for input in unet_dml.model.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
input = {
"sample": sample.cpu().numpy(),
"timestep": np.array([timestep.cpu()], dtype=timestep_dtype),
"encoder_hidden_states": encoder_hidden_states.cpu().numpy(),
}
sample = unet_dml(**input)[0]
sample = torch.from_numpy(sample).to(device)
return UnetResult(sample)
pipe.unet.forward = forward
images = pipe("photograph of an astronaut standing on mars", num_inference_steps=50).images
images[0].save("astro_dml.jpg")
## CONVERT
def convert_pipeline_unet_to_onnx(pipeline, save_path, opset=17, fp16: bool = False):
unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size
num_tokens = pipeline.text_encoder.config.max_position_embeddings
text_hidden_size = pipeline.text_encoder.config.hidden_size
_dtype = torch.float16 if fp16 else torch.float32
_device = device if fp16 else "cpu"
pipeline = pipeline.to(_device, torch_dtype=_dtype)
tmp_dir = save_path + "_" # collect the individual weights here
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
os.mkdir(tmp_dir)
tmp_model_path = os.path.join(tmp_dir, "model.onnx")
onnx_export(
pipeline.unet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=_device, dtype=_dtype),
torch.randn(2).to(device=_device, dtype=_dtype),
torch.randn(2, num_tokens, text_hidden_size).to(device=_device, dtype=_dtype),
False,
),
output_path=tmp_model_path,
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
output_names=["out_sample"], # has to be different from "sample" for correct tracing
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
},
opset=opset,
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
)
unet = onnx.load(tmp_model_path)
shutil.rmtree(tmp_dir)
# collate external tensor files into one
onnx.save_model(
unet,
save_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
convert_attribute=False,
)
pipeline = pipeline.to(device, torch_dtype=dtype)
def convert_unet_onnx_to_trt(
onnx_path,
save_path,
batch_size=1,
in_channels=4,
min_size=(512, 512),
max_size=(1024, 1024),
min_seq_length=77,
max_seq_length=77 * 3,
text_hidden_size=768,
):
TRT_BUILDER = trt.Builder(TRT_LOGGER)
network = TRT_BUILDER.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
onnx_parser = trt.OnnxParser(network, TRT_LOGGER)
parse_success = onnx_parser.parse_from_file(onnx_path)
for idx in range(onnx_parser.num_errors):
print(onnx_parser.get_error(idx))
if not parse_success:
sys.exit("ONNX model parsing failed")
config = TRT_BUILDER.create_builder_config()
profile = TRT_BUILDER.create_optimization_profile()
min_shape = {
"sample": (batch_size * 2, in_channels, min_size[1] // 8, min_size[0] // 8),
"encoder_hidden_states": (batch_size * 2, min_seq_length, text_hidden_size),
"timestep": (batch_size * 2,),
}
max_shape = {
"sample": (batch_size * 2, in_channels, max_size[1] // 8, max_size[0] // 8),
"encoder_hidden_states": (batch_size * 2, max_seq_length, text_hidden_size),
"timestep": (batch_size * 2,),
}
for name in min_shape.keys():
profile.set_shape(name, min_shape[name], min_shape[name], max_shape[name])
config.add_optimization_profile(profile)
# config.max_workspace_size = 4096 * (1 << 20)
config.set_flag(trt.BuilderFlag.FP16)
serialized_engine = TRT_BUILDER.build_serialized_network(network, config)
## save TRT engine
with open(save_path, "wb") as f:
f.write(serialized_engine)
print(f"Engine is saved to {save_path}")
def onnx_export(
model,
model_args: tuple,
output_path,
ordered_input_names,
output_names,
dynamic_axes,
opset,
use_external_data_format=False,
):
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if is_torch_less_than_1_11:
torch.onnx.export(
model,
model_args,
f=output_path,
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
enable_onnx_checker=True,
opset_version=opset,
)
else:
torch.onnx.export(
model,
model_args,
f=output_path,
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=opset,
)
init()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment