Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save alberduris/06f5095ddbc293501d65c8d2741899f3 to your computer and use it in GitHub Desktop.
Save alberduris/06f5095ddbc293501d65c8d2741899f3 to your computer and use it in GitHub Desktop.
[SkLearn TrainTestSplit OneHot Behaviour] #JupyterNotebook #CodeSnippet #@todo: BUG: `stratify=data_holder.Y_data` instead of `stratify=data_holder.Y_data` #@bug: ¿strange behaviour? If stratify is one_hot then changes w.r.t string or int encoded #@bug: See SkLearn_TrainTestSplit_OneHot_Behaviour.ipynb #Others
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Sklearn Train/Test split random test"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from keras.utils import np_utils\n",
"\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"ids = np.arange(20)\n",
"ids_2 = np.random.normal(size=(20,))\n",
"\n",
"classes = ['label1', 'label2', 'label3']\n",
"\n",
"labels = np.random.choice(classes, size=20)\n",
"\n",
"label_encoder = LabelEncoder()\n",
"\n",
"encoded_labels = label_encoder.fit_transform(labels)\n",
"\n",
"one_hot_labels = np_utils.to_categorical(encoded_labels)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Labels: ['label3' 'label3' 'label2' 'label3' 'label2' 'label2' 'label3' 'label3'\n",
" 'label3' 'label1' 'label2' 'label3' 'label1' 'label1' 'label2' 'label1'\n",
" 'label2' 'label3' 'label1' 'label2']\n",
"Encoded labels: [2 2 1 2 1 1 2 2 2 0 1 2 0 0 1 0 1 2 0 1]\n",
"OneHot labels: \n",
"[[0. 0. 1.]\n",
" [0. 0. 1.]\n",
" [0. 1. 0.]\n",
" [0. 0. 1.]\n",
" [0. 1. 0.]\n",
" [0. 1. 0.]\n",
" [0. 0. 1.]\n",
" [0. 0. 1.]\n",
" [0. 0. 1.]\n",
" [1. 0. 0.]\n",
" [0. 1. 0.]\n",
" [0. 0. 1.]\n",
" [1. 0. 0.]\n",
" [1. 0. 0.]\n",
" [0. 1. 0.]\n",
" [1. 0. 0.]\n",
" [0. 1. 0.]\n",
" [0. 0. 1.]\n",
" [1. 0. 0.]\n",
" [0. 1. 0.]]\n"
]
}
],
"source": [
"print('Labels: {}'.format(labels))\n",
"print('Encoded labels: {}'.format(encoded_labels))\n",
"print('OneHot labels: \\n{}'.format(one_hot_labels))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 6, 5, 7, 11, 12, 3, 18, 4, 15, 10, 8, 1, 14, 16]),\n",
" array([ 2, 13, 0, 17, 9, 19]))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tr_ids, te_ids, tr_labels, te_labels = train_test_split(ids, labels, stratify=labels, test_size=0.3, random_state=2019)\n",
"tr_ids, te_ids"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 6, 5, 7, 11, 12, 3, 18, 4, 15, 10, 8, 1, 14, 16]),\n",
" array([ 2, 13, 0, 17, 9, 19]))"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tr_ids, te_ids, tr_labels, te_labels = train_test_split(ids, encoded_labels, stratify=encoded_labels, test_size=0.3, random_state=2019)\n",
"tr_ids, te_ids"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([19, 2, 15, 18, 5, 7, 11, 1, 17, 13, 10, 4, 8, 6]),\n",
" array([ 3, 9, 12, 0, 16, 14]))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tr_ids, te_ids, tr_labels, te_labels = train_test_split(ids, one_hot_labels, stratify=one_hot_labels, test_size=0.3, random_state=2019)\n",
"tr_ids, te_ids"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 1., 0.]], dtype=float32),\n",
" array(['label3', 'label1', 'label1', 'label3', 'label2', 'label2'],\n",
" dtype='<U6'))"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"te_labels, label_encoder.classes_[np.argmax(te_labels, 1)]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment