Skip to content

Instantly share code, notes, and snippets.

@bitplane
Last active September 17, 2022 23:22
Show Gist options
  • Save bitplane/5d11a3c78bc0105b23d3bf83939db415 to your computer and use it in GitHub Desktop.
Save bitplane/5d11a3c78bc0105b23d3bf83939db415 to your computer and use it in GitHub Desktop.
Stable diffusion keyword inspection

If we could take a low res sample of the network's hidden layers, with the right sampling algorithm it might be possible to compare them and see how they relate to each other.

If so, it might tell us all kinds of useful info about the labels used in the training data, let us visualise them, group, graph, find interesting intersections, areas of the network that seem suboptimal, or use it to write prompt preprocessor and weight translators, link bad prompts to weird combination effects and so on.

Crudest example idea at the moment but it's my first thought in this space.

fingerprint the network while it's running

Every G generations, and maybe combine these somehow

# get the features out of the latent space
feature_groups = [net.features for net in latent_nets]
features = chain(features)

# split into groups
bucket_count = 512
bucket_width = len(features) / bucket_count
grouper = array.split # might need to be better
groups = grouper(features, size)

# downsample to a 1k fingerprint
aggregate = lambda group: sum(group) / group_size
bucket = lambda statistic: int(aggregate(group) * 65535)
buckets = [bucket(group) for group in groups]

store fingerprints for each group of words in the phrase

class Fingerprint:
    def __init__(self, ngram, data=None, count=0):
        self.data = data or data[0] * bucket_count
        self.count = count
    
    load = save = NotImplemented

def get_ngrams(words):
    ngrams = []
    for length in range(1, len(words)+1):
        ngrams_this_long = sliding_window(words, length)
        ngrams.extend(ngrams_this_long)
    
    return ngrams

words = phrase.split()
ngrams = get_ngrams(words)

for ngram in ngrams:
    fingerprint = Fingerprint.load(ngram)

    # weight function might need a better choice
    weight = len(ngram) / len(word)

    for idx, score in enumerate(buckets):
        fingerprint.data[idx] += score * weight

    fingerprint.count += weight
    fingerprint.save()

then what?

  • Explore better grouper, aggregate, weight and reduce functions. Possibly using ML itself.
  • Try to find stability over image size and permutations.
  • Get real world phrases to test with.
  • Prevent duplicate (seed, ngram) pairs from skewing data.
  • Filter ngrams before fingerprints run out of precision. Statistically? Max? Store as 0..1, start at 0.5 and convert to a bitmap + mask of convergent/chaotic regions?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment