Skip to content

Instantly share code, notes, and snippets.

@AmericanPresidentJimmyCarter
Last active March 17, 2023 18:48
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AmericanPresidentJimmyCarter/1947162f371e601ce183070443f41dc2 to your computer and use it in GitHub Desktop.
Save AmericanPresidentJimmyCarter/1947162f371e601ce183070443f41dc2 to your computer and use it in GitHub Desktop.
Attempting to make small DB patches
import argparse
from pathlib import Path
import sys
import torch
THRESHOLD_STRENGTH = 2.
DEFAULT_OUT_NAME = 'output.ckpt'
parser = argparse.ArgumentParser(description='Create a compressed dreambooth patch or patch weights')
parser.add_argument('mode', type=str, help='"compress" or "inflate"')
parser.add_argument(
'-m',
type=str,
help='Base model (Compvis)',
required=False,
)
parser.add_argument(
'-m2',
type=str,
help='Dreambooth model (Compvis)',
required=False,
)
parser.add_argument(
'-c',
type=str,
help='Model configuration for compvis model',
required=False,
)
parser.add_argument(
'-p',
type=str,
help='Patch file name',
required=False,
)
parser.add_argument(
'-o',
type=str,
help='Output file name',
required=False,
)
args = parser.parse_args()
if args.mode == 'compress' and \
not (Path.is_file(Path(args.m)) and Path.is_file(Path(args.m2))):
print('One or both models in not a file, please provide a correct path')
sys.exit(1)
if args.mode == 'inflate' and \
not (Path.is_file(Path(args.m)) and Path.is_file(Path(args.p))):
print('One or both models in not a file, please provide a correct path')
sys.exit(1)
out_path = args.o or DEFAULT_OUT_NAME
if args.mode == 'compress':
model = torch.load(args.m)
model_db = torch.load(args.m2)
m_state = model['state_dict']
m_db_state = model_db['state_dict']
for key in m_state.keys():
if 'model' in key and \
key in m_state and key in m_db_state and \
isinstance(m_state[key], torch.Tensor) and \
isinstance(m_db_state[key], torch.Tensor):
if m_state[key].dtype == torch.float32:
m_state[key] = m_state[key].half()
if m_db_state[key].dtype == torch.float32:
m_db_state[key] = m_db_state[key].half()
# Diff
m_state[key] = m_state[key] - m_db_state[key]
# Clamp to ignore small diff
threshold_pos = torch.max(m_state[key]) / THRESHOLD_STRENGTH
threshold_neg = torch.min(m_state[key]) / THRESHOLD_STRENGTH
clamped_high = torch.clone(m_state[key])
clamped_low = torch.clone(m_state[key])
clamped_high[clamped_high < threshold_pos] = 0.
clamped_low[clamped_low > threshold_neg] = 0.
non_zero_are_one = clamped_high + clamped_low
non_zero_are_one[non_zero_are_one != 0.] = 1.
# Compress as sparse tensor
m_state[key] = non_zero_are_one * m_db_state[key]
m_state[key] = m_state[key].to_sparse()
del clamped_high
del clamped_low
del non_zero_are_one
if 'model' in key and \
key not in m_state and key in m_db_state and \
isinstance(m_db_state[key], torch.Tensor):
m_state[key] = m_db_state[key].to_sparse()
del m_db_state
del model_db
# Save patch
torch.save(m_state, out_path)
if args.mode == 'inflate':
model = torch.load(args.m)
model_patch = torch.load(args.p)
# Diff the two diffusers unets.
m_state = model['state_dict']
m_patch_state = model_patch
for key in m_state.keys():
if 'model' in key and \
key in m_state and key in m_patch_state and \
isinstance(m_state[key], torch.Tensor) and \
isinstance(m_patch_state[key], torch.Tensor):
if m_state[key].dtype == torch.float32:
m_state[key] = m_state[key].half()
inflated = m_patch_state[key].to_dense().half()
assert m_state[key].size() == inflated.size()
# Zero out the altered values.
non_zero_are_one = inflated.clone()
non_zero_are_one[non_zero_are_one != 0.] = 1.
ones = torch.ones_like(non_zero_are_one)
zero_are_one = ones - non_zero_are_one
m_state[key] = m_state[key] * zero_are_one
m_state[key] = m_state[key] + inflated
del inflated
del non_zero_are_one
del ones
del zero_are_one
if 'model' in key and \
key not in m_state and key in m_patch_state and \
isinstance(m_patch_state[key], torch.Tensor):
m_state[key] = m_patch_state[key].to_dense()
# Save patched model
del m_patch_state
del model_patch
model['state_dict'] = m_state
torch.save(model, out_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment