Skip to content

Instantly share code, notes, and snippets.

@crosstyan
Created October 10, 2022 01:43
Show Gist options
  • Save crosstyan/a2dd8ba6e479e6002c553acd7ee050b5 to your computer and use it in GitHub Desktop.
Save crosstyan/a2dd8ba6e479e6002c553acd7ee050b5 to your computer and use it in GitHub Desktop.
source from discord SD Training Labs.
modules\prompt_parser.py file.
v2.pt can be loaded by putting it in the main folder of the repo and adding
---------------------------------------------------------------------------
import torch
from torch import nn
from modules import devices
class VectorAdjustPrior(nn.Module):
def __init__(self, hidden_size, inter_dim=64):
super().__init__()
self.vector_proj = nn.Linear(hidden_size*2, inter_dim, bias=True)
self.out_proj = nn.Linear(hidden_size+inter_dim, hidden_size, bias=True)
def forward(self, z):
b, s = z.shape[0:2]
x1 = torch.mean(z, dim=1).repeat(s, 1)
x2 = z.reshape(b*s, -1)
x = torch.cat((x1, x2), dim=1)
x = self.vector_proj(x)
x = torch.cat((x2, x), dim=1)
x = self.out_proj(x)
x = x.reshape(b, s, -1)
return x
@classmethod
def load_model(cls, model_path, hidden_size=768, inter_dim=64):
model = cls(hidden_size=hidden_size, inter_dim=inter_dim)
model.load_state_dict(torch.load(model_path)["state_dict"])
model.to(devices.device)
return model
vap = VectorAdjustPrior.load_model('v2.pt').cuda()
------------------------------------------------------------------
after the
----------------------------
import lark
---------------------------
line and adding
---------------------
conds = vap(conds)
---------------------
between
---------------------------------------------
conds = model.get_learned_conditioning(texts)
---------------------------------------------
and
---------------------------------------------
cond_schedule = [] lines.
---------------------------------------------
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment