Skip to content

Instantly share code, notes, and snippets.

@abetlen
Last active June 5, 2024 14:36
Show Gist options
  • Save abetlen/db9f3015e6d5bcc7d00493fa7b368655 to your computer and use it in GitHub Desktop.
Save abetlen/db9f3015e6d5bcc7d00493fa7b368655 to your computer and use it in GitHub Desktop.
import os
import json
import typing
import pathlib
import argparse
import numpy as np
import numpy.typing as npt
import gguf
from gguf import KEY_ATTENTION_HEAD_COUNT, KEY_ATTENTION_LAYERNORM_EPS, KEY_BLOCK_COUNT, KEY_EMBEDDING_LENGTH, KEY_FEED_FORWARD_LENGTH, GGUFWriter, TokenType, SpecialVocab
from safetensors import safe_open
class SafetensorsIndexFile(typing.TypedDict):
weight_map: typing.Dict[str, str]
class SafetensorsIndex:
def __init__(self, index_file_path: str):
directory = os.path.dirname(index_file_path)
self.index = typing.cast(SafetensorsIndexFile, json.load(open(index_file_path)))
self.weight_map = self.index["weight_map"]
files = set(self.weight_map.values())
self.tensors = {file: safe_open(os.path.join(directory, file), framework="np") for file in files}
def get_tensor(self, key: str) -> npt.NDArray[np.float32]:
return typing.cast(npt.NDArray[np.float32], self.tensors[self.weight_map[key]].get_tensor(key)) # type: ignore
def extract_key(raw_key: str, arch: str) -> str:
return raw_key.format(arch=arch)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--dir-model",
required=True,
help="path to directory containing the tokenizer",
)
args = parser.parse_args()
dir_model = pathlib.Path(args.dir_model)
# set model name to folder name
name = dir_model.name
tensors = SafetensorsIndex((dir_model / "model.safetensors.index.json").as_posix())
# Load the model config
config = json.load(open(dir_model / "config.json"))
# text config is based on mistral v0.1
text_config = {
"vocab_size": 32000,
"hidden_size": 4096,
"intermediate_size": 14336,
"num_hidden_layers": 32,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"hidden_act": "silu",
"max_position_embeddings": 4096 * 32,
"rms_norm_eps": 1e-05,
"bos_token_id": 1,
"eos_token_id": 2,
"tie_word_embeddings": False,
"rope_theta": 10000.0,
"sliding_window": 4096
}
text_config.update(config["text_config"])
vision_config = config["vision_config"]
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/configuration_idefics2.py#L129
perceiver_config = config.get("perceiver_config", {
"hidden_act": "silu",
"resampler_n_latents": 64,
"resampler_depth": 3,
"resampler_n_heads": 16,
"resampler_head_dim": 96,
"num_key_value_heads": 4,
"attention_dropout": 0.0,
})
### Vision encoder
ftype = 1 # fp16
fname_out = f"{name}-vision-model-f16.gguf"
fout = GGUFWriter(fname_out, arch="clip")
fout.add_bool("clip.has_text_encoder", False)
fout.add_bool("clip.has_vision_encoder", True)
fout.add_bool("clip.has_llava_projector", True)
fout.add_file_type(ftype)
model_name = "idefics2"
fout.add_name(model_name)
fout.add_description("Vision encoder for " + model_name)
fout.add_string("clip.projector_type", "idefics2")
n_layers_clip = vision_config["num_hidden_layers"]
# vision model hparams
VISION = "clip.vision"
fout.add_uint32("clip.vision.image_size", vision_config["image_size"]) # Update as necessary
fout.add_uint32("clip.vision.patch_size", vision_config["patch_size"]) # Update as necessary
fout.add_uint32(extract_key(KEY_EMBEDDING_LENGTH, VISION), vision_config["hidden_size"])
fout.add_uint32(extract_key(KEY_FEED_FORWARD_LENGTH, VISION), vision_config["intermediate_size"])
fout.add_uint32("clip.vision.projection_dim", 4096) # Update as necessary
fout.add_uint32(extract_key(KEY_ATTENTION_HEAD_COUNT, VISION), vision_config["num_attention_heads"])
fout.add_float32(extract_key(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
fout.add_uint32(extract_key(KEY_BLOCK_COUNT, VISION), n_layers_clip + 1)
fout.add_array("clip.vision.image_mean", [0.5, 0.5, 0.5])
fout.add_array("clip.vision.image_std", [0.5, 0.5, 0.5])
fout.add_bool("clip.use_gelu", True) # using regular GELU instead of quick
# connector
# model.connector
# model.connector.modality_projection.down_proj.weight [4 096, 14 336]
# F32
fout.add_tensor(
"mm.mp.ffn_down.weight",
tensors.get_tensor("model.connector.modality_projection.down_proj.weight").astype(np.float16),
)
# model.connector.modality_projection.gate_proj.weight [14 336, 1 152]
# F32
fout.add_tensor(
"mm.mp.ffn_gate.weight",
tensors.get_tensor("model.connector.modality_projection.gate_proj.weight").astype(np.float16),
)
# model.connector.modality_projection.up_proj.weight [14 336, 1 152]
# F32
fout.add_tensor(
"mm.mp.ffn_up.weight",
tensors.get_tensor("model.connector.modality_projection.up_proj.weight").astype(np.float16),
)
# model.connector.perceiver_resampler.latents [64, 4 096]
# F32
fout.add_tensor(
"mm.pr.latents.weight",
tensors.get_tensor("model.connector.perceiver_resampler.latents").astype(np.float32),
)
# model.connector.perceiver_resampler.norm.weight [4 096]
# F32
fout.add_tensor(
"mm.pr.ln0.weight",
tensors.get_tensor("model.connector.perceiver_resampler.norm.weight").astype(np.float32),
)
for i in range(3):
# model.connector.perceiver_resampler.layers.0.input_context_norm.weight [4 096]
# F32
fout.add_tensor(
f"mm.pr.blk.{i}.ln0.weight",
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.input_context_norm.weight").astype(np.float32),
)
# model.connector.perceiver_resampler.layers.0.input_latents_norm.weight [4 096]
# F32
fout.add_tensor(
f"mm.pr.blk.{i}.ln1.weight",
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.input_latents_norm.weight").astype(np.float32),
)
# model.connector.perceiver_resampler.layers.0.mlp.down_proj.weight [4 096, 16 384]
# F32
fout.add_tensor(
f"mm.pr.blk.{i}.ffn_down.weight",
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.mlp.down_proj.weight").astype(np.float16),
)
# model.connector.perceiver_resampler.layers.0.mlp.gate_proj.weight [16 384, 4 096]
# F32
fout.add_tensor(
f"mm.pr.blk.{i}.ffn_gate.weight",
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.mlp.gate_proj.weight").astype(np.float16),
)
# model.connector.perceiver_resampler.layers.0.mlp.up_proj.weight [16 384, 4 096]
# F32
fout.add_tensor(
f"mm.pr.blk.{i}.ffn_up.weight",
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.mlp.up_proj.weight").astype(np.float16),
)
# model.connector.perceiver_resampler.layers.0.post_attention_layernorm.weight [4 096]
# F32
fout.add_tensor(
f"mm.pr.blk.{i}.ln2.weight",
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.post_attention_layernorm.weight").astype(np.float32),
)
# model.connector.perceiver_resampler.layers.0.self_attn.k_proj.weight [384, 4 096]
# F32
fout.add_tensor(
f"mm.pr.blk.{i}.attn_k.weight",
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.self_attn.k_proj.weight").astype(np.float16),
)
# model.connector.perceiver_resampler.layers.0.self_attn.o_proj.weight [4 096, 1 536]
# F32
fout.add_tensor(
f"mm.pr.blk.{i}.attn_o.weight",
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.self_attn.o_proj.weight").astype(np.float16),
)
# model.connector.perceiver_resampler.layers.0.self_attn.q_proj.weight [1 536, 4 096]
# F32
fout.add_tensor(
f"mm.pr.blk.{i}.attn_q.weight",
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.self_attn.q_proj.weight").astype(np.float16),
)
# model.connector.perceiver_resampler.layers.0.self_attn.v_proj.weight [384, 4 096]
# F32
fout.add_tensor(
f"mm.pr.blk.{i}.attn_v.weight",
tensors.get_tensor(f"model.connector.perceiver_resampler.layers.{i}.self_attn.v_proj.weight").astype(np.float16),
)
# vision_model
fout.add_tensor(
"v.position_embd.weight",
tensors.get_tensor("model.vision_model.embeddings.position_embedding.weight").astype(np.float16),
)
fout.add_tensor(
"v.patch_embd.weight",
tensors.get_tensor("model.vision_model.embeddings.patch_embedding.weight")
.reshape(vision_config["hidden_size"], 3, vision_config["patch_size"], vision_config["patch_size"])
.astype(np.float16),
)
fout.add_tensor(
"v.patch_embd.bias",
tensors.get_tensor("model.vision_model.embeddings.patch_embedding.bias").astype(np.float32),
)
fout.add_tensor(
"v.post_ln.weight",
tensors.get_tensor("model.vision_model.post_layernorm.weight").astype(np.float32),
)
fout.add_tensor(
"v.post_ln.bias",
tensors.get_tensor("model.vision_model.post_layernorm.bias").astype(np.float32),
)
def add_vision_tensor(blk_id: int, gguf_id: typing.Optional[int]=None):
if gguf_id is None:
gguf_id = blk_id
attn_prefix = f"model.vision_model.encoder.layers.{blk_id}.self_attn."
fout.add_tensor(
f"v.blk.{gguf_id}.attn_q.weight",
tensors.get_tensor(f"{attn_prefix}q_proj.weight").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.attn_q.bias",
tensors.get_tensor(f"{attn_prefix}q_proj.bias").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.attn_k.weight",
tensors.get_tensor(f"{attn_prefix}k_proj.weight").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.attn_k.bias",
tensors.get_tensor(f"{attn_prefix}k_proj.bias").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.attn_v.weight",
tensors.get_tensor(f"{attn_prefix}v_proj.weight").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.attn_v.bias",
tensors.get_tensor(f"{attn_prefix}v_proj.bias").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.attn_out.weight",
tensors.get_tensor(f"{attn_prefix}out_proj.weight").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.attn_out.bias",
tensors.get_tensor(f"{attn_prefix}out_proj.bias").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.ln1.weight",
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.layer_norm1.weight").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.ln1.bias",
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.layer_norm1.bias").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.ffn_down.weight",
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.mlp.fc1.weight").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.ffn_down.bias",
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.mlp.fc1.bias").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.ffn_up.weight",
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.mlp.fc2.weight").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.ffn_up.bias",
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.mlp.fc2.bias").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.ln2.weight",
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.layer_norm2.weight").astype(np.float32),
)
fout.add_tensor(
f"v.blk.{gguf_id}.ln2.bias",
tensors.get_tensor(f"model.vision_model.encoder.layers.{blk_id}.layer_norm2.bias").astype(np.float32),
)
for i in range(n_layers_clip):
add_vision_tensor(i)
# Duplicate the last block (llava-cli skips over this)
add_vision_tensor(n_layers_clip - 1, n_layers_clip)
fout.write_header_to_file()
fout.write_kv_data_to_file()
fout.write_tensors_to_file()
fout.close()
### Text Model
# general GGUF init
fname_out = f"{name}-text-model-f16.gguf"
fout = GGUFWriter(fname_out, arch="llama")
ftype = 1
block_count = text_config["num_hidden_layers"]
fout.add_name(name)
fout.add_block_count(block_count)
fout.add_context_length(text_config["max_position_embeddings"])
fout.add_embedding_length(text_config["hidden_size"])
fout.add_feed_forward_length(text_config["intermediate_size"])
fout.add_head_count(text_config["num_attention_heads"])
fout.add_head_count_kv(text_config["num_key_value_heads"])
fout.add_rope_freq_base(text_config["rope_theta"])
fout.add_layer_norm_rms_eps(text_config["rms_norm_eps"])
fout.add_file_type(ftype)
fout.add_vocab_size(text_config["vocab_size"])
fout.add_rope_dimension_count(
text_config["hidden_size"] // text_config["num_attention_heads"]
)
tokenizer_config_file = dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)
if "add_prefix_space" in tokenizer_config_json:
fout.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
### Tokenizer
# Taken from _set_vocab_sentencepiece
from enum import IntEnum
class SentencePieceTokenTypes(IntEnum):
NORMAL = 1
UNKNOWN = 2
CONTROL = 3
USER_DEFINED = 4
UNUSED = 5
BYTE = 6
from sentencepiece import SentencePieceProcessor
tokenizer_path = dir_model / 'tokenizer.model'
tokens: typing.List[bytes] = []
scores: typing.List[float] = []
toktypes: typing.List[int] = []
if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}")
tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path))
vocab_size = text_config["vocab_size"]
tokens: typing.List[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: typing.List[float] = [-10000.0] * vocab_size
toktypes: typing.List[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
for token_id in range(tokenizer.vocab_size()):
piece = tokenizer.IdToPiece(token_id)
text = piece.encode("utf-8")
score = tokenizer.GetScore(token_id)
toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.IsUnknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.IsControl(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.IsUnused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.IsByte(token_id):
toktype = SentencePieceTokenTypes.BYTE
tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype
added_tokens_file = dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
added_tokens_json = json.load(f)
for key in added_tokens_json:
token_id = added_tokens_json[key]
if (token_id >= vocab_size):
print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
continue
tokens[token_id] = key.encode("utf-8")
scores[token_id] = -1000.0
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens)
print(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
for i in range(1, pad_count + 1):
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.UNUSED)
fout.add_tokenizer_model("llama")
fout.add_tokenizer_pre("default")
fout.add_token_list(tokens)
fout.add_token_scores(scores)
fout.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(fout)
def permute(weights: npt.NDArray[np.float16], n_head: int, n_head_kv: typing.Optional[int]):
if n_head_kv is not None and n_head != n_head_kv:
n_head = n_head_kv
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))
n_head = typing.cast(int, text_config["num_attention_heads"])
n_kv_head = typing.cast(int, text_config["num_key_value_heads"])
fout.add_tensor(
"token_embd.weight",
tensors.get_tensor("model.text_model.embed_tokens.weight").astype(np.float32),
)
def add_text_tensor(i: int):
fout.add_tensor(
f"blk.{i}.attn_norm.weight",
tensors.get_tensor(f"model.text_model.layers.{i}.input_layernorm.weight").astype(
np.float32
),
)
fout.add_tensor(
f"blk.{i}.ffn_down.weight",
tensors.get_tensor(f"model.text_model.layers.{i}.mlp.down_proj.weight").astype(
np.float16
),
)
fout.add_tensor(
f"blk.{i}.ffn_gate.weight",
tensors.get_tensor(f"model.text_model.layers.{i}.mlp.gate_proj.weight").astype(
np.float16
),
)
fout.add_tensor(
f"blk.{i}.ffn_up.weight",
tensors.get_tensor(f"model.text_model.layers.{i}.mlp.up_proj.weight").astype(
np.float16
),
)
fout.add_tensor(
f"blk.{i}.ffn_norm.weight",
tensors.get_tensor(f"model.text_model.layers.{i}.post_attention_layernorm.weight").astype(
np.float32
),
)
fout.add_tensor(
f"blk.{i}.attn_k.weight",
permute(
tensors.get_tensor(
f"model.text_model.layers.{i}.self_attn.k_proj.weight"
).astype(np.float16),
n_head,
n_kv_head
),
)
fout.add_tensor(
f"blk.{i}.attn_output.weight",
tensors.get_tensor(
f"model.text_model.layers.{i}.self_attn.o_proj.weight"
).astype(np.float16),
)
fout.add_tensor(
f"blk.{i}.attn_q.weight",
permute(
tensors.get_tensor(
f"model.text_model.layers.{i}.self_attn.q_proj.weight"
).astype(np.float16),
n_head,
n_head,
)
)
fout.add_tensor(
f"blk.{i}.attn_v.weight",
tensors.get_tensor(
f"model.text_model.layers.{i}.self_attn.v_proj.weight"
).astype(np.float16),
)
for i in range(32): # Update as necessary
add_text_tensor(i)
fout.add_tensor(
"output_norm.weight",
tensors.get_tensor("model.text_model.norm.weight").astype(np.float32),
)
fout.add_tensor(
"output.weight",
tensors.get_tensor("lm_head.weight").astype(np.float32),
)
fout.write_header_to_file()
fout.write_kv_data_to_file()
fout.write_tensors_to_file()
fout.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment