Skip to content

Instantly share code, notes, and snippets.

@python273
Last active April 21, 2023 13:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save python273/d2b16b267104179d0718fc266f74c132 to your computer and use it in GitHub Desktop.
Save python273/d2b16b267104179d0718fc266f74c132 to your computer and use it in GitHub Desktop.
Transformers LLaMA
conda create -n llama
conda activate llama

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
conda install cudatoolkit==11.8.0 -c conda-forge
python -s -m pip install bitsandbytes accelerate sentencepiece
python -s -m pip install git+https://github.com/huggingface/transformers
import time
import traceback
from transformers import (
LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig, TextStreamer,
GenerationConfig
)
import torch
cuda_is_available = torch.cuda.is_available()
gpu_count = torch.cuda.device_count()
if not cuda_is_available or gpu_count == 0:
print("CUDA is not available")
exit(1)
print("CUDA version: " + torch.version.cuda)
quantization_config = BitsAndBytesConfig(
llm_int8_skip_modules=["lm_head"],
load_in_8bit=True,
)
start_model = time.perf_counter_ns()
model_name = "./models/7B/"
tokenizer = LlamaTokenizer.from_pretrained(model_name)
streamer = TextStreamer(tokenizer)
model: LlamaForCausalLM = LlamaForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=quantization_config,
torch_dtype=torch.float16,
cache_dir="cache"
)
def main():
with open("prompt.txt") as f:
prompt = f.read()
gen_in = tokenizer(prompt, return_tensors="pt")["input_ids"].cuda()
with torch.no_grad():
generation_config = GenerationConfig(
max_new_tokens=256,
do_sample=True,
use_cache=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.2,
temperature=0.7,
top_k=40,
top_p=1.0,
# early_stopping=True,
)
try:
print('generation start')
start_gen = time.perf_counter_ns()
generated_ids = model.generate(
gen_in,
generation_config=generation_config,
streamer=streamer,
# stopping_criteria=
)
except Exception:
print("\033[91m")
traceback.print_exc()
print("\033[0m")
exit(1)
print(repr(generated_ids))
# generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# print(repr(generated_text))
end_gen = time.perf_counter_ns()
print(f"generation time: {(end_gen - start_gen) / 1e9} s")
print(f"total time: {(end_gen - start_model) / 1e9} s")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment