Skip to content

Instantly share code, notes, and snippets.

@bll-bobbygill
Last active June 20, 2024 23:21
Show Gist options
  • Save bll-bobbygill/56afa11759e5f4faa3ac948b10fb3f4b to your computer and use it in GitHub Desktop.
Save bll-bobbygill/56afa11759e5f4faa3ac948b10fb3f4b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from datasets import load_dataset, DatasetDict"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":[" ## Parameters & Setup"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["dataset_path='omgbobbyg/spock'\n","pretrained_model = \"distilbert/distilgpt2\"\n","finetuned_modelname = \"distilgpt2-spock\"\n","huggingface_username = \"omgbobbyg\"\n","huggingface_reponame = f\"{huggingface_username}/{finetuned_modelname}\" "]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#setup parameters regarding GPU availibility on the machine and recycle used memory\n","import torch; \n","import gc;\n","\n","is_gpu_available = torch.cuda.is_available()\n","device = 'cuda' if is_gpu_available else 'cpu'\n","if is_gpu_available:\n"," print(\"GPU available for notebook\")\n"," torch.cuda.empty_cache()\n"," print(\"GPU Memory cleaned\")\n","else:\n"," print(\"No GPU available for notebook\")\n"," \n","gc.collect()\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#setup our logging level and file\n","import logging \n","logger = logging.getLogger(__name__)\n","logger.setLevel(logging.INFO)\n","\n","# Configure the logger if it hasn't been configured before\n","if not logger.handlers:\n"," handler = logging.FileHandler('training.log')\n"," handler.setLevel(logging.INFO)\n"," formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')\n"," handler.setFormatter(formatter)\n"," logger.addHandler(handler)\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Dataset"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["dataset = load_dataset(dataset_path)\n","print(dataset)"]},{"cell_type":"markdown","metadata":{},"source":[" ## Tokenization"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig, AutoModelForCausalLM\n","\n","model = AutoModelForCausalLM.from_pretrained(pretrained_model)\n","# Get the maximum context size\n","max_length = model.config.max_position_embeddings\n","print(f\"Maximum context size: {max_length}\")\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from transformers import GPT2Tokenizer\n","tokenizer = AutoTokenizer.from_pretrained(pretrained_model)\n","\n","def tokenize_function(examples):\n"," return tokenizer(examples[\"dialogue\"],max_length=max_length)\n","\n","\n","# Apply the tokenization function to the entire dataset\n","tokenized_dataset = dataset.map(\n"," tokenize_function,\n"," batched=True,\n"," batch_size=10,\n"," remove_columns=dataset[\"train\"].column_names\n",")\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["sample_token = tokenizer.encode(\"Live long and prosper.\")\n","print(sample_token)"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Data Collator"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#We need to create data collator to manage the batches, we can use DataCollatorForLanguageModeling\n","from transformers import DataCollatorForLanguageModeling\n","tokenizer.pad_token =\"<pad>\" #tokenizer.eos_token\n","data_collator = DataCollatorForLanguageModeling(tokenizer,mlm=False)\n","# Iterate over the generator\n","out = data_collator([tokenized_dataset[\"train\"][i] for i in range(1)])\n","for key in out:\n"," print(f\"{key} shape: {out[key].shape}\")\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Setup the Trainer"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#Now we train the model using the Trainer API\n","from transformers import Trainer, TrainingArguments\n","\n","args = TrainingArguments(\n"," finetuned_modelname,\n"," evaluation_strategy=\"epoch\",\n"," save_strategy=\"epoch\",\n"," learning_rate=2e-5,\n"," num_train_epochs=3,\n"," weight_decay=0.01,\n"," fp16=is_gpu_available,\n"," push_to_hub=True,\n"," per_device_train_batch_size=4,\n"," per_device_eval_batch_size=4,\n"," hub_model_id=huggingface_reponame\n",")\n","\n","trainer = Trainer(\n"," model=model,\n"," tokenizer=tokenizer,\n"," args=args,\n"," data_collator=data_collator,\n"," train_dataset=tokenized_dataset[\"train\"],\n"," eval_dataset=tokenized_dataset[\"validation\"]\n",")"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Evaluate the Performance of the Base Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import math\n","\n","#Calculate and report on perplexity\n","initial_results = trainer.evaluate()\n","print(initial_results)\n","#log the results to file\n","logger.info(f\"Baseline {pretrained_model} Results: Perplexity: {math.exp(initial_results['eval_loss']):.2f}\")\n","print(f\"Baseline {pretrained_model} Results: Perplexity: {math.exp(initial_results['eval_loss']):.2f}\")\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#setup our test prompts\n","test_prompt = \"What is the meaning of life?\"\n","test_prompt2 = \"Where did that planet go??\"\n","test_prompt3 = \"What is the best way to cook a turkey?\""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#Use the model in a pipeline to generate text.\n","from transformers import pipeline\n","text_generator = pipeline(\"text-generation\", model=model, tokenizer=tokenizer, device=device)\n","\n","\n","result = text_generator(test_prompt, max_length=100, num_return_sequences=1,temperature=1)\n","print(f\"Baseline {pretrained_model} generated result: {test_prompt}...{result[0]['generated_text']}\")\n","logger.info(f\"Baseline {pretrained_model} generated result: {test_prompt}...{result[0]['generated_text']}\")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["result = text_generator(test_prompt2, max_length=100, num_return_sequences=1,temperature=1)\n","print(f\"Baseline {pretrained_model} generated result: {test_prompt2}...{result[0]['generated_text']}\")\n","logger.info(f\"Baseline {pretrained_model} generated result: {test_prompt2}...{result[0]['generated_text']}\")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["result = text_generator(test_prompt3, max_length=100, num_return_sequences=1,temperature=1)\n","print(f\"Baseline {pretrained_model} generated result: {test_prompt3}...{result[0]['generated_text']}\")\n","logger.info(f\"Baseline {pretrained_model} generated result: {test_prompt3}...{result[0]['generated_text']}\")"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Fine-Tune the Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer.train()"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Evaluate the Performance of the Fine-Tuned Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#Calculate and report on perplexity\n","eval_results = trainer.evaluate()\n","perplexity = math.exp(eval_results['eval_loss'])\n","eval_results['perplexity'] = perplexity\n","\n","logger.info(f\"Fine-tuned {finetuned_modelname} Results: Perplexity: {perplexity:.2f}\")\n","print(f\"Fine-tuned {finetuned_modelname} Results: Perplexity: {perplexity:.2f}\")\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#Prompt Test 1\n","text_generator = pipeline(\"text-generation\", model=model, tokenizer=tokenizer, device=device)\n","result = text_generator(test_prompt, max_length=100, num_return_sequences=1,temperature=0.7)\n","print(f\"Fine-tuned {finetuned_modelname} generated result: {test_prompt}...{result[0]['generated_text']}\")\n","logger.info(f\"Fine-tuned {finetuned_modelname} generated result: {test_prompt}...{result[0]['generated_text']}\")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#Prompt Test 2\n","result = text_generator(test_prompt2, max_length=100, num_return_sequences=1,temperature=1)\n","print(f\"Fine-tuned {finetuned_modelname} generated result: {test_prompt2}...{result[0]['generated_text']}\")\n","logger.info(f\"Fine-tuned {finetuned_modelname} generated result: {test_prompt2}...{result[0]['generated_text']}\")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#Prompt Test 3\n","result = text_generator(test_prompt3, max_length=100, num_return_sequences=1,temperature=1)\n","print(f\"Fine-tuned {finetuned_modelname} generated result: {test_prompt3}...{result[0]['generated_text']}\")\n","logger.info(f\"Fine-tuned {finetuned_modelname} generated result: {test_prompt3}...{result[0]['generated_text']}\")"]}],"metadata":{"kernelspec":{"display_name":"HuggingFace","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"36b514b738acff4061330311a03dff86b1232f3a413278e3b727ffbc7c44120f"}}},"nbformat":4,"nbformat_minor":2}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment