conda create -n llama
conda activate llama
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
conda install cudatoolkit==11.8.0 -c conda-forge
python -s -m pip install bitsandbytes accelerate sentencepiece
python -s -m pip install git+https://github.com/huggingface/transformers
Last active
April 21, 2023 13:53
-
-
Save python273/d2b16b267104179d0718fc266f74c132 to your computer and use it in GitHub Desktop.
Transformers LLaMA
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import traceback | |
from transformers import ( | |
LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig, TextStreamer, | |
GenerationConfig | |
) | |
import torch | |
cuda_is_available = torch.cuda.is_available() | |
gpu_count = torch.cuda.device_count() | |
if not cuda_is_available or gpu_count == 0: | |
print("CUDA is not available") | |
exit(1) | |
print("CUDA version: " + torch.version.cuda) | |
quantization_config = BitsAndBytesConfig( | |
llm_int8_skip_modules=["lm_head"], | |
load_in_8bit=True, | |
) | |
start_model = time.perf_counter_ns() | |
model_name = "./models/7B/" | |
tokenizer = LlamaTokenizer.from_pretrained(model_name) | |
streamer = TextStreamer(tokenizer) | |
model: LlamaForCausalLM = LlamaForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
quantization_config=quantization_config, | |
torch_dtype=torch.float16, | |
cache_dir="cache" | |
) | |
def main(): | |
with open("prompt.txt") as f: | |
prompt = f.read() | |
gen_in = tokenizer(prompt, return_tensors="pt")["input_ids"].cuda() | |
with torch.no_grad(): | |
generation_config = GenerationConfig( | |
max_new_tokens=256, | |
do_sample=True, | |
use_cache=True, | |
pad_token_id=tokenizer.eos_token_id, | |
repetition_penalty=1.2, | |
temperature=0.7, | |
top_k=40, | |
top_p=1.0, | |
# early_stopping=True, | |
) | |
try: | |
print('generation start') | |
start_gen = time.perf_counter_ns() | |
generated_ids = model.generate( | |
gen_in, | |
generation_config=generation_config, | |
streamer=streamer, | |
# stopping_criteria= | |
) | |
except Exception: | |
print("\033[91m") | |
traceback.print_exc() | |
print("\033[0m") | |
exit(1) | |
print(repr(generated_ids)) | |
# generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# print(repr(generated_text)) | |
end_gen = time.perf_counter_ns() | |
print(f"generation time: {(end_gen - start_gen) / 1e9} s") | |
print(f"total time: {(end_gen - start_model) / 1e9} s") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment