Skip to content

Instantly share code, notes, and snippets.

@7shi
Created February 29, 2024 03:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save 7shi/224018f51d07aeedbda89d69e5252da5 to your computer and use it in GitHub Desktop.
Save 7shi/224018f51d07aeedbda89d69e5252da5 to your computer and use it in GitHub Desktop.
[py] Gemma test
# https://huggingface.co/google/gemma-2b
# https://note.com/ngc_shj/n/n81cde9550b37
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
device_map="auto",
#device_map="cuda",
low_cpu_mem_usage=True,
trust_remote_code=True
)
generation_params = {
"do_sample": True,
"temperature": 0.8,
"top_p": 0.95,
"top_k": 40,
"max_new_tokens": 256,
"repetition_penalty": 1.1,
}
def q(prompt):
start = datetime.now()
#print("生成開始:", start.strftime("%Y/%m/%d %H:%M:%S"))
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output_ids = model.generate(input_ids, **generation_params)
output_ids2 = output_ids[0][input_ids.size(1) :]
output = tokenizer.decode(output_ids2)
end = datetime.now()
#print("生成終了:", end.strftime("%Y/%m/%d %H:%M:%S"))
total_time = end - start
input_tokens = len(input_ids[0])
output_tokens = len(output_ids2)
tps = output_tokens / total_time.total_seconds()
print(output.strip())
print("--------")
print(f"prompt tokens = {input_tokens:.7g}")
print(f"output tokens = {output_tokens:.7g} ({tps:f} [tps])")
print(f" total time = {total_time}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment