Skip to content

Instantly share code, notes, and snippets.

@torridgristle
Last active October 1, 2022 14:21
Show Gist options
  • Save torridgristle/55fef868d68cfa9a7e3449f9d4d1808b to your computer and use it in GitHub Desktop.
Save torridgristle/55fef868d68cfa9a7e3449f9d4d1808b to your computer and use it in GitHub Desktop.
Generate every combination of prompt parts, encode all of the prompts in batches to avoid running out of memory. Alternatively only keep the min/max channel values and min/max token norms and randomly generate prompts with randn noise. Intended for Stable Diffusion but can be used for anything with CLIP by just swapping out the model.get_learned…
import itertools
def prompt_combinations(prompt_parts):
'''
Provide a list of lists of prompt parts, like:
[ ["A ","An "], ["anteater","feather duster"] ]
'''
opt_prompt = list(itertools.product(*prompt_parts, repeat=1))
opt_prompt = [''.join(opt_prompt[b]) for b in range(len(opt_prompt))]
return opt_prompt
def encode_all_prompts(opt_prompt):
with torch.no_grad():
with torch.autocast("cuda", cache_enabled=True):
with model.ema_scope():
c_all = []
for b in range(math.ceil(len(opt_prompt)/64)):
c_all.append(model.get_learned_conditioning(opt_prompt[b*64:(b+1)*64]))
c_all = torch.cat(c_all)#.cpu()
return c_all
def encode_all_prompts_stats(opt_prompt):
with torch.no_grad():
with torch.autocast("cuda", cache_enabled=True):
with model.ema_scope():
max_ch = None
min_ch = None
max_norm = None
min_norm = None
for b in range(math.ceil(len(opt_prompt)/64)):
x = model.get_learned_conditioning(opt_prompt[b*64:(b+1)*64])
if max_ch != None:
x = torch.cat([x,max_ch,min_ch],0)
max_ch = x.max(0,keepdim=True).values
min_ch = x.min(0,keepdim=True).values
norm_token = x.norm(2,-1,keepdim=True)
if max_norm != None:
norm_token = torch.cat([norm_token,max_norm,min_norm],0)
max_norm = norm_token.max(0,keepdim=True).values
min_norm = norm_token.min(0,keepdim=True).values
return max_ch, min_ch, max_norm, min_norm
def token_stats(x):
max_ch = x.max(0,keepdim=True).values
min_ch = x.min(0,keepdim=True).values
norm_token = x.norm(2,-1,keepdim=True)
max_norm = norm_token.max(0,keepdim=True).values
min_norm = norm_token.min(0,keepdim=True).values
return max_ch, min_ch, max_norm, min_norm
def match_token_stats(x, max_ch, min_ch, max_norm, min_norm, eps=1e-6):
ch_out = torch.lerp(min_ch,max_ch, x)
ch_out_norm = ch_out.norm(2,-1,keepdim=True)
ch_out = ch_out / ch_out_norm.add(eps)
ch_out_norm = torch.rand([x.shape[0],x.shape[1],1],device=x.device).to(x.dtype)
ch_out_norm = torch.lerp(min_norm,max_norm,ch_out_norm)
ch_out = ch_out * ch_out_norm
return ch_out
def match_token_stats_simple(x):
return match_token_stats(torch.rand([1,77,768],device=x[0].device,dtype=x[0].dtype), *x)
#Example: match_token_stats(torch.rand([4,3,6]), *token_stats(torch.randn([4,3,6])))
#Example: match_token_stats_simple(token_stats)
import clip
tokenizer = clip.simple_tokenizer.SimpleTokenizer()
def token_check(x):
word_tokens = tokenizer.encode(x)
print(word_tokens)
print(tokenizer.decode(word_tokens))
if len(word_tokens) > 1:
print(list(str(i) + " " + tokenizer.decode([word_tokens[i]]) for i in range(len(word_tokens))))
print(len(word_tokens))
#Example: token_check("john carpenter, David Cronenberg, David Lynch, Clive Barker")
#Example:
opt_prompt = prompt_combinations([
["eldritch creature, practical effects, imdb"],
[", iso "],
["100","200","400","800"],
[", danny devito, body horror, metamorphosis carapace cocoon"],
[", john carpenter",", Cronenberg",", David Lynch",", Clive Barker"],
])
c_all = encode_all_prompts(opt_prompt)
c_stats = token_stats(c_all)
c = match_token_stats_simple(c_stats)
# Or to avoid encoding all the prompts and keeping them all in memory, just keep the stats from the prompts
c_stats = encode_all_prompts_stats(opt_prompt)
c = match_token_stats_simple(c_stats)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment