Skip to content

Instantly share code, notes, and snippets.

@mlconnor
Last active February 27, 2023 19:17
Show Gist options
  • Save mlconnor/4533390fb02e6f437e8bdc65c6fca851 to your computer and use it in GitHub Desktop.
Save mlconnor/4533390fb02e6f437e8bdc65c6fca851 to your computer and use it in GitHub Desktop.
This is a Python Jupyter notebook built for Amazon SageMaker notebooks. To launch, go to AWS Console | SageMaker | Notebooks | Create Notebook Instance (ml.g5.2xlarge, config volume size - 500G) | Open Jupyter Lab | Upload. Based on https://www.philschmid.de/fine-tune-flan-t5
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "820cc2a7-0080-415d-9768-9774f98ac20c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install -q huggingface_hub transformers accelerate bitsandbytes sentencepiece gradio"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6dacb0ab-dfcf-4fac-904f-141a2e0889fb",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fri Feb 24 13:42:13 2023 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 515.65.01 Driver Version: 515.65.01 CUDA Version: 11.7 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 NVIDIA A10G On | 00000000:00:1E.0 Off | 0 |\n",
"| 0% 23C P8 22W / 300W | 0MiB / 23028MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f7dc53c3-754d-465f-95b8-c9e9273ff08c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"HF_MODEL_ID = \"philschmid/flan-t5-xxl-sharded-fp16\"\n",
"TMP_DIR = \"tmp\""
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a1999ea4-71b0-418d-98dc-0e353b4854c1",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/json": {
"ascii": false,
"bar_format": null,
"colour": null,
"elapsed": 0.006432056427001953,
"initial": 0,
"n": 0,
"ncols": null,
"nrows": null,
"postfix": null,
"prefix": "Fetching 23 files",
"rate": null,
"total": 23,
"unit": "it",
"unit_divisor": 1000,
"unit_scale": false
},
"application/vnd.jupyter.widget-view+json": {
"model_id": "c8b77f822d364f47b95d2a6c9180feeb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 23 files: 0%| | 0/23 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"['flan-t5-xxl-sharded-fp16/pytorch_model.bin.index.json',\n",
" 'flan-t5-xxl-sharded-fp16/pytorch_model-00011-of-00012.bin',\n",
" 'flan-t5-xxl-sharded-fp16/tokenizer_config.json',\n",
" 'flan-t5-xxl-sharded-fp16/tokenizer.json',\n",
" 'flan-t5-xxl-sharded-fp16/pytorch_model-00009-of-00012.bin',\n",
" 'flan-t5-xxl-sharded-fp16/config.json',\n",
" 'flan-t5-xxl-sharded-fp16/.gitattributes',\n",
" 'flan-t5-xxl-sharded-fp16/pytorch_model-00001-of-00012.bin',\n",
" 'flan-t5-xxl-sharded-fp16/pytorch_model-00006-of-00012.bin',\n",
" 'flan-t5-xxl-sharded-fp16/pytorch_model-00010-of-00012.bin',\n",
" 'flan-t5-xxl-sharded-fp16/pytorch_model-00007-of-00012.bin',\n",
" 'flan-t5-xxl-sharded-fp16/pytorch_model-00008-of-00012.bin',\n",
" 'flan-t5-xxl-sharded-fp16/README.md',\n",
" 'flan-t5-xxl-sharded-fp16/createEndpoint.png',\n",
" 'flan-t5-xxl-sharded-fp16/pytorch_model-00012-of-00012.bin',\n",
" 'flan-t5-xxl-sharded-fp16/special_tokens_map.json',\n",
" 'flan-t5-xxl-sharded-fp16/pytorch_model-00004-of-00012.bin',\n",
" 'flan-t5-xxl-sharded-fp16/pytorch_model-00002-of-00012.bin',\n",
" 'flan-t5-xxl-sharded-fp16/requirements.txt',\n",
" 'flan-t5-xxl-sharded-fp16/pytorch_model-00005-of-00012.bin',\n",
" 'flan-t5-xxl-sharded-fp16/pytorch_model-00003-of-00012.bin',\n",
" 'flan-t5-xxl-sharded-fp16/handler.py',\n",
" 'flan-t5-xxl-sharded-fp16/spiece.model']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from distutils.dir_util import copy_tree\n",
"from pathlib import Path\n",
"from tempfile import TemporaryDirectory\n",
"from huggingface_hub import snapshot_download\n",
"import os\n",
"\n",
"tmpExists = os.path.exists(TMP_DIR)\n",
"if not tmpExists:\n",
" # Create a new directory because it does not exist\n",
" os.makedirs(TMP_DIR)\n",
"\n",
"model_dir = HF_MODEL_ID.split(\"/\")[-1]\n",
"model_dir_exists = os.path.exists(model_dir)\n",
"if not model_dir_exists:\n",
" os.makedirs(model_dir)\n",
"\n",
"snapshot_dir = snapshot_download(repo_id=HF_MODEL_ID, cache_dir=TMP_DIR)\n",
"# copy snapshot to model dir\n",
"copy_tree(snapshot_dir, str(model_dir))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "85d7b100-ac26-4968-840c-2e8d8a214b9a",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"===================================BUG REPORT===================================\n",
"Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
"================================================================================\n"
]
},
{
"data": {
"application/json": {
"ascii": false,
"bar_format": null,
"colour": null,
"elapsed": 0.0073816776275634766,
"initial": 0,
"n": 0,
"ncols": null,
"nrows": null,
"postfix": null,
"prefix": "Loading checkpoint shards",
"rate": null,
"total": 12,
"unit": "it",
"unit_divisor": 1000,
"unit_scale": false
},
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb61f74a88234c02ab20c7bad1a090d8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/12 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
"\n",
"from typing import Dict, List, Any\n",
"from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
"import torch\n",
"\n",
"model_dir = HF_MODEL_ID\n",
"\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_dir, device_map=\"auto\", load_in_8bit=True)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_dir)"
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "89f5f683-5ead-49a8-8940-ea27ed736bb6",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def generate(input_text):\n",
" input_ids = tokenizer(input_text, return_tensors=\"pt\").input_ids.to('cuda')\n",
"# output = model.generate(input_ids,\n",
"# max_length=100,\n",
"# temperature=1.0,\n",
"# repetition_penalty=3.1,\n",
"# num_beams=3,\n",
"# do_sample=True,\n",
"# top_k=50,\n",
"# top_p=1)\n",
" output = model.generate(input_ids, max_new_tokens=100)\n",
" return tokenizer.decode(output[0], skip_special_tokens=True)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "1a34ed55-0e3e-490c-87ae-5de6344c55b1",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'Wie alt sind Sie?'"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"\"\"translate English to German: How old are you?\"\"\"\n",
"generate(input_text)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f3348868-7b56-4431-bb7e-94a714f76353",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'212 degrees'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"\"\"Please answer the following question. \n",
"What is the boiling point of water if farenheit?\"\"\"\n",
"generate(input_text)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0c1943be-2a9b-4708-828a-475a98b43fc1",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'The cafeteria has 23 - 20 = 3 apples. They have 3 + 6 = 9 apples. Therefore, the answer is 9.'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"\"\"\n",
"Answer the following question by reasoning step by step.\n",
"The cafeteria had 23 apples. If they used 20 for lunch and bought 6 more, how many apples do they have?\n",
"\"\"\"\n",
"generate(input_text)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "0c7c226c-29a4-4586-9524-533785e2be54",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'George Washington died in 1799. Obama was born in 1961. The final answer: no.'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"\"\"\n",
"Q: Can Obama have a conversation with George Washington?\n",
"Give the rationale before answering\n",
"\"\"\"\n",
"generate(input_text)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "3d39b5d3-ec24-4992-ae40-8d4c88709d63",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'Monkeys are a type of animal. Monkeys are a type of mammal. Mammals can write. The answer is yes.'"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"\"\"\n",
"Q: Answer the following yes/no question by\n",
"reasoning step-by-step.\n",
"Could a dandelion suffer from hepatitis?\n",
"A: Hepatitis only affects organisms with livers.\n",
"Dandelions don’t have a liver. The answer is no.\n",
"Q: Answer the following yes/no question by\n",
"reasoning step-by-step.\n",
"Can a monkey write a post?\n",
"A: \n",
"\"\"\"\n",
"\n",
"generate(input_text)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e73d1085-2c55-4b26-8a39-49841a20e494",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'yes'"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"\"\"\n",
"\"Premise: Alberto is the CTO of a top NLP company. Hypothesis: Alberto is a tech expert. Does the premise entail the hypothesis?\"\n",
"\"\"\"\n",
"\n",
"generate(input_text)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "95604c2a-def7-4ec6-9e1f-096a7a1d879a",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'A tweet is 140 characters long. A haiku is a form of poetry that uses 5-7-5 syllables. The answer is no.'"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"\"\"\n",
"Q: Answer the following yes/no question by\n",
"reasoning step-by-step.\n",
"Could a dandelion suffer from hepatitis?\n",
"A: Hepatitis only affects organisms with livers.\n",
"Dandelions don’t have a liver. The answer is no.\n",
"Q: Answer the following yes/no question by\n",
"reasoning step-by-step.\n",
"Can you write a whole Haiku in a single tweet?\n",
"A:\"\"\"\n",
"generate(input_text)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "3a41a3ec-eb87-42d1-a5d5-5dd8a62a31c5",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'The Eiffel Tower is a wrought iron tower in the centre of Paris, France.'"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"\"\"\n",
"Summarize the following text: The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.\n",
"\"\"\"\n",
"generate(input_text)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "3547fdc9-7fcb-4987-b652-1469b0b93377",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'How tall is the Eiffel Tower?'"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"\"\"\n",
"Generate a question for the following text: The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.\n",
"\"\"\"\n",
"generate(input_text)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "1950b448-0548-4393-ab28-49a7da7af51f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'iron'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"\"\"\n",
"The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.\n",
"Q: Which material is the tower made of?\n",
"\"\"\"\n",
"generate(input_text)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "eb0d5103-9eeb-493a-b5d7-a4cd94b327a5",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"\"Politicians lie all the time, so it's not crazy to say you've never seen one. The answer: (A).\""
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"\"\"\n",
"Q: Which statement is sarcastic?\n",
"Options:\n",
"(A) Wow thats crazy, I've never seen a politician lying!\n",
"(B) Wow thats crazy, I've never seen Obama lying!\n",
"A: Let's think step by step.\n",
"\"\"\"\n",
"generate(input_text)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2b6d63ac-e118-4cc3-a934-a0f805dff29a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'yes'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"\"\"\n",
"Movie review: This movie is the best\n",
"RomCom since Pretty Woman.\n",
"Did this critic like the movie?\n",
"OPTIONS\n",
"- yes\n",
"- no\"\"\"\n",
"generate(input_text)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"id": "97255744-141c-49d2-847e-7952db384d4d",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"\"Understand the game's mechanics.\""
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_text = \"\"\"Summarize the following text: The Elder Scrolls V Skyrim is an action role-playing game, playable from either a first or third-person perspective. The player may freely roam over the land of Skyrim which is an open world environment consisting of wilderness expanses, dungeons, cities, towns, fortresses, and villages. Players may navigate the game world more quickly by riding horses or by utilizing a fast-travel system which allows them to warp to previously discovered locations. The game's main quest can be completed or ignored at the player's preference after the first stage of the quest is finished. However, some quests rely on the main storyline being at least partially completed. Non-player characters (NPCs) populate the world and can be interacted with in a number of ways: the player may engage them in conversation, marry an eligible NPC, kill them or engage in a nonlethal \"brawl\".\"\"\"\n",
"generate(input_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c67bcb0f-2634-4655-8d61-39f8b313c532",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "e0a0c3ae-a5cc-48ef-a2e6-05e7a1e03c4b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_pytorch_p39",
"language": "python",
"name": "conda_pytorch_p39"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment