Skip to content

Instantly share code, notes, and snippets.

@atiorh
Last active November 21, 2023 18:52
Show Gist options
  • Save atiorh/f90018fafe96d4116898a3cc0f85c751 to your computer and use it in GitHub Desktop.
Save atiorh/f90018fafe96d4116898a3cc0f85c751 to your computer and use it in GitHub Desktop.
Activation Sparsity in LLMs
# !pip install transformers sentencepiece
import torch
import torch.nn as nn
torch.set_grad_enabled(False)
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.activations import ReLUSquaredActivation
from collections import defaultdict, OrderedDict
import numpy as np
# https://www.adept.ai/blog/persimmon-8b: (Footnote 2)
# In contrast to the more standard SwiGLU and GeLU activations,
# the squared ReLU often results in output activations consisting of 90+% zeros.
# This provides interesting opportunities for inference
model_version = "adept/persimmon-8b-chat"
torch_device = "mps"
# Create model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_version,
torch_dtype=torch.float16
).to(torch_device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_version, use_fast=False)
# Test case to verify accurate model inference
MAX_NEW_TOKENS = 8
PROMPT = "human: Simply put, the theory of relativity states that?\n\nadept:"
EXPECTED_GREEDY_COMPLETION = \
"|ENDOFTEXT|human: Simply put, the theory of relativity states that?\n\nadept: " \
"The theory of relativity states that the laws of physics are the same"
# Register activation sparsity loggers for ReLUSquaredActivation modules to verify the 90%+ claim
# in Persimmon release notes. Also register it for nn.Linear modules to measure the residual sparsity
# (if any) that trickles down from the latest preceeding ReLUSquaredActivation
activation_sparsity_stats = defaultdict(list)
module_types_to_hook = (ReLUSquaredActivation, nn.Linear)
hash_to_name = {
hash(module): name for name, module in model.named_modules()
if isinstance(module, module_types_to_hook)
}
def activation_sparsity_logger_hook(module, input, output):
if isinstance(module, module_types_to_hook):
activation_sparsity_stats[hash_to_name[hash(module)]].append({
'input': input[0].eq(0),
'output': output.eq(0.),
})
def register_hook(module):
if isinstance(module, module_types_to_hook):
module._forward_hooks = OrderedDict()
module.register_forward_hook(activation_sparsity_logger_hook)
_ = model.apply(register_hook)
inputs = tokenizer(PROMPT, return_tensors="pt")["input_ids"].to(torch_device)
outputs = model.eval().generate(
inputs,
max_new_tokens=MAX_NEW_TOKENS,
return_dict_in_generate=True,
output_scores=True,
top_p=1.0,
top_k=0.,
temperature=1.0,
do_sample=False,
)
generated_text = tokenizer.decode(outputs.sequences[0])
print(f"Generated text: {generated_text}")
assert generated_text == EXPECTED_GREEDY_COMPLETION
for idx,(module_name, stats) in enumerate(activation_sparsity_stats.items()):
input_sparsity_vectors = [stats[i]["input"][0, 0].cpu().float() for i in range(MAX_NEW_TOKENS)]
output_sparsity_vectors = [stats[i]["output"][0, 0].cpu().float() for i in range(MAX_NEW_TOKENS)]
input_sparsity = sum([isv.mean().item() for isv in input_sparsity_vectors]) / MAX_NEW_TOKENS
output_sparsity = sum([osv.mean().item() for osv in output_sparsity_vectors]) / MAX_NEW_TOKENS
# Print input and output activation sparsities
print(
f"`{module_name}`: \t\t\t input sparsity= {input_sparsity:.2f} \t\t| "
f"output sparsity={output_sparsity:2f}"
)
if input_sparsity > 0.5:
# Print shared sparsity % across tokens (how repeatable are these sparsity patterns?)
common_sparsity_percent = torch.stack(input_sparsity_vectors, dim=0).all(dim=0)
print(f"Shared sparsity across tokens: {common_sparsity_percent.float().mean():.2f}\n")
@atiorh
Copy link
Author

atiorh commented Nov 9, 2023

Prints:

Activation Sparsity Statistics:
`model.layers.0.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.0.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.0.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.0.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.97
`model.layers.0.mlp.dense_4h_to_h`: 		 input sparsity= 0.97 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.82

`model.layers.1.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.1.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.1.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.1.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.96
`model.layers.1.mlp.dense_4h_to_h`: 		 input sparsity= 0.96 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.74

`model.layers.2.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.2.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.2.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.2.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.95
`model.layers.2.mlp.dense_4h_to_h`: 		 input sparsity= 0.95 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.74

`model.layers.3.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.3.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.3.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.3.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.95
`model.layers.3.mlp.dense_4h_to_h`: 		 input sparsity= 0.95 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.77

`model.layers.4.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.4.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.4.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.4.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.96
`model.layers.4.mlp.dense_4h_to_h`: 		 input sparsity= 0.96 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.81

`model.layers.5.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.5.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.5.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.5.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.96
`model.layers.5.mlp.dense_4h_to_h`: 		 input sparsity= 0.96 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.78

`model.layers.6.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.6.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.6.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.6.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.96
`model.layers.6.mlp.dense_4h_to_h`: 		 input sparsity= 0.96 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.79

`model.layers.7.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.7.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.7.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.7.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.96
`model.layers.7.mlp.dense_4h_to_h`: 		 input sparsity= 0.96 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.80

`model.layers.8.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.8.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.8.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.8.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.95
`model.layers.8.mlp.dense_4h_to_h`: 		 input sparsity= 0.95 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.76

`model.layers.9.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.9.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.9.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.9.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.94
`model.layers.9.mlp.dense_4h_to_h`: 		 input sparsity= 0.94 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.74

`model.layers.10.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.10.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.10.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.10.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.93
`model.layers.10.mlp.dense_4h_to_h`: 		 input sparsity= 0.93 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.71

`model.layers.11.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.11.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.11.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.11.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.90
`model.layers.11.mlp.dense_4h_to_h`: 		 input sparsity= 0.90 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.64

`model.layers.12.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.12.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.12.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.12.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.89
`model.layers.12.mlp.dense_4h_to_h`: 		 input sparsity= 0.89 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.63

`model.layers.13.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.13.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.13.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.13.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.89
`model.layers.13.mlp.dense_4h_to_h`: 		 input sparsity= 0.89 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.61

`model.layers.14.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.14.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.14.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.14.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.87
`model.layers.14.mlp.dense_4h_to_h`: 		 input sparsity= 0.87 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.57

`model.layers.15.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.15.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.15.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.15.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.88
`model.layers.15.mlp.dense_4h_to_h`: 		 input sparsity= 0.88 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.61

`model.layers.16.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.16.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.16.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.16.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.88
`model.layers.16.mlp.dense_4h_to_h`: 		 input sparsity= 0.88 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.62

`model.layers.17.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.17.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.17.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.17.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.88
`model.layers.17.mlp.dense_4h_to_h`: 		 input sparsity= 0.88 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.63

`model.layers.18.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.18.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.18.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.18.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.89
`model.layers.18.mlp.dense_4h_to_h`: 		 input sparsity= 0.89 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.66

`model.layers.19.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.19.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.19.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.19.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.90
`model.layers.19.mlp.dense_4h_to_h`: 		 input sparsity= 0.90 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.68

`model.layers.20.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.20.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.20.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.20.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.91
`model.layers.20.mlp.dense_4h_to_h`: 		 input sparsity= 0.91 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.69

`model.layers.21.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.21.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.21.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.21.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.92
`model.layers.21.mlp.dense_4h_to_h`: 		 input sparsity= 0.92 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.73

`model.layers.22.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.22.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.22.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.22.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.92
`model.layers.22.mlp.dense_4h_to_h`: 		 input sparsity= 0.92 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.72

`model.layers.23.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.23.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.23.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.23.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.92
`model.layers.23.mlp.dense_4h_to_h`: 		 input sparsity= 0.92 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.74

`model.layers.24.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.24.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.24.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.24.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.93
`model.layers.24.mlp.dense_4h_to_h`: 		 input sparsity= 0.93 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.76

`model.layers.25.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.25.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.25.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.25.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.94
`model.layers.25.mlp.dense_4h_to_h`: 		 input sparsity= 0.94 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.77

`model.layers.26.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.26.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.26.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.26.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.94
`model.layers.26.mlp.dense_4h_to_h`: 		 input sparsity= 0.94 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.77

`model.layers.27.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.27.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.27.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.27.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.93
`model.layers.27.mlp.dense_4h_to_h`: 		 input sparsity= 0.93 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.75

`model.layers.28.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.28.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.28.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.28.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.94
`model.layers.28.mlp.dense_4h_to_h`: 		 input sparsity= 0.94 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.78

`model.layers.29.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.29.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.29.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.29.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.94
`model.layers.29.mlp.dense_4h_to_h`: 		 input sparsity= 0.94 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.76

`model.layers.30.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.30.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.30.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.30.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.93
`model.layers.30.mlp.dense_4h_to_h`: 		 input sparsity= 0.93 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.73

`model.layers.31.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.31.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.31.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.31.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.92
`model.layers.31.mlp.dense_4h_to_h`: 		 input sparsity= 0.92 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.70

`model.layers.32.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.32.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.32.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.32.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.91
`model.layers.32.mlp.dense_4h_to_h`: 		 input sparsity= 0.91 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.67

`model.layers.33.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.33.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.33.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.33.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.87
`model.layers.33.mlp.dense_4h_to_h`: 		 input sparsity= 0.87 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.56

`model.layers.34.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.34.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.34.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.34.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.82
`model.layers.34.mlp.dense_4h_to_h`: 		 input sparsity= 0.82 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.47

`model.layers.35.self_attn.query_key_value`: 	 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.35.self_attn.dense`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.35.mlp.dense_h_to_4h`: 		 input sparsity= 0.00 	 | 	 output sparsity=0.00
`model.layers.35.mlp.act`: 			 input sparsity= 0.00 	 | 	 output sparsity=0.87
`model.layers.35.mlp.dense_4h_to_h`: 		 input sparsity= 0.87 	 | 	 output sparsity=0.00
Shared sparsity across tokens: 0.57

`lm_head`: 					 input sparsity= 0.00 	 | 	 output sparsity=0.00

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment