Skip to content

Instantly share code, notes, and snippets.

@CodeZombie
Last active March 27, 2023 03:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CodeZombie/e085b4b621d609447d6e13ea97a8c6a7 to your computer and use it in GitHub Desktop.
Save CodeZombie/e085b4b621d609447d6e13ea97a8c6a7 to your computer and use it in GitHub Desktop.
a drop-in replacement for Automatic1111's `safe.py` which as it turns out, is not safe at all :)
# An actual safe model loader.
# I consider this file to actually be safe because it physically cannot load pickle file.
# If you try to load a pickle file, it will instead look for a safetensors file with the same name and try to load that.
# If it does find a suitable Safetensors alternative, it will load it in such a way that makes it compatible with all pickle-related code (params_ema/params checking)
# This code was written by Jeremy C (badnoise.net)
import safetensors.torch
import torch
import os
class StateDictCompatibleDictionary(dict):
"""
Looks and behaves exactly like a normal dictionary, except that when you try to check for or get a key called 'params_ema', it will return a reference to itself.
This makes this model compatible with a lot of hack Pickle-dependant code that would manually check the dict for 'params_ema' and attempt to load that value as the state_dict.
We need to do this because when loading a state dict in Strict mode, the dict.keys() must match `torch.nn.Module.state_dict`, which does not include `params_ema`.
So by doing this, this one object will be compatible with all loading logic, whether it checks for params_ema or not.
"""
hidden_key = "params_ema"
def __contains__(self, __key: object) -> bool:
if __key == self.hidden_key:
return True
return super().__contains__(__key)
def __getitem__(self, __key: object) -> object:
if __key == self.hidden_key:
return self
return super().__getitem__(__key)
def load(filename, *args, **kwargs):
safetensors_variant_filename = filename
if not filename.endswith(".safetensors"):
filepath_without_extension = filename.rsplit(".", 1)[0]
safetensors_variant_filename = "{}.safetensors".format(filepath_without_extension)
print("A non-safetensors file was attempting to load. This is unacceptable. Attempting to load safetensors variant: {}".format(safetensors_variant_filename))
#check if the filename exists...
if not os.path.exists(filename):
raise Exception("Safetensors variant not found. Please ensure one exists: \"{}\"".format(safetensors_variant_filename))
# remove the kwargs with the keyword 'extra_handler'
device = kwargs.pop('map_location', None)
#Safetensors.torch.load wants 'device' to be a string, but torch.load wanted it to be a torch.device. This fixes that.
if isinstance(device, torch.device):
device = device.type
return StateDictCompatibleDictionary(safetensors.torch.load_file(safetensors_variant_filename, *args, device=device, **kwargs))
#Override torch.load to use our method instead :)
torch.load = load
@CodeZombie
Copy link
Author

CodeZombie commented Mar 26, 2023

The motivation for this is that pickle files (.pt, .ckpt, .pth, etc) all load by executing arbitrary code. Some work has been done to mitigate this by only allowing certain kinds of code to be arbitrarily executed from within the loaded file, but this is still unacceptable.

This replacement for safe.py (stable-diffusion-webui\modules\safe.py) fully removes the web-ui's ability to load pickle files using the standard torch.load method. Instead, it attempts to find safer pickletensors variant of the same file.

This requires you replace any pickle files with safetensors files manually and place them in the same folder as the pickle files.
I will keep this comments section updated with safetensors variants of popular pickle models as I find them.

@CodeZombie
Copy link
Author

@CodeZombie
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment