Skip to content

Instantly share code, notes, and snippets.

@chadbrewbaker
Created April 2, 2023 16:32
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save chadbrewbaker/ffe95290fc945af63611693688dfe54d to your computer and use it in GitHub Desktop.
Save chadbrewbaker/ffe95290fc945af63611693688dfe54d to your computer and use it in GitHub Desktop.
ParallelGPT

Asking GPT3.5 to refactor gpt2() into a parallel prefix over inputs with code from:

https://github.com/jaymody/picoGPT/blob/3b7f4d180bb125a76b2f4f7b3a74268e5ec5f131/gpt2_pico.py#L37

def gpt2(inputs, wte, wpe, blocks, ln_f, n_head, chunk_size):
    n_chunks = len(inputs) // chunk_size
    chunked_inputs = np.split(inputs, n_chunks)
    embeddings = []
    for i, chunk in enumerate(chunked_inputs):
        emb_chunk = wte[chunk] + wpe[range(i*chunk_size, (i+1)*chunk_size)]
        embeddings.append(emb_chunk)
    embeddings = np.concatenate(embeddings, axis=0)

    for block in blocks:
        attn_outputs = []
        for i, chunk in enumerate(np.split(embeddings, n_chunks)):
            attn_chunk = transformer_block(chunk, **block, n_head=n_head)
            attn_outputs.append(attn_chunk)
        attn_outputs = np.concatenate(attn_outputs, axis=0)

        attn_outputs_prefix = np.zeros_like(attn_outputs)
        for i in range(1, n_chunks):
            attn_outputs_prefix[i*chunk_size:(i+1)*chunk_size] = \
                attn_outputs_prefix[(i-1)*chunk_size:i*chunk_size] + \
                attn_outputs[i*chunk_size:(i+1)*chunk_size]

        embeddings = embeddings + attn_outputs_prefix

    final_output = layer_norm(embeddings, **ln_f) @ wte.T
    return final_output
@chadbrewbaker
Copy link
Author

GPT3.5 suggestion to permute the token space given a list of prompts to speed inference.

def permute_tokens(prompts, gpt_model, block_size):
    # Concatenate the prompts into a single sequence
    combined_sequence = ''.join(prompts)

    # Compute the frequency of occurrence for each token
    token_frequencies = Counter(combined_sequence)

    # Sort the tokens by frequency of occurrence
    sorted_tokens = sorted(token_frequencies.keys(), key=lambda x: token_frequencies[x], reverse=True)

    # Divide the sorted tokens into contiguous segments
    token_segments = [sorted_tokens[i:i+block_size] for i in range(0, len(sorted_tokens), block_size)]

    # Permute the tokens within each segment
    for segment in token_segments:
        # Implement cache-aware permutation within each segment
        segment.sort(key=lambda x: combined_sequence.index(x))

    # Permute the segments themselves
    permuted_segments = permute_segments(token_segments)

    # Compute the token IDs for the permuted sequence
    token_ids = [gpt_model.vocab[token] for token in itertools.chain(*permuted_segments)]

    return token_ids

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment