Skip to content

Instantly share code, notes, and snippets.

@muellerzr
Created September 15, 2023 18:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save muellerzr/8523a9f2868898c839b4ea3163594aa0 to your computer and use it in GitHub Desktop.
Save muellerzr/8523a9f2868898c839b4ea3163594aa0 to your computer and use it in GitHub Desktop.
Model memory stuff
import torch
from transformers import AutoModel, AutoConfig, AutoModelForSequenceClassification
def get_model_memory(model: torch.nn.Module):
"""
Returns the memory usage of the given model
"""
total_memory = 0
for param in model.parameters():
total_memory += param.numel() * param.element_size()
return total_memory
class ActivationCounter:
"""Helper class to count the number of activations in a model."""
def __init__(self):
self.activation_bytes = 0
def add_activations(self, tensor):
self.activation_bytes += tensor.numel() * tensor.element_size()
def add_activation_bytes(self, bytes):
self.activation_bytes += bytes
def activation_counter_hook(counter: ActivationCounter):
"""Returns a hook that counts the number of activations."""
def hook(self, _, output):
if self.__class__.__name__ == "Dropout":
# for dropout layers, we only need to store the mask
counter.add_activation_bytes(output.data.numel())
else:
if isinstance(output, tuple):
for o in output:
if isinstance(o, torch.Tensor):
counter.add_activations(o.data)
elif isinstance(o, tuple):
for o2 in o:
counter.add_activations(o2.data)
elif isinstance(output, torch.Tensor):
counter.add_activations(output.data)
return hook
def register_hooks_recursive(model: torch.nn.Module, counter: ActivationCounter):
"""Recursively injects activation counting hooks into the given model."""
for module in model.children():
module.register_forward_hook(activation_counter_hook(counter))
register_hooks_recursive(module, counter)
import math
def format_size(size_bytes):
"""
Converts the given size in bytes to a human readable format
Reference: https://stackoverflow.com/questions/5194057/better-way-to-convert-file-sizes-in-python
"""
if size_bytes == 0:
return "0B"
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
i = int(math.floor(math.log(size_bytes, 1024)))
p = math.pow(1024, i)
s = round(size_bytes / p, 2)
return "%s %s" % (s, size_name[i])
def get_current_memory_allocation():
return torch.cuda.memory_allocated()
def get_optimizer_memory(model: torch.nn.Module, optimizer: torch.optim.Optimizer):
"""
Returns the memory usage (in bytes) of the given optimizer and model.
Note: Currently only supports SGD, Adam, and AdamW.
"""
model_parameters = sum(param.numel() for param in model.parameters())
bytes_per_param = 0
if type(optimizer) == torch.optim.SGD:
has_momentum = any(param_group.get('momentum', 0) != 0
for param_group in optimizer.param_groups)
if has_momentum:
bytes_per_param = 4
elif type(optimizer) in (torch.optim.Adam, torch.optim.AdamW):
bytes_per_param = 8
else:
raise ValueError(f"Unsupported optimizer: {optimizer}")
return model_parameters * bytes_per_param
def project_transformer_memory(
layers, hidden_size, num_attention_heads,
batch_size, sequence_length, optimizer):
model_memory = 4 * layers * hidden_size * (13 + 12 * hidden_size)
gradient_memory = model_memory
# activation memory formula from: https://arxiv.org/pdf/2205.05198.pdf
activation_memory = layers * batch_size * sequence_length * hidden_size * (
67 + (9*num_attention_heads*sequence_length) / hidden_size
)
optimizer_memory = get_optimizer_memory(model, optimizer)
return model_memory + gradient_memory + activation_memory + optimizer_memory
if __name__ == "__main__":
batch_size = 1
model_name = "bert-base-cased"
# model_name = "bert-base-uncased"
# model_name = "albert-base-v2"
# model_name = "distilbert-base-uncased"
# model_name = "gpt2"
# model_name = "roberta-base"
config = AutoConfig.from_pretrained(model_name)
config.return_dict = True
model = AutoModelForSequenceClassification.from_config(config)
print(f'Model used: {model_name}')
# bert-base-cased should have 12
projected_total_memory = format_size(project_transformer_memory(
config.num_hidden_groups if hasattr(config, "num_hidden_groups") else config.num_hidden_layers,
config.hidden_size,
config.num_attention_heads, batch_size,
config.max_position_embeddings, torch.optim.Adam(model.parameters()))
)
print(f"Projected total memory usage: {projected_total_memory}")
print("-" * 80)
############################################################
## Measure model memory
############################################################
device = "cuda"
model.to(device)
memory_allocation_with_model = get_current_memory_allocation()
estimated_model_memory = get_model_memory(model)
print(f"Measured Model Memory: {format_size(memory_allocation_with_model)}")
print(f"Estimated Model Memory: {format_size(estimated_model_memory)}")
print(f"Percent difference: {abs(memory_allocation_with_model - estimated_model_memory) / estimated_model_memory * 100:.2f}%")
print("-" * 80)
############################################################
## Measure activation memory
############################################################
activation_counter = ActivationCounter()
register_hooks_recursive(model, activation_counter)
batch = {
"labels": torch.tensor([0]).to("cuda"),
"input_ids":torch.randint(0, model.config.max_position_embeddings-1, (batch_size, 64)).to("cuda")
}
outputs = model(batch["input_ids"])
activation_counter.add_activations(batch["input_ids"])
memory_allocation_forward_pass = get_current_memory_allocation() - memory_allocation_with_model
print(f"Consumed Activation Memory: {format_size(memory_allocation_forward_pass)}")
print(f"Estimated Activation Memory: {format_size(activation_counter.activation_bytes)}")
print(f"Percent difference: {abs(memory_allocation_forward_pass - activation_counter.activation_bytes) / activation_counter.activation_bytes * 100:.2f}%")
print("-" * 80)
############################################################
## Measure gradient memory
############################################################
loss_fn = torch.nn.MSELoss()
labels = torch.randn_like(outputs.logits).to(device)
labels_size = labels.numel() * labels.element_size()
outputs_size = outputs.logits.numel() * outputs.logits.element_size()
loss = loss_fn(outputs.logits, labels)
loss.backward(retain_graph=True)
memory_allocation_with_gradients = get_current_memory_allocation() - memory_allocation_with_model - memory_allocation_forward_pass
estimated_gradient_memory = estimated_model_memory
print(f"Consumed Gradient Memory: {format_size(memory_allocation_with_gradients)}" )
print(f"Estimated Gradient Memory: {format_size(estimated_gradient_memory)}")
print(f"Percent difference: {abs(memory_allocation_with_gradients - estimated_gradient_memory) / (estimated_gradient_memory) * 100:.2f}%")
print("-" * 80)
############################################################
## Measure optimizer memory
############################################################
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer.step()
post_step_memory = get_current_memory_allocation() - memory_allocation_with_model - memory_allocation_forward_pass - memory_allocation_with_gradients
estimated_optimizer_memory = get_optimizer_memory(model, optimizer)
print(f"Consumed Optimizer + Gradient Memory: {format_size(post_step_memory)}" )
print(f"Estimated Optimizer + Gradient Memory: {format_size(estimated_optimizer_memory)}")
print(f"Percent difference: {abs(post_step_memory - estimated_optimizer_memory) / (estimated_optimizer_memory) * 100:.2f}%")
print("-" * 80)
print(f"Actual total memory usage: {format_size(get_current_memory_allocation())}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment