Created
June 8, 2020 08:01
-
-
Save CM-Mr-Mo/30df1d193fbce5d6ce0e152a0a24d903 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!pip install -U pip\n", | |
"!pip install ktrain\n", | |
"!pip install mecab-python3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%reload_ext autoreload\n", | |
"%autoreload 2\n", | |
"%matplotlib inline\n", | |
"import os\n", | |
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", | |
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!wget https://www.rondhuit.com/download/ldcc-20140209.tar.gz" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true, | |
"jupyter": { | |
"outputs_hidden": true | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"!tar -zxvf ldcc-20140209.tar.gz" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from pathlib import Path\n", | |
"datapath = Path('text/')\n", | |
"topics = [x.parts[-1] for x in datapath.iterdir() if x.is_dir()]\n", | |
"topics" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.datasets import load_files\n", | |
"dataset = load_files(datapath, categories=topics, encoding='utf8')\n", | |
"dataset.data = [doc.split('\\n', 2)[2] for doc in dataset.data]\n", | |
"dataset.data[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.model_selection import train_test_split\n", | |
"X, y = dataset.data, dataset.target\n", | |
"X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, \n", | |
" test_size=0.1, random_state=42)\n", | |
"len(y_train), len(y_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import ktrain\n", | |
"from ktrain import text\n", | |
"MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'\n", | |
"t = text.Transformer(MODEL_NAME, maxlen=128, classes=topics)\n", | |
"trn = t.preprocess_train(X_train, y_train)\n", | |
"val = t.preprocess_test(X_test, y_test)\n", | |
"model = t.get_classifier()\n", | |
"learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=32) # lower bs if OOM occurs\n", | |
"#learner.fit_onecycle(5e-5, 3)\n", | |
"learner.autofit(5e-5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learner.validate(class_names=t.get_classes())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learner.view_top_losses()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"predictor = ktrain.get_predictor(learner.model, t)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X_test[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"y_test[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"predictor.predict(X_test[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"predictor.save('model/livedoor')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"p = ktrain.load_predictor('model/livedoor')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"p.predict(X_test[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "conda_tensorflow_p36", | |
"language": "python", | |
"name": "conda_tensorflow_p36" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment