Last active
November 21, 2023 18:52
-
-
Save atiorh/f90018fafe96d4116898a3cc0f85c751 to your computer and use it in GitHub Desktop.
Activation Sparsity in LLMs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# !pip install transformers sentencepiece | |
import torch | |
import torch.nn as nn | |
torch.set_grad_enabled(False) | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers.activations import ReLUSquaredActivation | |
from collections import defaultdict, OrderedDict | |
import numpy as np | |
# https://www.adept.ai/blog/persimmon-8b: (Footnote 2) | |
# In contrast to the more standard SwiGLU and GeLU activations, | |
# the squared ReLU often results in output activations consisting of 90+% zeros. | |
# This provides interesting opportunities for inference | |
model_version = "adept/persimmon-8b-chat" | |
torch_device = "mps" | |
# Create model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained( | |
model_version, | |
torch_dtype=torch.float16 | |
).to(torch_device).eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_version, use_fast=False) | |
# Test case to verify accurate model inference | |
MAX_NEW_TOKENS = 8 | |
PROMPT = "human: Simply put, the theory of relativity states that?\n\nadept:" | |
EXPECTED_GREEDY_COMPLETION = \ | |
"|ENDOFTEXT|human: Simply put, the theory of relativity states that?\n\nadept: " \ | |
"The theory of relativity states that the laws of physics are the same" | |
# Register activation sparsity loggers for ReLUSquaredActivation modules to verify the 90%+ claim | |
# in Persimmon release notes. Also register it for nn.Linear modules to measure the residual sparsity | |
# (if any) that trickles down from the latest preceeding ReLUSquaredActivation | |
activation_sparsity_stats = defaultdict(list) | |
module_types_to_hook = (ReLUSquaredActivation, nn.Linear) | |
hash_to_name = { | |
hash(module): name for name, module in model.named_modules() | |
if isinstance(module, module_types_to_hook) | |
} | |
def activation_sparsity_logger_hook(module, input, output): | |
if isinstance(module, module_types_to_hook): | |
activation_sparsity_stats[hash_to_name[hash(module)]].append({ | |
'input': input[0].eq(0), | |
'output': output.eq(0.), | |
}) | |
def register_hook(module): | |
if isinstance(module, module_types_to_hook): | |
module._forward_hooks = OrderedDict() | |
module.register_forward_hook(activation_sparsity_logger_hook) | |
_ = model.apply(register_hook) | |
inputs = tokenizer(PROMPT, return_tensors="pt")["input_ids"].to(torch_device) | |
outputs = model.eval().generate( | |
inputs, | |
max_new_tokens=MAX_NEW_TOKENS, | |
return_dict_in_generate=True, | |
output_scores=True, | |
top_p=1.0, | |
top_k=0., | |
temperature=1.0, | |
do_sample=False, | |
) | |
generated_text = tokenizer.decode(outputs.sequences[0]) | |
print(f"Generated text: {generated_text}") | |
assert generated_text == EXPECTED_GREEDY_COMPLETION | |
for idx,(module_name, stats) in enumerate(activation_sparsity_stats.items()): | |
input_sparsity_vectors = [stats[i]["input"][0, 0].cpu().float() for i in range(MAX_NEW_TOKENS)] | |
output_sparsity_vectors = [stats[i]["output"][0, 0].cpu().float() for i in range(MAX_NEW_TOKENS)] | |
input_sparsity = sum([isv.mean().item() for isv in input_sparsity_vectors]) / MAX_NEW_TOKENS | |
output_sparsity = sum([osv.mean().item() for osv in output_sparsity_vectors]) / MAX_NEW_TOKENS | |
# Print input and output activation sparsities | |
print( | |
f"`{module_name}`: \t\t\t input sparsity= {input_sparsity:.2f} \t\t| " | |
f"output sparsity={output_sparsity:2f}" | |
) | |
if input_sparsity > 0.5: | |
# Print shared sparsity % across tokens (how repeatable are these sparsity patterns?) | |
common_sparsity_percent = torch.stack(input_sparsity_vectors, dim=0).all(dim=0) | |
print(f"Shared sparsity across tokens: {common_sparsity_percent.float().mean():.2f}\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Prints: