Last active
June 1, 2023 06:28
-
-
Save cmdr2/15e954519ed3cb7a36ddb377cd1d576f to your computer and use it in GitHub Desktop.
TensorRT and DirectML with regular diffusers pipelines
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import tensorrt as trt | |
from polygraphy import cuda | |
import sys | |
from packaging import version | |
from diffusers import StableDiffusionPipeline | |
from diffusers.pipelines.onnx_utils import OnnxRuntimeModel, ORT_TO_NP_TYPE | |
from dataclasses import dataclass | |
import numpy as np | |
import onnxruntime as ort | |
import onnx | |
import shutil | |
import os | |
device = "cuda:0" | |
dtype = torch.float32 | |
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", local_files_only=True) | |
pipe.requires_safety_checker = False | |
pipe.safety_checker = None | |
pipe = pipe.to(device, torch_dtype=dtype) | |
TRT_LOGGER = trt.Logger(trt.Logger.INFO) | |
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11") | |
def init(): | |
convert_pipeline_unet_to_onnx(pipe, "F:/sd_1_5_unet.onnx") | |
convert_unet_onnx_to_trt(onnx_path="F:/sd_1_5_unet.onnx", save_path="F:/sd_1_5_unet.trt") | |
infer_trt(unet_path="F:/sd_1_5_unet.trt", width=512, height=512) | |
# infer_dml(unet_path="F:/sd_1_5_unet.onnx") | |
## INFERENCE | |
@dataclass | |
class UnetResult: | |
sample: torch.FloatTensor = None | |
class UnetTRT: | |
def __init__(self, engine_path): | |
trt.init_libnvinfer_plugins(None, "") | |
with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: | |
self.engine = runtime.deserialize_cuda_engine(f.read()) | |
self.trt_context = self.engine.create_execution_context() | |
self.tensors = {} | |
def allocate_buffers(self, width=512, height=512): | |
"Call this once before an image is generated, not per sample" | |
self.tensors.clear() | |
shape_dict = { | |
"sample": (2, 4, width // 8, height // 8), | |
"encoder_hidden_states": (2, 77, 768), | |
"timestep": (2,), | |
} | |
for i, binding in enumerate(self.engine): | |
if binding in shape_dict: | |
shape = shape_dict[binding] | |
else: | |
shape = self.engine.get_binding_shape(binding) | |
if binding == "out_sample": | |
shape = (2, 4, width // 8, height // 8) | |
if self.engine.binding_is_input(binding): | |
self.trt_context.set_binding_shape(i, shape) | |
self.tensors[binding] = torch.empty(tuple(shape), dtype=dtype, device=device) | |
def forward(self, sample, timestep, encoder_hidden_states, **kwargs): | |
feed_dict = { | |
"sample": sample, | |
"timestep": timestep, | |
"encoder_hidden_states": encoder_hidden_states, | |
} | |
stream = cuda.Stream() | |
for name, tensor in feed_dict.items(): | |
self.tensors[name].copy_(tensor) | |
for name, tensor in self.tensors.items(): | |
self.trt_context.set_tensor_address(name, tensor.data_ptr()) | |
if not self.trt_context.execute_async_v3(stream_handle=stream.ptr): | |
raise RuntimeError("Inference failed!") | |
return UnetResult(sample=self.tensors["out_sample"]) | |
def infer_trt(unet_path: str, width=512, height=512): | |
unet_trt = UnetTRT(unet_path) | |
pipe.unet.forward = unet_trt.forward | |
unet_trt.allocate_buffers(width=width, height=height) | |
images = pipe( | |
"photograph of an astronaut standing on mars", num_inference_steps=50, width=width, height=height | |
).images | |
images[0].save("astro_trt.jpg") | |
def infer_dml(unet_path: str): | |
# batch_size = 1 | |
# these are supposed to make things faster, haven't tested if they do | |
sess_options = ort.SessionOptions() | |
sess_options.enable_mem_pattern = False | |
# sess_options.add_free_dimension_override_by_name("unet_sample_batch", batch_size * 2) | |
# sess_options.add_free_dimension_override_by_name("unet_sample_channels", 4) | |
# sess_options.add_free_dimension_override_by_name("unet_sample_height", 64) | |
# sess_options.add_free_dimension_override_by_name("unet_sample_width", 64) | |
# sess_options.add_free_dimension_override_by_name("unet_time_batch", 1) | |
# sess_options.add_free_dimension_override_by_name("unet_hidden_batch", batch_size * 2) | |
# sess_options.add_free_dimension_override_by_name("unet_hidden_sequence", 77) | |
unet_dml = OnnxRuntimeModel( | |
model=OnnxRuntimeModel.load_model(unet_path, "DmlExecutionProvider", sess_options=sess_options) | |
) | |
def forward(sample, timestep, encoder_hidden_states, **kwargs): | |
timestep_dtype = next( | |
(input.type for input in unet_dml.model.get_inputs() if input.name == "timestep"), "tensor(float)" | |
) | |
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] | |
input = { | |
"sample": sample.cpu().numpy(), | |
"timestep": np.array([timestep.cpu()], dtype=timestep_dtype), | |
"encoder_hidden_states": encoder_hidden_states.cpu().numpy(), | |
} | |
sample = unet_dml(**input)[0] | |
sample = torch.from_numpy(sample).to(device) | |
return UnetResult(sample) | |
pipe.unet.forward = forward | |
images = pipe("photograph of an astronaut standing on mars", num_inference_steps=50).images | |
images[0].save("astro_dml.jpg") | |
## CONVERT | |
def convert_pipeline_unet_to_onnx(pipeline, save_path, opset=17, fp16: bool = False): | |
unet_in_channels = pipeline.unet.config.in_channels | |
unet_sample_size = pipeline.unet.config.sample_size | |
num_tokens = pipeline.text_encoder.config.max_position_embeddings | |
text_hidden_size = pipeline.text_encoder.config.hidden_size | |
_dtype = torch.float16 if fp16 else torch.float32 | |
_device = device if fp16 else "cpu" | |
pipeline = pipeline.to(_device, torch_dtype=_dtype) | |
tmp_dir = save_path + "_" # collect the individual weights here | |
if os.path.exists(tmp_dir): | |
shutil.rmtree(tmp_dir) | |
os.mkdir(tmp_dir) | |
tmp_model_path = os.path.join(tmp_dir, "model.onnx") | |
onnx_export( | |
pipeline.unet, | |
model_args=( | |
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=_device, dtype=_dtype), | |
torch.randn(2).to(device=_device, dtype=_dtype), | |
torch.randn(2, num_tokens, text_hidden_size).to(device=_device, dtype=_dtype), | |
False, | |
), | |
output_path=tmp_model_path, | |
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], | |
output_names=["out_sample"], # has to be different from "sample" for correct tracing | |
dynamic_axes={ | |
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, | |
"timestep": {0: "batch"}, | |
"encoder_hidden_states": {0: "batch", 1: "sequence"}, | |
}, | |
opset=opset, | |
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split | |
) | |
unet = onnx.load(tmp_model_path) | |
shutil.rmtree(tmp_dir) | |
# collate external tensor files into one | |
onnx.save_model( | |
unet, | |
save_path, | |
save_as_external_data=True, | |
all_tensors_to_one_file=True, | |
location="weights.pb", | |
convert_attribute=False, | |
) | |
pipeline = pipeline.to(device, torch_dtype=dtype) | |
def convert_unet_onnx_to_trt( | |
onnx_path, | |
save_path, | |
batch_size=1, | |
in_channels=4, | |
min_size=(512, 512), | |
max_size=(1024, 1024), | |
min_seq_length=77, | |
max_seq_length=77 * 3, | |
text_hidden_size=768, | |
): | |
TRT_BUILDER = trt.Builder(TRT_LOGGER) | |
network = TRT_BUILDER.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) | |
onnx_parser = trt.OnnxParser(network, TRT_LOGGER) | |
parse_success = onnx_parser.parse_from_file(onnx_path) | |
for idx in range(onnx_parser.num_errors): | |
print(onnx_parser.get_error(idx)) | |
if not parse_success: | |
sys.exit("ONNX model parsing failed") | |
config = TRT_BUILDER.create_builder_config() | |
profile = TRT_BUILDER.create_optimization_profile() | |
min_shape = { | |
"sample": (batch_size * 2, in_channels, min_size[1] // 8, min_size[0] // 8), | |
"encoder_hidden_states": (batch_size * 2, min_seq_length, text_hidden_size), | |
"timestep": (batch_size * 2,), | |
} | |
max_shape = { | |
"sample": (batch_size * 2, in_channels, max_size[1] // 8, max_size[0] // 8), | |
"encoder_hidden_states": (batch_size * 2, max_seq_length, text_hidden_size), | |
"timestep": (batch_size * 2,), | |
} | |
for name in min_shape.keys(): | |
profile.set_shape(name, min_shape[name], min_shape[name], max_shape[name]) | |
config.add_optimization_profile(profile) | |
# config.max_workspace_size = 4096 * (1 << 20) | |
config.set_flag(trt.BuilderFlag.FP16) | |
serialized_engine = TRT_BUILDER.build_serialized_network(network, config) | |
## save TRT engine | |
with open(save_path, "wb") as f: | |
f.write(serialized_engine) | |
print(f"Engine is saved to {save_path}") | |
def onnx_export( | |
model, | |
model_args: tuple, | |
output_path, | |
ordered_input_names, | |
output_names, | |
dynamic_axes, | |
opset, | |
use_external_data_format=False, | |
): | |
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, | |
# so we check the torch version for backwards compatibility | |
if is_torch_less_than_1_11: | |
torch.onnx.export( | |
model, | |
model_args, | |
f=output_path, | |
input_names=ordered_input_names, | |
output_names=output_names, | |
dynamic_axes=dynamic_axes, | |
do_constant_folding=True, | |
use_external_data_format=use_external_data_format, | |
enable_onnx_checker=True, | |
opset_version=opset, | |
) | |
else: | |
torch.onnx.export( | |
model, | |
model_args, | |
f=output_path, | |
input_names=ordered_input_names, | |
output_names=output_names, | |
dynamic_axes=dynamic_axes, | |
do_constant_folding=True, | |
opset_version=opset, | |
) | |
init() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment