Skip to content

Instantly share code, notes, and snippets.

@adhikjoshi
Last active July 29, 2023 20:17
Show Gist options
  • Save adhikjoshi/2c6da89cbcd7a6a3344d3081ccd1dda0 to your computer and use it in GitHub Desktop.
Save adhikjoshi/2c6da89cbcd7a6a3344d3081ccd1dda0 to your computer and use it in GitHub Desktop.
Lycoris Inference
import lycoris_inference
import lora
from diffusers import DiffusionPipeline
import torch
from safetensors import safe_open
from safetensors.torch import load_file
import time
import json
# load SDXL pipeline
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
# loha link : https://civitai.com/models/111594/sd-xl09-loha-pearly-gates-concept
# lora link : https://civitai.com/models/112904/arcane-style-lora-xl09
# lycoris link : https://civitai.com/models/108011/fcbodybuildingxl-10-for-sdxl
#lora_model = "loha.safetensors"
lora_model = "lycoris.safetensors"
#lora_model = "sdxl.safetensors"
lora_strength = 1
#pipe.load_lora_weights(".", weight_name=lora_model)
weights_sd = safe_open(lora_model, framework="pt")
network_args = weights_sd.metadata()
print(network_args)
weights_sd = None
weights_sd = load_file(lora_model)
try:
ss_network_args_dict = json.loads(
network_args['ss_network_args'])
if 'algo' in ss_network_args_dict:
algo = ss_network_args_dict['algo']
except Exception as e:
try:
algo = network_args['ss_network_module']
if algo == "networks.lora":
algo = "lora"
except Exception as e:
algo = "lora"
print(e)
print("Error: could not load ss_network_args")
if algo == "lora":
pipe.load_lora_weights(".", weight_name=lora_model)
#network = lora
else:
network = lycoris_inference
# for SDXL two text_encoders
network, weights_sd = network.create_network_from_weights(multiplier =lora_strength,file= "", vae = pipe.vae, text_encoder = [pipe.text_encoder, pipe.text_encoder_2],unet = pipe.unet, weights_sd = weights_sd,for_inference=True, algo= algo )
network.apply_to()
info = network.load_state_dict(weights_sd, False)
network.to("cuda", dtype=torch.float16)
# create an image
generator = torch.Generator("cuda").manual_seed(0)
prompt = "pearlygates, 1girl, solo, scenery, long hair, dress, cloudy sky, standing"
image = pipe(prompt=prompt,generator=generator).images[0]
image.save(time.strftime("%Y%m%d_%H%M%S") + ".png")
import torch
from lycoris import kohya
from lycoris.modules import locon, loha, lokr
kohya.LycorisNetwork.UNET_TARGET_REPLACE_MODULE.remove("Attention")
class LokrModule(lokr.LokrModule):
def make_weight(self):
org_weight = self.org_module[0].weight.to(torch.float)
up = self.lora_up.weight.to(
device=org_weight.device, dtype=org_weight.dtype)
down = self.lora_down.weight.to(
device=org_weight.device, dtype=org_weight.dtype
)
if self.cp:
mid = self.lora_mid.weight.to(
device=org_weight.device, dtype=org_weight.dtype
)
up = up.reshape(up.size(0), up.size(1))
down = down.reshape(down.size(0), down.size(1))
weight = torch.einsum(
"i j k l, i p, j r -> p r k l", mid, up, down)
else:
weight = up.reshape(
up.size(0), -1) @ down.reshape(down.size(0), -1)
return weight.reshape(org_weight.shape) * self.scale
def merge_to(self, *args):
org_weight = self.org_module[0].weight
weight = self.make_weight() * self.multiplier
merged_weight = org_weight + weight.to(org_weight.dtype)
org_weight.copy_(merged_weight)
class LoConModule(locon.LoConModule):
def make_weight(self):
org_weight = self.org_module[0].weight.to(torch.float)
up = self.lora_up.weight.to(
device=org_weight.device, dtype=org_weight.dtype)
down = self.lora_down.weight.to(
device=org_weight.device, dtype=org_weight.dtype
)
if self.cp:
mid = self.lora_mid.weight.to(
device=org_weight.device, dtype=org_weight.dtype
)
up = up.reshape(up.size(0), up.size(1))
down = down.reshape(down.size(0), down.size(1))
weight = torch.einsum(
"m n w h, i m, n j -> i j w h", mid, up, down)
else:
weight = up.reshape(
up.size(0), -1) @ down.reshape(down.size(0), -1)
return weight.reshape(org_weight.shape) * self.scale
def merge_to(self):
org_weight = self.org_module[0].weight
weight = self.make_weight() * self.multiplier
org_weight.copy_(org_weight + weight.to(org_weight.dtype))
class LohaModule(loha.LohaModule):
def make_weight(self):
org_weight = self.org_module[0].weight.to(torch.float)
w1a = self.hada_w1_a.to(
device=org_weight.device, dtype=org_weight.dtype)
w1b = self.hada_w1_b.to(
device=org_weight.device, dtype=org_weight.dtype)
w2a = self.hada_w2_a.to(
device=org_weight.device, dtype=org_weight.dtype)
w2b = self.hada_w2_b.to(
device=org_weight.device, dtype=org_weight.dtype)
if self.cp:
t1 = self.hada_t1.to(device=org_weight.device,
dtype=org_weight.dtype)
t2 = self.hada_t2.to(device=org_weight.device,
dtype=org_weight.dtype)
weight_1 = torch.einsum("i j k l, j r -> i r k l", t1, w1b)
weight_1 = torch.einsum("i j k l, i r -> r j k l", weight_1, w1a)
weight_2 = torch.einsum("i j k l, j r -> i r k l", t2, w2b)
weight_2 = torch.einsum("i j k l, i r -> r j k l", weight_2, w2a)
else:
weight_1 = w1a @ w1b
weight_2 = w2a @ w2b
return (weight_1 * weight_2).reshape(org_weight.shape) * self.scale
def merge_to(self):
org_weight = self.org_module[0].weight
weight = self.make_weight() * self.multiplier
org_weight.copy_(org_weight + weight.to(org_weight.dtype))
def get_metadata(algo: str, weight):
if algo == "lora":
use_cp = False
conv_alpha = None
conv_lora_dim = None
lora_alpha = None
lora_dim = None
for key, value in weight.items():
if key.endswith("alpha"):
base_key = key[:-6]
def get_dim():
lora_up = weight[f"{base_key}.lora_up.weight"].size()[1]
lora_down = weight[f"{base_key}.lora_down.weight"].size()[
0]
assert (
lora_up == lora_down
), "lora_up and lora_down must be same size"
return lora_up
if any([x for x in ["conv", "conv1", "conv2"] if base_key.endswith(x)]):
conv_alpha = int(value)
conv_lora_dim = get_dim()
else:
lora_alpha = int(value)
lora_dim = get_dim()
if f"{base_key}.lora_mid.weight" in weight:
use_cp = True
return conv_alpha, conv_lora_dim, lora_alpha, lora_dim, {"use_cp": use_cp}
elif algo == "loha":
use_cp = False
conv_alpha = None
conv_lora_dim = None
lora_alpha = None
lora_dim = None
for key, value in weight.items():
if key.endswith("alpha"):
base_key = key[:-6]
def get_dim():
hada_w1_b = weight[f"{base_key}.hada_w1_b"].size()[0]
hada_w2_b = weight[f"{base_key}.hada_w2_b"].size()[0]
assert (
hada_w1_b == hada_w2_b
), "hada_w1_b and hada_w2_b must be same size"
return hada_w1_b
if any([x for x in ["conv", "conv1", "conv2"] if base_key.endswith(x)]):
conv_alpha = int(value)
conv_lora_dim = get_dim()
else:
lora_alpha = int(value)
lora_dim = get_dim()
if f"{base_key}.hada_t1" in weight and f"{base_key}.hada_t2" in weight:
use_cp = True
return conv_alpha, conv_lora_dim, lora_alpha, lora_dim, {"use_cp": use_cp}
elif algo == "lokr":
use_cp = False
conv_alpha = None
conv_lora_dim = None
lora_alpha = None
lora_dim = None
for key, value in weight.items():
if key.endswith("alpha"):
base_key = key[:-6]
def get_dim():
return None
if any([x for x in ["conv", "conv1", "conv2"] if base_key.endswith(x)]):
conv_alpha = int(value)
conv_lora_dim = get_dim()
else:
lora_alpha = int(value)
lora_dim = get_dim()
if f"{base_key}.lora_mid.weight" in weight:
use_cp = True
# Additional layers
lora_layers = [
"mlp_fc1", "mlp_fc2", "self_attn_k_proj",
"self_attn_out_proj", "self_attn_q_proj", "self_attn_v_proj"
]
for lora_layer in lora_layers:
if f"lora_te2_text_model_encoder_layers_9_{lora_layer}.alpha" in weight:
lora_alpha = int(
weight[f"lora_te2_text_model_encoder_layers_9_{lora_layer}.alpha"])
lora_dim = weight[f"lora_te2_text_model_encoder_layers_9_{lora_layer}.lokr_w1"].size()[
1]
return conv_alpha, conv_lora_dim, lora_alpha, lora_dim, {"use_cp": use_cp}
def create_network_from_weights(
multiplier: float,
file: str,
vae,
text_encoder,
unet,
algo=None,
weights_sd: torch.Tensor = None,
**kwargs,
):
apply_unet = None
apply_te = None
additional_kwargs = {}
print(algo)
for key in weights_sd.keys():
if key.startswith("lora_unet"):
apply_unet = True
elif key.startswith("lora_te"):
apply_te = True
if algo is None:
if "lora_up" in key or "lora_down" in key:
algo = "lora"
elif "hada" in key:
algo = "loha"
if apply_unet is not None and apply_te is not None and algo is not None:
break
if algo is None:
raise ValueError("Could not determine network module")
(
conv_alpha,
conv_dim,
lora_alpha,
lora_dim,
additional_kwargs,
) = get_metadata(algo, weights_sd)
if lora_dim is None or lora_alpha is None:
lora_dim = 0
lora_alpha = 0
if conv_dim is None or conv_alpha is None:
conv_dim = 0
conv_alpha = 0
network_module = {
"lora": LoConModule,
"locon": LoConModule,
"loha": LohaModule,
# "ia3": IA3Module,
"lokr": LokrModule,
# "dylora": DyLoraModule,
# "glora": GLoRAModule,
}[algo]
network = LycorisNetwork(
text_encoder,
unet,
multiplier=multiplier,
lora_dim=lora_dim,
conv_lora_dim=int(conv_dim),
alpha=lora_alpha,
conv_alpha=conv_alpha,
network_module=network_module,
apply_unet=apply_unet,
apply_te=apply_te,
weights_sd=weights_sd,
**additional_kwargs,
)
return network, weights_sd
class LycorisNetwork(kohya.LycorisNetwork):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.apply_unet = kwargs.get("apply_unet", True)
self.apply_te = kwargs.get("apply_te", True)
self.weights_sd = kwargs.get("weights_sd", None)
if self.apply_unet:
for lora in self.unet_loras:
self.add_module(lora.lora_name, lora)
if self.apply_te:
for lora in self.text_encoder_loras:
self.add_module(lora.lora_name, lora)
for lora in self.text_encoder_loras + self.unet_loras:
org_module = lora.org_module[0]
if not hasattr(org_module, "_lora_org_forward"):
setattr(org_module, "_lora_org_forward", org_module.forward)
if not hasattr(org_module, "_lora_org_weight"):
setattr(org_module, "_lora_org_weight",
org_module.weight.clone().cpu())
def apply_to(self):
apply_text_encoder = self.apply_te
apply_unet = self.apply_unet
if self.weights_sd:
weights_has_text_encoder = weights_has_unet = False
for key in self.weights_sd.keys():
if key.startswith(LycorisNetwork.LORA_PREFIX_TEXT_ENCODER):
weights_has_text_encoder = True
elif key.startswith(LycorisNetwork.LORA_PREFIX_UNET):
weights_has_unet = True
if apply_text_encoder is None:
apply_text_encoder = weights_has_text_encoder
else:
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
if apply_unet is None:
apply_unet = weights_has_unet
else:
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
else:
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
if apply_text_encoder:
print("enable LyCORIS for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LyCORIS for U-Net")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
lora.apply_to()
self.add_module(lora.lora_name, lora)
if self.weights_sd:
info = self.load_state_dict(self.weights_sd, False)
print(f"weights are loaded")
def merge_to(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.merge_to()
def restore(self, *args):
for lora in self.text_encoder_loras + self.unet_loras:
org_module = lora.org_module[0]
if hasattr(org_module, "_lora_org_forward"):
org_module.forward = org_module._lora_org_forward
del org_module._lora_org_forward
if hasattr(org_module, "_lora_org_weight"):
org_module.weight.copy_(org_module._lora_org_weight)
del org_module._lora_org_weight
anyio==3.7.1
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
async-lru==2.0.3
attrs==23.1.0
Babel==2.12.1
backcall==0.2.0
beautifulsoup4==4.12.2
bleach==6.0.0
blinker==1.4
certifi==2022.12.7
cffi==1.15.1
charset-normalizer==2.1.1
cmake==3.25.0
comm==0.1.3
cryptography==3.4.8
dbus-python==1.2.18
debugpy==1.6.7
decorator==5.1.1
defusedxml==0.7.1
diffusers==0.19.2
distro==1.7.0
einops==0.6.1
exceptiongroup==1.1.2
executing==1.2.0
fastjsonschema==2.17.1
filelock==3.9.0
fqdn==1.5.1
fsspec==2023.6.0
httplib2==0.20.2
huggingface-hub==0.16.4
idna==3.4
importlib-metadata==4.6.4
invisible-watermark==0.2.0
ipykernel==6.24.0
ipython==8.14.0
ipython-genutils==0.2.0
ipywidgets==8.0.7
isoduration==20.11.0
jedi==0.18.2
jeepney==0.7.1
Jinja2==3.1.2
json5==0.9.14
jsonpointer==2.4
jsonschema==4.18.0
jsonschema-specifications==2023.6.1
jupyter-archive==3.3.4
jupyter-contrib-core==0.4.2
jupyter-contrib-nbextensions==0.7.0
jupyter-events==0.6.3
jupyter-highlight-selected-word==0.2.0
jupyter-lsp==2.2.0
jupyter-nbextensions-configurator==0.6.3
jupyter_client==8.3.0
jupyter_core==5.3.1
jupyter_server==2.7.0
jupyter_server_terminals==0.4.4
jupyterlab==4.0.2
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.8
jupyterlab_server==2.23.0
keyring==23.5.0
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
lit==15.0.7
losalina==1.0.0
lxml==4.9.3
lycoris-lora==1.8.0
MarkupSafe==2.1.2
matplotlib-inline==0.1.6
mistune==3.0.1
more-itertools==8.10.0
mpmath==1.2.1
nbclassic==1.0.0
nbclient==0.8.0
nbconvert==7.6.0
nbformat==5.9.1
nest-asyncio==1.5.6
networkx==3.0
notebook==6.5.4
notebook_shim==0.2.3
numpy==1.24.1
oauthlib==3.2.0
opencv-python==4.8.0.74
overrides==7.3.1
packaging==23.1
pandocfilters==1.5.0
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.3.0
platformdirs==3.8.1
prometheus-client==0.17.0
prompt-toolkit==3.0.39
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
pycparser==2.21
Pygments==2.15.1
PyGObject==3.42.1
PyJWT==2.3.0
pyparsing==2.4.7
python-apt==2.4.0+ubuntu1
python-dateutil==2.8.2
python-json-logger==2.0.7
PyWavelets==1.4.1
PyYAML==6.0
pyzmq==25.1.0
referencing==0.29.1
regex==2023.6.3
requests==2.28.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.8.10
safetensors==0.3.1
SecretStorage==3.3.1
Send2Trash==1.8.2
six==1.16.0
sniffio==1.3.0
soupsieve==2.4.1
stack-data==0.6.2
sympy==1.11.1
terminado==0.17.1
timm==0.9.2
tinycss2==1.2.1
tokenizers==0.13.3
tomli==2.0.1
torch==2.0.1+cu118
torchaudio==2.0.2+cu118
torchvision==0.15.2+cu118
tornado==6.3.2
tqdm==4.65.0
traitlets==5.9.0
transformers==4.31.0
triton==2.0.0
typing_extensions==4.4.0
uri-template==1.3.0
urllib3==1.26.13
wadllib==1.3.6
wcwidth==0.2.6
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.1
widgetsnbextension==4.0.8
zipp==1.0.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment