Last active
March 27, 2023 03:26
-
-
Save CodeZombie/e085b4b621d609447d6e13ea97a8c6a7 to your computer and use it in GitHub Desktop.
a drop-in replacement for Automatic1111's `safe.py` which as it turns out, is not safe at all :)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# An actual safe model loader. | |
# I consider this file to actually be safe because it physically cannot load pickle file. | |
# If you try to load a pickle file, it will instead look for a safetensors file with the same name and try to load that. | |
# If it does find a suitable Safetensors alternative, it will load it in such a way that makes it compatible with all pickle-related code (params_ema/params checking) | |
# This code was written by Jeremy C (badnoise.net) | |
import safetensors.torch | |
import torch | |
import os | |
class StateDictCompatibleDictionary(dict): | |
""" | |
Looks and behaves exactly like a normal dictionary, except that when you try to check for or get a key called 'params_ema', it will return a reference to itself. | |
This makes this model compatible with a lot of hack Pickle-dependant code that would manually check the dict for 'params_ema' and attempt to load that value as the state_dict. | |
We need to do this because when loading a state dict in Strict mode, the dict.keys() must match `torch.nn.Module.state_dict`, which does not include `params_ema`. | |
So by doing this, this one object will be compatible with all loading logic, whether it checks for params_ema or not. | |
""" | |
hidden_key = "params_ema" | |
def __contains__(self, __key: object) -> bool: | |
if __key == self.hidden_key: | |
return True | |
return super().__contains__(__key) | |
def __getitem__(self, __key: object) -> object: | |
if __key == self.hidden_key: | |
return self | |
return super().__getitem__(__key) | |
def load(filename, *args, **kwargs): | |
safetensors_variant_filename = filename | |
if not filename.endswith(".safetensors"): | |
filepath_without_extension = filename.rsplit(".", 1)[0] | |
safetensors_variant_filename = "{}.safetensors".format(filepath_without_extension) | |
print("A non-safetensors file was attempting to load. This is unacceptable. Attempting to load safetensors variant: {}".format(safetensors_variant_filename)) | |
#check if the filename exists... | |
if not os.path.exists(filename): | |
raise Exception("Safetensors variant not found. Please ensure one exists: \"{}\"".format(safetensors_variant_filename)) | |
# remove the kwargs with the keyword 'extra_handler' | |
device = kwargs.pop('map_location', None) | |
#Safetensors.torch.load wants 'device' to be a string, but torch.load wanted it to be a torch.device. This fixes that. | |
if isinstance(device, torch.device): | |
device = device.type | |
return StateDictCompatibleDictionary(safetensors.torch.load_file(safetensors_variant_filename, *args, device=device, **kwargs)) | |
#Override torch.load to use our method instead :) | |
torch.load = load |
ControlNet models:
https://huggingface.co/webui/ControlNet-modules-safetensors/tree/main
Pickle to Safetensors Google Colab
https://colab.research.google.com/drive/1ehsididHPX3kFQTcH93okh9_4W46E6HS?usp=sharing
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The motivation for this is that pickle files (.pt, .ckpt, .pth, etc) all load by executing arbitrary code. Some work has been done to mitigate this by only allowing certain kinds of code to be arbitrarily executed from within the loaded file, but this is still unacceptable.
This replacement for
safe.py
(stable-diffusion-webui\modules\safe.py
) fully removes the web-ui's ability to load pickle files using the standardtorch.load
method. Instead, it attempts to find saferpickletensors
variant of the same file.This requires you replace any pickle files with safetensors files manually and place them in the same folder as the pickle files.
I will keep this comments section updated with safetensors variants of popular pickle models as I find them.