Last active
January 4, 2024 12:39
-
-
Save xrsrke/1cd36cc42e7dcf50a83bd60557459bca to your computer and use it in GitHub Desktop.
merged checkpoint's inference
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import GPTBigCodeForCausalLM, GPTBigCodeConfig | |
from transformers import AutoTokenizer | |
from pathlib import Path | |
import json | |
import torch | |
import random | |
import numpy as np | |
if __name__ == "__main__": | |
seed = 42 | |
checkpoint_dir = "/admin/home/phuc_nguyen/.cache/huggingface/hub/models--HuggingFaceBR4--starcoder2_7b_4k_smol_data_580000/snapshots/92b6c25cab25f07c367bcc6d773635700a8a287d" | |
checkpoint_dir = Path(checkpoint_dir) | |
config = json.load(open(checkpoint_dir / "model_config.json")) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
model_config = GPTBigCodeConfig( | |
vocab_size=config["vocab_size"], | |
n_positions=config["max_position_embeddings"], | |
n_embd=config["hidden_size"], | |
n_layer=config["num_hidden_layers"], | |
n_head=config["num_attention_heads"], | |
num_key_value_heads=config["num_kv_heads"], | |
# NOTE: based on https://github.com/huggingface/brrr/blob/f569b93f80d03c626b24370d5ca4b1fe4f13fd76/brrr/models/fast/starcoder2.py#L194C16-L194C88 | |
n_inner=config.get("n_inner", 4 * config["hidden_size"]), | |
activation_function=config["activation_function"], | |
resid_pdrop=config["resid_pdrop"], | |
embd_pdrop=config["embd_pdrop"], | |
attn_pdrop=config["attn_pdrop"], | |
layer_norm_epsilon=config["layer_norm_epsilon"], | |
scale_attn_weights=config["scale_attn_weights"], | |
bos_token_id=config["bos_token_id"], | |
eos_token_id=config["eos_token_id"], | |
attention_softmax_in_fp32=config["attention_softmax_in_fp32"], | |
scale_attention_softmax_in_fp32=config["scale_attention_softmax_in_fp32"], | |
multi_query=config["multi_query"], | |
use_rotary_embeddings=config["use_rotary_embeddings"], | |
# rotary_embedding_scale=brrr_model_config.rotary_embedding_scale, #TODO | |
attention_window_size=config["sliding_window_size"], | |
) | |
model = GPTBigCodeForCausalLM._from_config(model_config, torch_dtype=torch.bfloat16) | |
checkpoint_path = Path("/fsx/phuc/projects/starcoder/transformers-starcoder/src/transformers/models/gpt_bigcode/pytorch_model.pth") | |
checkpoint = torch.load(checkpoint_path) | |
model.load_state_dict(checkpoint) | |
model = model.to("cuda") | |
checkpoint = "bigcode/starcoder2-tokenizer" | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
tokenizer.eos_token_id = tokenizer.pad_token_id | |
inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to("cuda") | |
outputs = model.generate(inputs) | |
print(tokenizer.decode(outputs[0], clean_up_tokenization_spaces=False)) | |
# inputs = tokenizer("def print_hello_world():", return_tensors="pt").to("cuda") | |
# outputs = model(inputs) | |
# print(f"outputs: {outputs}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment