Skip to content

Instantly share code, notes, and snippets.

@pacman100
Last active December 8, 2023 09:55
Show Gist options
  • Save pacman100/5aac746b0a7bdee5dca23e2f27cc4fb0 to your computer and use it in GitHub Desktop.
Save pacman100/5aac746b0a7bdee5dca23e2f27cc4fb0 to your computer and use it in GitHub Desktop.
from accelerate import Accelerator
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
import contextlib
MODEL_NAME = "meta-llama/Llama-2-70b-chat-hf" #"HuggingFaceH4/zephyr-7b-beta"
def main():
accelerator = Accelerator()
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if accelerator.state.deepspeed_plugin is not None:
accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']=1
model = accelerator.prepare(model)
sample_texts = [
[{"role": "user", "content": "Explain Deep Learning like a Pirate."}],
[{"role": "user", "content": "Why is it important to eat socks daily?"}],
[{"role": "user", "content": "Write a tweet about the latest model by Google Gemini which is topping all the benchmarks"}],
[{"role": "user", "content": "How do I convert a Python dictionary into a string representation?"}]
]
for i in range(len(sample_texts)):
sample_texts[i] = tokenizer.apply_chat_template(sample_texts[i], add_generation_prompt=True, tokenize=False)
accelerator.print(sample_texts)
inputs = tokenizer(sample_texts[accelerator.process_index], return_tensors="pt").to(accelerator.device)
ctx = FSDP.summon_full_params(model, writeback=False, recurse=False) if hasattr(accelerator.state, "fsdp_plugin") is not None else contextlib.nullcontext()
unwrapped_model = accelerator.unwrap_model(model)
with ctx:
outputs = unwrapped_model.generate(**inputs,
do_sample=True,
temperature=0.2,
top_p=0.95,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=128,
synced_gpus=True
)
print(f"{accelerator.process_index=} {tokenizer.decode(outputs[0], skip_special_tokens=False)}")
print("".join(["-"]*100))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment