Skip to content

Instantly share code, notes, and snippets.

@CodeZombie
Last active March 27, 2023 03:26
Show Gist options
  • Save CodeZombie/e085b4b621d609447d6e13ea97a8c6a7 to your computer and use it in GitHub Desktop.
Save CodeZombie/e085b4b621d609447d6e13ea97a8c6a7 to your computer and use it in GitHub Desktop.
a drop-in replacement for Automatic1111's `safe.py` which as it turns out, is not safe at all :)
# An actual safe model loader.
# I consider this file to actually be safe because it physically cannot load pickle file.
# If you try to load a pickle file, it will instead look for a safetensors file with the same name and try to load that.
# If it does find a suitable Safetensors alternative, it will load it in such a way that makes it compatible with all pickle-related code (params_ema/params checking)
# This code was written by Jeremy C (badnoise.net)
import safetensors.torch
import torch
import os
class StateDictCompatibleDictionary(dict):
"""
Looks and behaves exactly like a normal dictionary, except that when you try to check for or get a key called 'params_ema', it will return a reference to itself.
This makes this model compatible with a lot of hack Pickle-dependant code that would manually check the dict for 'params_ema' and attempt to load that value as the state_dict.
We need to do this because when loading a state dict in Strict mode, the dict.keys() must match `torch.nn.Module.state_dict`, which does not include `params_ema`.
So by doing this, this one object will be compatible with all loading logic, whether it checks for params_ema or not.
"""
hidden_key = "params_ema"
def __contains__(self, __key: object) -> bool:
if __key == self.hidden_key:
return True
return super().__contains__(__key)
def __getitem__(self, __key: object) -> object:
if __key == self.hidden_key:
return self
return super().__getitem__(__key)
def load(filename, *args, **kwargs):
safetensors_variant_filename = filename
if not filename.endswith(".safetensors"):
filepath_without_extension = filename.rsplit(".", 1)[0]
safetensors_variant_filename = "{}.safetensors".format(filepath_without_extension)
print("A non-safetensors file was attempting to load. This is unacceptable. Attempting to load safetensors variant: {}".format(safetensors_variant_filename))
#check if the filename exists...
if not os.path.exists(filename):
raise Exception("Safetensors variant not found. Please ensure one exists: \"{}\"".format(safetensors_variant_filename))
# remove the kwargs with the keyword 'extra_handler'
device = kwargs.pop('map_location', None)
#Safetensors.torch.load wants 'device' to be a string, but torch.load wanted it to be a torch.device. This fixes that.
if isinstance(device, torch.device):
device = device.type
return StateDictCompatibleDictionary(safetensors.torch.load_file(safetensors_variant_filename, *args, device=device, **kwargs))
#Override torch.load to use our method instead :)
torch.load = load
@CodeZombie
Copy link
Author

@CodeZombie
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment