Last active
November 1, 2024 12:48
-
-
Save Stella2211/10f5bd870387ec1ddb9932235321068e to your computer and use it in GitHub Desktop.
メモリ効率のいいfp8化スクリプト。 / Memory efficient fp8 conversion script.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from pathlib import Path | |
import torch | |
from tqdm import tqdm | |
import struct | |
from typing import Dict, Any | |
import sys | |
# input file | |
if(len(sys.argv) < 3): | |
print("Usage: mem_eff_fp8_convert.py {fp16 model path} {output path}") | |
sys.exit(1) | |
path = sys.argv[1] | |
output =sys.argv[2] | |
model_file = Path(path) | |
class MemoryEfficientSafeOpen: | |
# does not support metadata loading | |
def __init__(self, filename): | |
self.filename = filename | |
self.header, self.header_size = self._read_header() | |
self.file = open(filename, "rb") | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.file.close() | |
def keys(self): | |
return [k for k in self.header.keys() if k != "__metadata__"] | |
def get_tensor(self, key): | |
if key not in self.header: | |
raise KeyError(f"Tensor '{key}' not found in the file") | |
metadata = self.header[key] | |
offset_start, offset_end = metadata["data_offsets"] | |
if offset_start == offset_end: | |
tensor_bytes = None | |
else: | |
# adjust offset by header size | |
self.file.seek(self.header_size + 8 + offset_start) | |
tensor_bytes = self.file.read(offset_end - offset_start) | |
return self._deserialize_tensor(tensor_bytes, metadata) | |
def _read_header(self): | |
with open(self.filename, "rb") as f: | |
header_size = struct.unpack("<Q", f.read(8))[0] | |
header_json = f.read(header_size).decode("utf-8") | |
return json.loads(header_json), header_size | |
def _deserialize_tensor(self, tensor_bytes, metadata): | |
dtype = self._get_torch_dtype(metadata["dtype"]) | |
shape = metadata["shape"] | |
if tensor_bytes is None: | |
byte_tensor = torch.empty(0, dtype=torch.uint8) | |
else: | |
tensor_bytes = bytearray(tensor_bytes) # make it writable | |
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8) | |
# process float8 types | |
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]: | |
return self._convert_float8(byte_tensor, metadata["dtype"], shape) | |
# convert to the target dtype and reshape | |
return byte_tensor.view(dtype).reshape(shape) | |
@staticmethod | |
def _get_torch_dtype(dtype_str): | |
dtype_map = { | |
"F64": torch.float64, | |
"F32": torch.float32, | |
"F16": torch.float16, | |
"BF16": torch.bfloat16, | |
"I64": torch.int64, | |
"I32": torch.int32, | |
"I16": torch.int16, | |
"I8": torch.int8, | |
"U8": torch.uint8, | |
"BOOL": torch.bool, | |
} | |
# add float8 types if available | |
if hasattr(torch, "float8_e5m2"): | |
dtype_map["F8_E5M2"] = torch.float8_e5m2 | |
if hasattr(torch, "float8_e4m3fn"): | |
dtype_map["F8_E4M3"] = torch.float8_e4m3fn | |
return dtype_map.get(dtype_str) | |
@staticmethod | |
def _convert_float8(byte_tensor, dtype_str, shape): | |
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"): | |
return byte_tensor.view(torch.float8_e5m2).reshape(shape) | |
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"): | |
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape) | |
else: | |
# # convert to float16 if float8 is not supported | |
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.") | |
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) | |
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") | |
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): | |
_TYPES = { | |
torch.float64: "F64", | |
torch.float32: "F32", | |
torch.float16: "F16", | |
torch.bfloat16: "BF16", | |
torch.int64: "I64", | |
torch.int32: "I32", | |
torch.int16: "I16", | |
torch.int8: "I8", | |
torch.uint8: "U8", | |
torch.bool: "BOOL", | |
getattr(torch, "float8_e5m2", None): "F8_E5M2", | |
getattr(torch, "float8_e4m3fn", None): "F8_E4M3", | |
} | |
_ALIGN = 256 | |
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: | |
validated = {} | |
for key, value in metadata.items(): | |
if not isinstance(key, str): | |
raise ValueError(f"Metadata key must be a string, got {type(key)}") | |
if not isinstance(value, str): | |
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") | |
validated[key] = str(value) | |
else: | |
validated[key] = value | |
return validated | |
header = {} | |
offset = 0 | |
if metadata: | |
header["__metadata__"] = validate_metadata(metadata) | |
for k, v in tensors.items(): | |
if v.numel() == 0: # empty tensor | |
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} | |
else: | |
size = v.numel() * v.element_size() | |
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} | |
offset += size | |
hjson = json.dumps(header).encode("utf-8") | |
hjson += b" " * (-(len(hjson) + 8) % _ALIGN) | |
with open(filename, "wb") as f: | |
f.write(struct.pack("<Q", len(hjson))) | |
f.write(hjson) | |
for k, v in tensors.items(): | |
if v.numel() == 0: | |
continue | |
if v.is_cuda: | |
# Direct GPU to disk save | |
with torch.cuda.device(v.device): | |
if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
v = v.unsqueeze(0) | |
tensor_bytes = v.contiguous().view(torch.uint8) | |
tensor_bytes.cpu().numpy().tofile(f) | |
else: | |
# CPU tensor save | |
if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
v = v.unsqueeze(0) | |
v.contiguous().view(torch.uint8).numpy().tofile(f) | |
# read safetensors metadata | |
def read_safetensors_metadata(path: str): | |
with open(path, 'rb') as f: | |
header_size = int.from_bytes(f.read(8), 'little') | |
header_json = f.read(header_size).decode('utf-8') | |
header = json.loads(header_json) | |
metadata = header.get('__metadata__', {}) | |
return metadata | |
metadata = read_safetensors_metadata(path) | |
print(json.dumps(metadata, indent=4)) #show metadata | |
sd_pruned = dict() #initialize empty dict | |
with MemoryEfficientSafeOpen(path) as reader: | |
keys = reader.keys() | |
for key in tqdm(keys): #for each key in the safetensors file | |
sd_pruned[key] = reader.get_tensor(key).to(torch.float8_e4m3fn) #convert to fp8 | |
# save the pruned safetensors file | |
mem_eff_save_file(sd_pruned, output, metadata={"format": "pt", **metadata}) |
アップデート: kohya-ss氏のmem_eff_safeopen.pyを組み込み、更にメモリ効率の良い変換を実現しました。Animagine XL 3.1の場合、メモリ使用量3.4GBで変換できます。
ライセンス: kohya-ss氏のmem_eff_safeopen.pyとmem_eff_save_file.pyはApache 2.0ライセンスで提供されています。
Update: I incorporated kohya-ss's mem_eff_safeopen.py for even more memory-efficient conversion.
For example, Animagine XL 3.1 can convert to fp8 with 3.4 GB of memory usage.
License: kohya-ss's mem_eff_safeopen.py and mem_eff_save_file.py are provided under the Apache 2.0 license.
Good work!
Thank you!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
スペシャルサンクス:
お二方のコードがなければ、このスクリプトを作ることは出来ませんでした。多大な感謝を申し上げます。
また、開発に協力してくださったharuharu-1105氏にも感謝を申し上げさせていただきます。
Special Thanks:
Without their code, I would not have been able to create this script. A huge thanks goes out to them.
I would also like to thank haruharu-1105 for his help in the development.