Skip to content

Instantly share code, notes, and snippets.

@Norod
Created April 24, 2024 17:08
Show Gist options
  • Save Norod/5a311a8e0a774b5c35919913545b7af4 to your computer and use it in GitHub Desktop.
Save Norod/5a311a8e0a774b5c35919913545b7af4 to your computer and use it in GitHub Desktop.
Apple_OpenELM-270M_cpu_Gradio-Demo.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNTwrv0H/WIAsYrXznTst4v",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Norod/5a311a8e0a774b5c35919913545b7af4/apple_openelm-270m_cpu_gradio-demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "RTmA16d-1eQV"
},
"outputs": [],
"source": [
"!pip install -q gradio"
]
},
{
"cell_type": "markdown",
"source": [
"Make sure to set the HF_TOKEN notebook secret and allow access to a read token"
],
"metadata": {
"id": "MZt9Q8sKCNQK"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import gradio as gr\n",
"from threading import Thread\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32\n",
"low_cpu_mem_usage = True if torch.cuda.is_available() else False\n",
"\n",
"checkpoint = \"apple/OpenELM-270M\"\n",
"checkpoint_tok = \"meta-llama/Llama-2-7b-hf\"\n",
"tokenizer = AutoTokenizer.from_pretrained(checkpoint_tok)\n",
"model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch_dtype, trust_remote_code=True, low_cpu_mem_usage=low_cpu_mem_usage)\n",
"model.to(device)\n",
"\n",
"text_title = checkpoint.replace(\"/\", \" - \") + ' (' + str(model.device) + ') - Gradio Demo'\n"
],
"metadata": {
"id": "frvAniQU1iWm"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"tokenizer"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JI-NZMTxAstc",
"outputId": "c9be4c5f-bb9f-4da9-dcf0-947c3ef14962"
},
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"LlamaTokenizerFast(name_or_path='meta-llama/Llama-2-7b-hf', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>'}, clean_up_tokenization_spaces=False), added_tokens_decoder={\n",
"\t0: AddedToken(\"<unk>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t1: AddedToken(\"<s>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t2: AddedToken(\"</s>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"}"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"source": [
"if tokenizer.pad_token == None:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
" tokenizer.pad_token_id = tokenizer.eos_token_id\n"
],
"metadata": {
"id": "zsYtkzTC58sM"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#from transformers import set_seed\n",
"#set_seed(42)\n",
"#Make sure we can generate text\n",
"text = \"Hello my name is Doron and I am\"\n",
"inputs = tokenizer([text], return_tensors = \"pt\").input_ids.to(model.device)\n",
"pred_ids = model.generate(input_ids=inputs, do_sample=True, max_new_tokens=32, repetition_penalty=1.2)\n",
"#print(pred_ids)\n",
"pred_text = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n",
"print(pred_text[0])\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GACVDdQ_6lIl",
"outputId": "08f13f12-879b-48ba-9fe4-de9c58949a11"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Hello my name is Doron and I am an expert in teaching computer graphics. There are many different classes of people who study graphic arts, I have been doing many different types of classes since 1\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"\n",
"########################################################################\n",
"# Settings\n",
"########################################################################\n",
"\n",
"#Set the maximum number of tokens to generate\n",
"max_new_tokens = 250\n",
"\n",
"#Set a the value of the repetition penalty\n",
"#The higher the value, the less repetitive the generated text will be\n",
"#Note that `repetition_penalty` has to be a strictly positive float\n",
"repetition_penalty = 1.4\n",
"\n",
"#Set the text direction\n",
"#For languages that are written from right to left (RTL), set rtl to True\n",
"rtl = False\n",
"\n",
"########################################################################\n",
"\n",
"print(f\"Settings: max_new_tokens = {max_new_tokens}, repetition_penalty = {repetition_penalty}, rtl = {rtl}\")\n",
"\n",
"if rtl:\n",
" text_title += \" - RTL\"\n",
" text_align = 'right'\n",
" css = \"#output_text{direction: rtl} #input_text{direction: rtl}\"\n",
"else:\n",
" text_align = 'left'\n",
" css = \"\"\n",
"\n",
"\n",
"def generate(text = \"\"):\n",
" print(\"Create streamer\")\n",
" yield \"[Please wait for an answer]\"\n",
"\n",
" decode_kwargs = dict(skip_special_tokens = True, clean_up_tokenization_spaces = True)\n",
" streamer = TextIteratorStreamer(tokenizer, timeout = 5., decode_kwargs = decode_kwargs)\n",
"\n",
" inputs = tokenizer([text], return_tensors = \"pt\").input_ids.to(model.device)\n",
" print(tokenizer.decode(inputs[0], skip_special_tokens=True))\n",
"\n",
" generation_kwargs = dict(input_ids=inputs, streamer = streamer, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty)\n",
"\n",
" print(\"Create thread\")\n",
" thread = Thread(target = model.generate, kwargs = generation_kwargs)\n",
" thread.start()\n",
" generated_text = \"\"\n",
" for new_text in streamer:\n",
" if new_text == None:\n",
" continue\n",
" if tokenizer.eos_token not in new_text:\n",
" new_text = new_text.replace(tokenizer.pad_token, \"\").replace(tokenizer.bos_token, \"\")\n",
" yield generated_text + new_text\n",
" print(new_text, end =\"\")\n",
" generated_text += new_text\n",
" else:\n",
" new_text = new_text.replace(tokenizer.eos_token, \"\\n\")\n",
" print(new_text, end =\"\")\n",
" generated_text += new_text\n",
" return generated_text\n",
" return generated_text\n",
"\n",
"demo = gr.Interface(\n",
" title = text_title,\n",
" fn = generate,\n",
" inputs = gr.Textbox(label = \"Enter your prompt here\", elem_id = \"input_text\", text_align = text_align, rtl = rtl),\n",
" outputs = gr.Textbox(type = \"text\", label = \"Generated text will appear here\", elem_id = \"output_text\", text_align = text_align, rtl = rtl),\n",
" css = css,\n",
" allow_flagging = 'never'\n",
")\n",
"\n",
"demo.queue()\n",
"demo.launch(debug = True)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 906
},
"id": "Us-cb5OY4qC9",
"outputId": "c28c3e4f-9e10-4c39-c715-1d77948bcd31"
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Settings: max_new_tokens = 250, repetition_penalty = 1.4, rtl = False\n",
"Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n",
"\n",
"Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().\n",
"Running on public URL: https://50b66dcd3a0b393f74.gradio.live\n",
"\n",
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<div><iframe src=\"https://50b66dcd3a0b393f74.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Create streamer\n",
"Question: I'm tired, what should I do?\n",
"Answer:\n",
"Create thread\n",
" Question: I'm tired, what should I do?\n",
"Answer: You are not alone. There is a lot of information out there about how to deal with stress and anxiety. But the truth is that you can't control your own life. The only thing you can control is how you respond to it. If you feel like you have no choice but to be stressed, then you need to take action. Here are some things you can do to help yourself cope with stress.\n",
"1. Take time to relax. When you get home from work, go for a walk or run around the block. This will give you a chance to clear your mind and recharge. It also helps to reduce your stress levels.\n",
"2. Get outside. Go for a walk in nature. Even if you don't live near a park, going for a walk in nature can help you de-stress.\n",
"3. Eat healthy. A diet rich in fruits and vegetables can help you manage stress. Try eating more fruit and vegetables as well as whole grains. These foods contain fiber which helps you keep your digestive system working properly.\n",
"4. Exercise. Doing physical activity releases endorphins which make you feel better.Keyboard interruption in main thread... closing server.\n",
"Killing tunnel 127.0.0.1:7860 <> https://50b66dcd3a0b393f74.gradio.live\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": []
},
"metadata": {},
"execution_count": 15
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment