Skip to content

Instantly share code, notes, and snippets.

@myersguo
Created June 2, 2023 05:47
Show Gist options
  • Save myersguo/1b93310cfce70de695e4c99184236802 to your computer and use it in GitHub Desktop.
Save myersguo/1b93310cfce70de695e4c99184236802 to your computer and use it in GitHub Desktop.
gpt-practise.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyOyPyhRfyrFgDYowwuTYpwQ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/myersguo/1b93310cfce70de695e4c99184236802/gpt-practise.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"1. 安装 Transformers 和 PyTorch。我们需要这两个库来加载和使用GPT-neo模型。\n",
"\n"
],
"metadata": {
"id": "4uwV8-9kl88Z"
}
},
{
"cell_type": "code",
"source": [
"pip install transformers torch"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OYoTYCIQjINn",
"outputId": "da9726ea-3ab2-4db8-bbbc-6fd5b0118a98"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.29.2)\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.0.1+cu118)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.0)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.1)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (3.25.2)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (16.0.5)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.4.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.2)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.15)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2022.12.7)\n",
"Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"2. 导入所需的库"
],
"metadata": {
"id": "_zldf-aRl8KZ"
}
},
{
"cell_type": "code",
"source": [
"import torch \n",
"from transformers import AutoTokenizer, AutoModelWithLMHead"
],
"metadata": {
"id": "FFydzHRajxLy"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"加载GPT-neO-1.3B的tokenizer和模型。我们使用`AutoTokenizer`和`AutoModelWithLMHead`来自动下载模型。"
],
"metadata": {
"id": "w6JBB1TFmE5h"
}
},
{
"cell_type": "code",
"source": [
"tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neo-1.3B\")\n",
"model = AutoModelWithLMHead.from_pretrained(\"EleutherAI/gpt-neo-1.3B\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9931UiMZjzxq",
"outputId": "fca38266-1f93-44bb-f1f4-cc95d59ada20"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/transformers/models/auto/modeling_auto.py:1352: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
" warnings.warn(\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"为 prompt 添加输入文本。我们将输入\"Hello, my name is\"作为prompt。"
],
"metadata": {
"id": "eTOqX1mCmHbf"
}
},
{
"cell_type": "code",
"source": [
"prompt = \"Hello, my name is\""
],
"metadata": {
"id": "IO9JGkN6j2a4"
},
"execution_count": 25,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"5. 将prompt转换为模型的输入格式"
],
"metadata": {
"id": "Bb497dItmOSE"
}
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "r16O94VUmRo0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"input_ids = tokenizer.encode(prompt, return_tensors=\"pt\")"
],
"metadata": {
"id": "A1mSQrKqktgt"
},
"execution_count": 26,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"6. 用模型生成文本"
],
"metadata": {
"id": "Owi7Fwy7mQ6j"
}
},
{
"cell_type": "code",
"source": [
"output = model.generate(input_ids, max_length=20, temperature=0.8)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "l1CAgPwCky-t",
"outputId": "aa48adbb-700e-41ed-8ef4-0a8bb16aba78"
},
"execution_count": 27,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"7. 将输出解码为文本"
],
"metadata": {
"id": "Ht_PFUAtmUUP"
}
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "Xc4DgfZDmXcA"
}
},
{
"cell_type": "code",
"source": [
"output_text = tokenizer.decode(output[0], skip_special_tokens=True)"
],
"metadata": {
"id": "-G1s47x8k5fI"
},
"execution_count": 28,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"8. 打印结果"
],
"metadata": {
"id": "5WjwG4PEmZOT"
}
},
{
"cell_type": "code",
"source": [
"print(output_text)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0sGtHBUfk7KJ",
"outputId": "06f45873-6c0b-4c64-cd80-f347a4fdd617"
},
"execution_count": 29,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"what color is human's skin?\n",
"\n",
"I am a white male, and I am not a\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment