Last active
January 8, 2024 02:47
-
-
Save thistleknot/48b4551737e72c039abf20b6b913b043 to your computer and use it in GitHub Desktop.
long mistral 2-2b
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import argparse | |
from transformers import MistralConfig, AutoModelForCausalLM | |
import torch | |
import sys | |
import os | |
def calculate_model_parameters(config): | |
# Load the model configuration from the JSON file | |
# Extract the necessary values from the configuration | |
vocab_size = config['vocab_size'] | |
# Extract the necessary values from the configuration | |
hidden_size = config['hidden_size'] | |
intermediate_size = config['intermediate_size'] | |
num_attention_heads = config['num_attention_heads'] | |
num_hidden_layers = config['num_hidden_layers'] | |
# Calculate parameters for the embedding layer | |
parameters_embedding = vocab_size * hidden_size | |
# Parameters for the multi-head attention mechanism in each transformer block | |
size_per_head = hidden_size // num_attention_heads | |
parameters_attention_per_layer = ( | |
3 * num_attention_heads * size_per_head * size_per_head # QKV matrices | |
+ hidden_size * hidden_size # Output projection | |
) | |
# Parameters for the feedforward network at each transformer block | |
parameters_ffn_per_layer = ( | |
hidden_size * intermediate_size # First linear layer | |
+ intermediate_size * hidden_size # Second linear layer | |
) | |
# Parameters for layer normalization (usually 2 per block) | |
parameters_ln_per_layer = 2 * 2 * hidden_size | |
# Total parameters for each transformer block | |
parameters_per_block = ( | |
parameters_attention_per_layer + parameters_ffn_per_layer + parameters_ln_per_layer | |
) | |
# Total parameters in all transformer blocks | |
parameters_transformer_blocks = parameters_per_block * num_hidden_layers | |
# Total parameters in the model (embeddings + transformer blocks) | |
# The output layer often uses weight tying with the input embedding layer, so it's not counted separately. | |
total_parameters = parameters_embedding + parameters_transformer_blocks | |
return parameters_embedding, parameters_transformer_blocks, total_parameters | |
def main(): | |
# Initialize the parser to accept a file path input | |
parser = argparse.ArgumentParser(description="Calculate model parameters from config file.") | |
parser.add_argument('config_file', type=str, help="Path to the configuration file.") | |
# Parse the arguments provided | |
args = parser.parse_args() | |
# Load the model configuration from the file provided in the argument | |
with open(args.config_file, 'r') as file: | |
config = json.load(file) | |
# Calculate the parameters based on the loaded configuration | |
embedding_params, transformer_block_params, total_params = calculate_model_parameters(config) | |
# Print the results | |
print(f"Parameters in Embedding: {embedding_params}") | |
print(f"Parameters in Transformer Blocks: {transformer_block_params}") | |
print(f"Total Parameters: {total_params}") | |
# Extract the base filename without the .json extension | |
base_filename = os.path.splitext(os.path.basename(args.config_file))[0] | |
# Convert total parameters to billions and format the string | |
param_count_in_billions = total_params / 1e9 # Convert to 'B' (billions) | |
formatted_param_count = f"{param_count_in_billions:.2f}B" # Format to 2 decimal places | |
# Construct the new filename based on the original config file | |
# Convert the loaded dictionary to a MistralConfig object | |
model_config = MistralConfig( | |
hidden_size=config.get('hidden_size'), | |
intermediate_size=config.get('intermediate_size'), | |
num_hidden_layers=config.get('num_hidden_layers'), | |
num_attention_heads=config.get('num_attention_heads'), | |
num_key_value_heads=config.get('num_key_value_heads'), | |
tie_word_embeddings=config.get('tie_word_embeddings'), | |
max_position_embeddings=config.get('max_position_embeddings'), | |
sliding_window=config.get('sliding_window'), | |
rope_theta=config.get('rope_theta') | |
) | |
model_filename = f"{base_filename}_{formatted_param_count}" | |
model = AutoModelForCausalLM.from_config(model_config, torch_dtype=torch.bfloat16) | |
#this config results in 3.75B parameters. | |
with torch.no_grad(): | |
#for name, param in model.named_parameters(): | |
#param.data = torch.zeros(size=param.size(), dtype=param.dtype) | |
model.save_pretrained(model_filename) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
#modified after reviewing mistrallite, phi, open_llama_3b_v2, tiny_llama contrasted with mistral_7b_v0.1