Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save arjunguha/34dc21b639512878eb92c1d0ffc49231 to your computer and use it in GitHub Desktop.
Save arjunguha/34dc21b639512878eb92c1d0ffc49231 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# StarCoder with the Hugging Face Text Generation Inference Server\n",
"\n",
"This notebook shows you how to write a client that uses StarCoder with the \n",
"Hugging Face [Text Generation Inference Server](https://github.com/huggingface/text-generation-inference). \n",
"There is nothing StarCoder-specific here. This code should work with any\n",
"served model."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/arjun/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from text_generation import Client\n",
"# This will help with \"batching\".\n",
"from concurrent.futures.thread import ThreadPoolExecutor\n",
"from tqdm import tqdm"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Set the URL of the server below."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"URL = \"http://starcoder.tenant-85080e-nuprl.coreweave.cloud\"\n",
"client = Client(URL)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This is basic code to generate from a single prompt. The `print_by_line`\n",
"function shows output line-by-line (I don't think we can do token-by-token,\n",
"in ChatGPT style in a notebook.) This function is not necessary."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def print_by_line(previous_text: str, new_text: str):\n",
" \"\"\"\n",
" A little hack to print line-by-line in a Notebook. We receive results\n",
" a few tokens at a time. This buffers output until a newline, so that\n",
" we do not print partial lines.\n",
" \"\"\"\n",
" if \"\\n\" not in new_text:\n",
" return\n",
" last_newline = previous_text.rfind(\"\\n\")\n",
" if last_newline != -1:\n",
" print(previous_text[last_newline+1:] + new_text, end=\"\")\n",
" else:\n",
" print(previous_text + new_text, end=\"\")\n",
"\n",
"DEFAULT_STOP_SEQUENCES = [ \"\\ndef\", \"\\nclass\", \"\\nif\" ]\n",
"\n",
"def generate(prompt: str,\n",
" max_new_tokens=512,\n",
" stop_sequences=DEFAULT_STOP_SEQUENCES,\n",
" echo=True):\n",
" text = \"\"\n",
" for response in client.generate_stream(prompt,\n",
" max_new_tokens=max_new_tokens,\n",
" temperature=0.2,\n",
" top_p=0.95,\n",
" stop_sequences=stop_sequences):\n",
" if not response.token.special:\n",
" if echo:\n",
" print_by_line(text, response.token.text)\n",
" text += response.token.text\n",
" if echo:\n",
" print_by_line(text, \"\\n\") # flush any remaining text\n",
" return text"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"The `generate` function blocks until it receives a completion. If you have\n",
"a batch of several prompts, this may grossly underutilize your GPU. The\n",
"`text-generation-inference` server does not natively support batch input.\n",
"But, it does create batches behind the scenes to efficiently process \n",
"concurrent requests. So, we can achieve the same effect as batching on a\n",
"single client by issuing several concurrent requests."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def generate_concurrent(\n",
" prompts: list,\n",
" max_new_tokens: int,\n",
" max_workers: int,\n",
" stop_sequences=DEFAULT_STOP_SEQUENCES):\n",
" with ThreadPoolExecutor(max_workers=max_workers) as ctx:\n",
" return list(tqdm(\n",
" ctx.map(lambda prompt: generate(prompt, max_new_tokens, stop_sequences, echo=False), prompts),\n",
" desc=\"Completions\",\n",
" total=len(prompts)))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Examples\n",
"\n",
"### Generating a Web Page"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"<head>\n",
"<title>Search Engine</title>\n",
"<meta name=\"description\" content=\"A search engine for the web\">\n",
"<meta name=\"keywords\" content=\"search engine, web search, google, yahoo, bing\">\n",
"<meta name=\"author\" content=\"<NAME>\">\n",
"<meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n",
"<link rel=\"stylesheet\" href=\"style.css\">\n",
"</head>\n",
"<body>\n",
"<div id=\"container\">\n",
"<h1>Search Engine</h1>\n",
"<form action=\"search.php\" method=\"get\">\n",
"<input type=\"text\" name=\"q\" placeholder=\"Search...\">\n",
"<input type=\"submit\" value=\"Search\">\n",
"</form>\n",
"<p>This is a search engine for the web. It is a simple PHP script that uses the Google, Yahoo and Bing search engines to search the web. It is a simple search engine that can be used for educational purposes.</p>\n",
"<p>This search engine is not affiliated with Google, Yahoo or Bing. It is a simple search engine that can be used for educational purposes.</p>\n",
"<p>This search engine is not affiliated with Google, Yahoo or Bing. It is a simple search engine that can be used for educational purposes.</p>\n",
"<p>This search engine is not affiliated with Google, Yahoo or Bing. It is a simple search engine that can be used for educational purposes.</p>\n",
"<p>This search engine is not affiliated with Google, Yahoo or Bing. It is a simple search engine that can be used for educational purposes.</p>\n",
"<p>This search engine is not affiliated with Google, Yahoo or Bing. It is a simple search engine that can be used for educational purposes.</p>\n",
"<p>This search engine is not affiliated with Google, Yahoo or Bing. It is a simple search engine that can be used for educational purposes.</p>\n",
"<p>This search engine is not affiliated with Google, Yahoo or Bing. It is a simple search engine that can be used for educational purposes.</p>\n",
"<p>This search engine is not affiliated with Google, Yahoo or Bing. It is a simple search engine that can be used\n"
]
}
],
"source": [
"html_result = generate(\n",
" \"<html><!-- A HTML page for a search engine, with a simple text box in the center -->\",\n",
" stop_sequences=[\"</html>\"]) "
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Batching\n",
"\n",
"In the cell below, `fast_results` completes much faster than `slow_results`, because\n",
"it uses more concurrent workers. There will be no point exceeding the maximum batch size\n",
"supported by the server."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Completions: 100%|██████████| 20/20 [00:19<00:00, 1.02it/s]\n",
"Completions: 100%|██████████| 20/20 [00:56<00:00, 2.80s/it]\n"
]
}
],
"source": [
"fast_results = generate_concurrent([\"# Write factorial\", \"# Write fib\"] * 10,\n",
" max_new_tokens=512,\n",
" max_workers=20,\n",
" stop_sequences=[])\n",
"slow_results = generate_concurrent(\n",
" [\"# Write factorial\", \"# Write fib\"] * 10,\n",
" max_new_tokens=512,\n",
" max_workers=5,\n",
" stop_sequences=[])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Chat Mode\n",
"\n",
"This is a hack. Chat will only run for about 1000 tokens. But, that can be tweaked by tweaking the CHAT_PROMPT below."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"CHAT_PROMPT = requests.get(\"https://gist.githubusercontent.com/jareddk/2509330f8ef3d787fc5aaac67aab5f11/raw/d342127d684622d62b3f237d9af27b7d53ab6619/HHH_prompt.txt\").text"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Rerun the cell below to reset the chat state. Note that generate should be configured to sample differently: we want a repetition penalty and probably a higher temperature."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"chat_state = CHAT_PROMPT + \"\\n\"\n",
"def send(message):\n",
" global chat_state\n",
" message_to_send = \"\\nHuman: \" + message + \"\\n\\nAssistant:\"\n",
" result = generate(chat_state + message_to_send, \n",
" max_new_tokens=256, \n",
" stop_sequences=[\"Human:\", \"-----\"])\n",
" if result.endswith(\"Human:\"):\n",
" result = result[:-len(\"Human:\")]\n",
" elif result.endswith(\"-----\"):\n",
" result = result[:-len(\"-----\")]\n",
" else:\n",
" print(\"<stopped early>\")\n",
" chat_state += message_to_send + result"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Sure, here’s one example:\n",
"\n",
"factorial :: Integer -> Integer\n",
"factorial 0 = 1\n",
"factorial n = n * factorial (n - 1)\n",
"\n",
"-----\n"
]
}
],
"source": [
"send(\"Please write the factorial function in Haskell.\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Sure, here’s one example:\n",
"\n",
"from flask import Flask\n",
"app = Flask(__name__)\n",
"\n",
"@app.route(\"/\")\n",
"def hello():\n",
" return \"Hello World!\"\n",
"\n",
"if __name__ == \"__main__\":\n",
" app.run()\n",
"\n",
"-----\n"
]
}
],
"source": [
"send(\"Write a web server in Python.\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment