Skip to content

Instantly share code, notes, and snippets.

@MizunagiKB
Last active August 14, 2021 13:06
Show Gist options
  • Save MizunagiKB/d5049508483758f88b4d8822f63bc8b1 to your computer and use it in GitHub Desktop.
Save MizunagiKB/d5049508483758f88b4d8822f63bc8b1 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "rinna_japanese_gpt2_test.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "TPU"
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "qvYgsXJmysqI"
},
"source": [
"# https://huggingface.co/rinna/japanese-gpt2-medium\n",
"# install modules\n",
"!pip install -q transformers==4.9.2\n",
"!pip install -q sentencepiece==0.1.96\n",
"!pip install -q ipywidgets==7.6.3"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ci9jRdlUzTfQ"
},
"source": [
"# setup model\n",
"RINNA_GPT2_MODEL = \"rinna/japanese-gpt2-medium\"\n",
"\n",
"from transformers import T5Tokenizer, AutoModelForCausalLM, pipeline\n",
"\n",
"tokenizer = T5Tokenizer.from_pretrained(RINNA_GPT2_MODEL)\n",
"tokenizer.do_lower_case = True\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(RINNA_GPT2_MODEL)\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "lSfNxaWqnWsD"
},
"source": [
"# testcode\n",
"from ipywidgets import Layout, Button, IntText, Textarea, VBox, HBox, Label\n",
"\n",
"def f(v):\n",
" if input_text.value == \"\":\n",
" return\n",
"\n",
" input_button.disabled = True\n",
" input_button.description = \"変換中\"\n",
"\n",
" input = tokenizer.encode(input_text.value, return_tensors=\"pt\")\n",
" output = model.generate(input, do_sample=True, max_length=input_max_length.value, num_return_sequences=3)\n",
" tokenizer.batch_decode(output)\n",
"\n",
" print(\"----\\n\")\n",
" print(\"入力内容)\\n{:s}\\n\".format(input_text.value))\n",
" for idx, s in enumerate(tokenizer.batch_decode(output)):\n",
" print(\"出力結果 {:d})\\n{:s}\\n\".format(idx, s.replace(\"</s>\", \"</s>\\n\")))\n",
"\n",
" input_button.description = \"実行\"\n",
" input_button.disabled = False\n",
"\n",
"\n",
"input_text = Textarea(\n",
" value=\"\",\n",
" placeholder=\"\",\n",
" description=\"\",\n",
" disabled=False,\n",
" layout=Layout(width=\"60%\", height=\"8em\")\n",
")\n",
"\n",
"input_max_length = IntText(\n",
" value=120,\n",
" disabled=False,\n",
" layout=Layout(width=\"60%\")\n",
")\n",
"\n",
"input_button = Button(\n",
" description=\"実行\",\n",
" layout=Layout(width=\"60%\")\n",
")\n",
"\n",
"input_button.on_click(f)\n",
"VBox(\n",
" [\n",
" HBox([Label(\"入力内容\", layout=Layout(width=\"10%\")), input_text]),\n",
" HBox([Label(\"出力長\", layout=Layout(width=\"10%\")), input_max_length]),\n",
" HBox([Label(\"\", layout=Layout(width=\"10%\")), input_button])\n",
" ]\n",
" )"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment