Last active
June 20, 2024 23:21
-
-
Save bll-bobbygill/56afa11759e5f4faa3ac948b10fb3f4b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from datasets import load_dataset, DatasetDict"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":[" ## Parameters & Setup"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["dataset_path='omgbobbyg/spock'\n","pretrained_model = \"distilbert/distilgpt2\"\n","finetuned_modelname = \"distilgpt2-spock\"\n","huggingface_username = \"omgbobbyg\"\n","huggingface_reponame = f\"{huggingface_username}/{finetuned_modelname}\" "]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#setup parameters regarding GPU availibility on the machine and recycle used memory\n","import torch; \n","import gc;\n","\n","is_gpu_available = torch.cuda.is_available()\n","device = 'cuda' if is_gpu_available else 'cpu'\n","if is_gpu_available:\n"," print(\"GPU available for notebook\")\n"," torch.cuda.empty_cache()\n"," print(\"GPU Memory cleaned\")\n","else:\n"," print(\"No GPU available for notebook\")\n"," \n","gc.collect()\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#setup our logging level and file\n","import logging \n","logger = logging.getLogger(__name__)\n","logger.setLevel(logging.INFO)\n","\n","# Configure the logger if it hasn't been configured before\n","if not logger.handlers:\n"," handler = logging.FileHandler('training.log')\n"," handler.setLevel(logging.INFO)\n"," formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')\n"," handler.setFormatter(formatter)\n"," logger.addHandler(handler)\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Dataset"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["dataset = load_dataset(dataset_path)\n","print(dataset)"]},{"cell_type":"markdown","metadata":{},"source":[" ## Tokenization"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig, AutoModelForCausalLM\n","\n","model = AutoModelForCausalLM.from_pretrained(pretrained_model)\n","# Get the maximum context size\n","max_length = model.config.max_position_embeddings\n","print(f\"Maximum context size: {max_length}\")\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from transformers import GPT2Tokenizer\n","tokenizer = AutoTokenizer.from_pretrained(pretrained_model)\n","\n","def tokenize_function(examples):\n"," return tokenizer(examples[\"dialogue\"],max_length=max_length)\n","\n","\n","# Apply the tokenization function to the entire dataset\n","tokenized_dataset = dataset.map(\n"," tokenize_function,\n"," batched=True,\n"," batch_size=10,\n"," remove_columns=dataset[\"train\"].column_names\n",")\n","\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["sample_token = tokenizer.encode(\"Live long and prosper.\")\n","print(sample_token)"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Data Collator"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#We need to create data collator to manage the batches, we can use DataCollatorForLanguageModeling\n","from transformers import DataCollatorForLanguageModeling\n","tokenizer.pad_token =\"<pad>\" #tokenizer.eos_token\n","data_collator = DataCollatorForLanguageModeling(tokenizer,mlm=False)\n","# Iterate over the generator\n","out = data_collator([tokenized_dataset[\"train\"][i] for i in range(1)])\n","for key in out:\n"," print(f\"{key} shape: {out[key].shape}\")\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Setup the Trainer"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#Now we train the model using the Trainer API\n","from transformers import Trainer, TrainingArguments\n","\n","args = TrainingArguments(\n"," finetuned_modelname,\n"," evaluation_strategy=\"epoch\",\n"," save_strategy=\"epoch\",\n"," learning_rate=2e-5,\n"," num_train_epochs=3,\n"," weight_decay=0.01,\n"," fp16=is_gpu_available,\n"," push_to_hub=True,\n"," per_device_train_batch_size=4,\n"," per_device_eval_batch_size=4,\n"," hub_model_id=huggingface_reponame\n",")\n","\n","trainer = Trainer(\n"," model=model,\n"," tokenizer=tokenizer,\n"," args=args,\n"," data_collator=data_collator,\n"," train_dataset=tokenized_dataset[\"train\"],\n"," eval_dataset=tokenized_dataset[\"validation\"]\n",")"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Evaluate the Performance of the Base Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import math\n","\n","#Calculate and report on perplexity\n","initial_results = trainer.evaluate()\n","print(initial_results)\n","#log the results to file\n","logger.info(f\"Baseline {pretrained_model} Results: Perplexity: {math.exp(initial_results['eval_loss']):.2f}\")\n","print(f\"Baseline {pretrained_model} Results: Perplexity: {math.exp(initial_results['eval_loss']):.2f}\")\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#setup our test prompts\n","test_prompt = \"What is the meaning of life?\"\n","test_prompt2 = \"Where did that planet go??\"\n","test_prompt3 = \"What is the best way to cook a turkey?\""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#Use the model in a pipeline to generate text.\n","from transformers import pipeline\n","text_generator = pipeline(\"text-generation\", model=model, tokenizer=tokenizer, device=device)\n","\n","\n","result = text_generator(test_prompt, max_length=100, num_return_sequences=1,temperature=1)\n","print(f\"Baseline {pretrained_model} generated result: {test_prompt}...{result[0]['generated_text']}\")\n","logger.info(f\"Baseline {pretrained_model} generated result: {test_prompt}...{result[0]['generated_text']}\")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["result = text_generator(test_prompt2, max_length=100, num_return_sequences=1,temperature=1)\n","print(f\"Baseline {pretrained_model} generated result: {test_prompt2}...{result[0]['generated_text']}\")\n","logger.info(f\"Baseline {pretrained_model} generated result: {test_prompt2}...{result[0]['generated_text']}\")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["result = text_generator(test_prompt3, max_length=100, num_return_sequences=1,temperature=1)\n","print(f\"Baseline {pretrained_model} generated result: {test_prompt3}...{result[0]['generated_text']}\")\n","logger.info(f\"Baseline {pretrained_model} generated result: {test_prompt3}...{result[0]['generated_text']}\")"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Fine-Tune the Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer.train()"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Evaluate the Performance of the Fine-Tuned Model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#Calculate and report on perplexity\n","eval_results = trainer.evaluate()\n","perplexity = math.exp(eval_results['eval_loss'])\n","eval_results['perplexity'] = perplexity\n","\n","logger.info(f\"Fine-tuned {finetuned_modelname} Results: Perplexity: {perplexity:.2f}\")\n","print(f\"Fine-tuned {finetuned_modelname} Results: Perplexity: {perplexity:.2f}\")\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#Prompt Test 1\n","text_generator = pipeline(\"text-generation\", model=model, tokenizer=tokenizer, device=device)\n","result = text_generator(test_prompt, max_length=100, num_return_sequences=1,temperature=0.7)\n","print(f\"Fine-tuned {finetuned_modelname} generated result: {test_prompt}...{result[0]['generated_text']}\")\n","logger.info(f\"Fine-tuned {finetuned_modelname} generated result: {test_prompt}...{result[0]['generated_text']}\")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#Prompt Test 2\n","result = text_generator(test_prompt2, max_length=100, num_return_sequences=1,temperature=1)\n","print(f\"Fine-tuned {finetuned_modelname} generated result: {test_prompt2}...{result[0]['generated_text']}\")\n","logger.info(f\"Fine-tuned {finetuned_modelname} generated result: {test_prompt2}...{result[0]['generated_text']}\")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#Prompt Test 3\n","result = text_generator(test_prompt3, max_length=100, num_return_sequences=1,temperature=1)\n","print(f\"Fine-tuned {finetuned_modelname} generated result: {test_prompt3}...{result[0]['generated_text']}\")\n","logger.info(f\"Fine-tuned {finetuned_modelname} generated result: {test_prompt3}...{result[0]['generated_text']}\")"]}],"metadata":{"kernelspec":{"display_name":"HuggingFace","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"36b514b738acff4061330311a03dff86b1232f3a413278e3b727ffbc7c44120f"}}},"nbformat":4,"nbformat_minor":2} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment