Last active
March 7, 2024 08:30
-
-
Save Norod/11997c0c9a330d0eeb9a6d4791b9aa2f to your computer and use it in GitHub Desktop.
colab-hebrewgpt-gradiodemo.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"private_outputs": true, | |
"provenance": [], | |
"gpuType": "T4", | |
"machine_shape": "hm", | |
"authorship_tag": "ABX9TyMi2WD/mcCXUjnpfzsrS88a", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/Norod/11997c0c9a330d0eeb9a6d4791b9aa2f/colab-hebrewgpt-gradiodemo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install transformers tokenizers deepspeed xformers bitsandbytes accelerate gradio huggingface_hub sentencepiece" | |
], | |
"metadata": { | |
"id": "8BKOA-W_Zndy" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "cIBivOgKZY6h", | |
"cellView": "form" | |
}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"import deepspeed\n", | |
"import torch\n", | |
"from transformers import pipeline\n", | |
"import os\n", | |
"\n", | |
"model_id = 'dicta-il/dictalm-7b-instruct' #@param ['dicta-il/dictalm-7b-instruct','dicta-il/dictalm-7b','dicta-il/dictalm-rab-7b','Norod78/hebrew-gpt_neo-xl']\n", | |
"\n", | |
"text_title = model_id.replace(\"/\", \" - \") + ' - Gradio Demo'\n", | |
"\n", | |
"should_use_fast = True\n", | |
"print(f'should_use_fast = {should_use_fast}')\n", | |
"\n", | |
"local_rank = int(os.getenv('LOCAL_RANK', '0'))\n", | |
"world_size = int(os.getenv('WORLD_SIZE', '1'))\n", | |
"generator = pipeline('text-generation', model=model_id,\n", | |
" tokenizer=model_id,\n", | |
" torch_dtype = torch.float16,\n", | |
" use_fast=should_use_fast,\n", | |
" trust_remote_code=True, #Because of configuration_megatron_gpt.py in dicta-il's repo\n", | |
" device_map = \"auto\")\n", | |
"\n", | |
"# setting device on GPU if available, else CPU\n", | |
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
"print('Using device:', device)\n", | |
"print()\n", | |
"\n", | |
"total_mem = 0\n", | |
"\n", | |
"#Additional Info when using cuda\n", | |
"if device.type == 'cuda':\n", | |
" print(torch.cuda.get_device_name(0))\n", | |
" print('Memory Usage:')\n", | |
" total_mem = round(torch.cuda.get_device_properties(0).total_memory/1024**3,1)\n", | |
" print('Total: ', total_mem, 'GB')\n", | |
" print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')\n", | |
" print('Cached: ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')\n", | |
"\n", | |
"should_replace_with_kernel_inject = False\n", | |
"if (total_mem >= 12):\n", | |
" should_replace_with_kernel_inject = True\n", | |
"\n", | |
"print(f'should_replace_with_kernel_inject = {should_replace_with_kernel_inject}')\n", | |
"\n", | |
"ds_engine = deepspeed.init_inference(generator.model,\n", | |
" mp_size=world_size,\n", | |
" dtype=torch.half,\n", | |
" replace_with_kernel_inject=should_replace_with_kernel_inject)\n", | |
"generator.model = ds_engine.module" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#Sanity, Just make sure we can generate\n", | |
"init_text = \"מהו מודל שפה?\"\n", | |
"\n", | |
"string = generator(init_text, do_sample=True, min_length=20, max_length=64, top_k=40, top_p=0.92, temperature=0.9)\n", | |
"if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:\n", | |
" print(string)" | |
], | |
"metadata": { | |
"id": "DWXB0Il-jQaB" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"top_k = 40 #@param int\n", | |
"top_p = 0.92 #@param float\n", | |
"temperature = 0.75 #@param float" | |
], | |
"metadata": { | |
"id": "bAxYXg3QcDtM" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import gradio as gr\n", | |
"from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer\n", | |
"from threading import Thread\n", | |
"import torch\n", | |
"\n", | |
"model = generator.model\n", | |
"tok = generator.tokenizer\n", | |
"\n", | |
"CUDA_AVAILABLE = torch.cuda.is_available()\n", | |
"device = torch.device(\"cuda\" if CUDA_AVAILABLE else \"cpu\")\n", | |
"\n", | |
"def generate(text = \"\"):\n", | |
" print(\"Create streamer\")\n", | |
" yield \"[אנא המתינו לתשובה]\"\n", | |
" streamer = TextIteratorStreamer(tok, timeout=5.)\n", | |
" if len(text) == 0 or 'instruct' in model_id:\n", | |
" text = text + \"\\n\"\n", | |
"\n", | |
" inputs = tok([text], return_tensors=\"pt\").to(device)\n", | |
" generation_kwargs = dict(inputs, streamer=streamer, do_sample=True, top_k=top_k, top_p=top_p, temperature=temperature, num_beams = 1 ,max_new_tokens=256, pad_token_id = model.config.eos_token_id, early_stopping=True, no_repeat_ngram_size=4)\n", | |
" thread = Thread(target=model.generate, kwargs=generation_kwargs)\n", | |
" thread.start()\n", | |
" generated_text = \"\"\n", | |
" for new_text in streamer:\n", | |
" if tok.eos_token not in new_text:\n", | |
" yield generated_text + new_text\n", | |
" print(new_text, end =\"\")\n", | |
" generated_text += new_text\n", | |
" else:\n", | |
" new_text.replace(tok.eos_token, \"\\n\")\n", | |
" print(new_text, end =\"\")\n", | |
" generated_text += new_text\n", | |
" return generated_text\n", | |
" return generated_text\n", | |
"\n", | |
"demo = gr.Interface(\n", | |
" title=text_title,\n", | |
" fn=generate,\n", | |
" inputs=gr.Textbox(label=\"כתבו כאן את הטקסט שלכם או השאירו ריק\", elem_id=\"input_text\", text_align='right', rtl=True),\n", | |
" outputs=gr.Textbox(type=\"text\", label=\"פה יופיע הטקסט שהמחולל יחולל\", elem_id=\"output_text\", text_align='right', rtl=True),\n", | |
" css=\"#output_text{direction: rtl} #input_text{direction: rtl}\",\n", | |
" allow_flagging='never'\n", | |
")\n", | |
"\n", | |
"demo.queue()\n", | |
"demo.launch(debug=True)" | |
], | |
"metadata": { | |
"id": "FhSCepWx77Vu" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment