Skip to content

Instantly share code, notes, and snippets.

@jyc
Created July 30, 2023 20:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jyc/431098b0af565943f03877593abf59cd to your computer and use it in GitHub Desktop.
Save jyc/431098b0af565943f03877593abf59cd to your computer and use it in GitHub Desktop.
!nvcc --version
#nvcc: NVIDIA (R) Cuda compiler driver
#Copyright (c) 2005-2022 NVIDIA Corporation
#Built on Tue_Mar__8_18:18:20_PST_2022
#Cuda compilation tools, release 11.6, V11.6.124
#Build cuda_11.6.r11.6/compiler.31057947_0
%pip install -U --only-binary :all: xformers transformers==4.31.0 torch accelerate einops bitsandbytes==0.39.1 peft sentencepiece boto3 'tokenizers>=0.13.3'
# Setting HF_HOME and cache_dir doesn't work...
!mkdir -p /storage/huggingface-cache
!rm -rf ~/.cache/huggingface
!ln -sf /storage/huggingface-cache ~/.cache/huggingface
import torch
from tqdm.notebook import tqdm
import math
from math import exp
import transformers
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
import html
import re
from instruct_pipeline import InstructionTextGenerationPipeline
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = 'meta-llama/Llama-2-70b-hf'
revision = 'aa66797ee725f1d03e506cf21cccb19327280c35' # https://huggingface.co/meta-llama/Llama-2-70b-hf/commit/aa66797ee725f1d03e506cf21cccb19327280c35
hf_auth = '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'
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_id,
revision=revision,
use_auth_token=hf_auth
)
model_config = transformers.AutoConfig.from_pretrained(
model_id,
revision=revision,
use_auth_token=hf_auth
)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
trust_remote_code=True,
config=model_config,
# quantization_config=bnb_config,
#torch_dtype=torch.bfloat16,
load_in_8bit=True, # otherwise the kernel crashes
device_map='auto',
token=hf_auth
)
model.tie_weights()
model.eval()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment