Skip to content

Instantly share code, notes, and snippets.

@yamasakih
Last active January 9, 2019 13:02
Show Gist options
  • Save yamasakih/cbe5a75ddb035350b0e652f6a98158da to your computer and use it in GitHub Desktop.
Save yamasakih/cbe5a75ddb035350b0e652f6a98158da to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"金子先生の[demo_opt_gtm_with_k3nerror.ipynb](https://github.com/hkaneko1985/gtm-generativetopographicmapping/blob/master/Python/demo_opt_gtm_with_k3nerror.ipynb)を参考に行なっています"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.figure as figure\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import numpy as np\n",
"from sklearn.datasets import load_iris\n",
"\n",
"from gtm import gtm\n",
"from k3nerror import k3nerror"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# settings\n",
"candidates_of_shape_of_map = np.arange(30, 31, dtype=int)\n",
"candidates_of_shape_of_rbf_centers = np.arange(2, 22, 2, dtype=int)\n",
"candidates_of_variance_of_rbfs = 2 ** np.arange(-5, 4, 2, dtype=float)\n",
"candidates_of_lambda_in_em_algorithm = 2 ** np.arange(-4, 0, dtype=float)\n",
"candidates_of_lambda_in_em_algorithm = np.append(0, candidates_of_lambda_in_em_algorithm)\n",
"number_of_iterations = 300\n",
"display_flag = 0\n",
"k_in_k3nerror = 10\n",
"\n",
"# load an iris dataset\n",
"iris = load_iris()\n",
"# input_dataset = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
"input_dataset = iris.data\n",
"color = iris.target\n",
"\n",
"# autoscaling\n",
"input_dataset = (input_dataset - input_dataset.mean(axis=0)) / input_dataset.std(axis=0, ddof=1)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1, 250]\n",
"[2, 250]\n",
"[3, 250]\n",
"[4, 250]\n",
"[5, 250]\n",
"[6, 250]\n",
"[7, 250]\n",
"[8, 250]\n",
"[9, 250]\n",
"[10, 250]\n",
"[11, 250]\n",
"[12, 250]\n",
"[13, 250]\n",
"[14, 250]\n",
"[15, 250]\n",
"[16, 250]\n",
"[17, 250]\n",
"[18, 250]\n",
"[19, 250]\n",
"[20, 250]\n",
"[21, 250]\n",
"[22, 250]\n",
"[23, 250]\n",
"[24, 250]\n",
"[25, 250]\n",
"[26, 250]\n",
"[27, 250]\n",
"[28, 250]\n",
"[29, 250]\n",
"[30, 250]\n",
"[31, 250]\n",
"[32, 250]\n",
"[33, 250]\n",
"[34, 250]\n",
"[35, 250]\n",
"[36, 250]\n",
"[37, 250]\n",
"[38, 250]\n",
"[39, 250]\n",
"[40, 250]\n",
"[41, 250]\n",
"[42, 250]\n",
"[43, 250]\n",
"[44, 250]\n",
"[45, 250]\n",
"[46, 250]\n",
"[47, 250]\n",
"[48, 250]\n",
"[49, 250]\n",
"[50, 250]\n",
"[51, 250]\n",
"[52, 250]\n",
"[53, 250]\n",
"[54, 250]\n",
"[55, 250]\n",
"[56, 250]\n",
"[57, 250]\n",
"[58, 250]\n",
"[59, 250]\n",
"[60, 250]\n",
"[61, 250]\n",
"[62, 250]\n",
"[63, 250]\n",
"[64, 250]\n",
"[65, 250]\n",
"[66, 250]\n",
"[67, 250]\n",
"[68, 250]\n",
"[69, 250]\n",
"[70, 250]\n",
"[71, 250]\n",
"[72, 250]\n",
"[73, 250]\n",
"[74, 250]\n",
"[75, 250]\n",
"[76, 250]\n",
"[77, 250]\n",
"[78, 250]\n",
"[79, 250]\n",
"[80, 250]\n",
"[81, 250]\n",
"[82, 250]\n",
"[83, 250]\n",
"[84, 250]\n",
"[85, 250]\n",
"[86, 250]\n",
"[87, 250]\n",
"[88, 250]\n",
"[89, 250]\n",
"[90, 250]\n",
"[91, 250]\n",
"[92, 250]\n",
"[93, 250]\n",
"[94, 250]\n",
"[95, 250]\n",
"[96, 250]\n",
"[97, 250]\n",
"[98, 250]\n",
"[99, 250]\n",
"[100, 250]\n",
"[101, 250]\n",
"[102, 250]\n",
"[103, 250]\n",
"[104, 250]\n",
"[105, 250]\n",
"[106, 250]\n",
"[107, 250]\n",
"[108, 250]\n",
"[109, 250]\n",
"[110, 250]\n",
"[111, 250]\n",
"[112, 250]\n",
"[113, 250]\n",
"[114, 250]\n",
"[115, 250]\n",
"[116, 250]\n",
"[117, 250]\n",
"[118, 250]\n",
"[119, 250]\n",
"[120, 250]\n",
"[121, 250]\n",
"[122, 250]\n",
"[123, 250]\n",
"[124, 250]\n",
"[125, 250]\n",
"[126, 250]\n",
"[127, 250]\n",
"[128, 250]\n",
"[129, 250]\n",
"[130, 250]\n",
"[131, 250]\n",
"[132, 250]\n",
"[133, 250]\n",
"[134, 250]\n",
"[135, 250]\n",
"[136, 250]\n",
"[137, 250]\n",
"[138, 250]\n",
"[139, 250]\n",
"[140, 250]\n",
"[141, 250]\n",
"[142, 250]\n",
"[143, 250]\n",
"[144, 250]\n",
"[145, 250]\n",
"[146, 250]\n",
"[147, 250]\n",
"[148, 250]\n",
"[149, 250]\n",
"[150, 250]\n",
"[151, 250]\n",
"[152, 250]\n",
"[153, 250]\n",
"[154, 250]\n",
"[155, 250]\n",
"[156, 250]\n",
"[157, 250]\n",
"[158, 250]\n",
"[159, 250]\n",
"[160, 250]\n",
"[161, 250]\n",
"[162, 250]\n",
"[163, 250]\n",
"[164, 250]\n",
"[165, 250]\n",
"[166, 250]\n",
"[167, 250]\n",
"[168, 250]\n",
"[169, 250]\n",
"[170, 250]\n",
"[171, 250]\n",
"[172, 250]\n",
"[173, 250]\n",
"[174, 250]\n",
"[175, 250]\n",
"[176, 250]\n",
"[177, 250]\n",
"[178, 250]\n",
"[179, 250]\n",
"[180, 250]\n",
"[181, 250]\n",
"[182, 250]\n",
"[183, 250]\n",
"[184, 250]\n",
"[185, 250]\n",
"[186, 250]\n",
"[187, 250]\n",
"[188, 250]\n",
"[189, 250]\n",
"[190, 250]\n",
"[191, 250]\n",
"[192, 250]\n",
"[193, 250]\n",
"[194, 250]\n",
"[195, 250]\n",
"[196, 250]\n",
"[197, 250]\n",
"[198, 250]\n",
"[199, 250]\n",
"[200, 250]\n",
"[201, 250]\n",
"[202, 250]\n",
"[203, 250]\n",
"[204, 250]\n",
"[205, 250]\n",
"[206, 250]\n",
"[207, 250]\n",
"[208, 250]\n",
"[209, 250]\n",
"[210, 250]\n",
"[211, 250]\n",
"[212, 250]\n",
"[213, 250]\n",
"[214, 250]\n",
"[215, 250]\n",
"[216, 250]\n",
"[217, 250]\n",
"[218, 250]\n",
"[219, 250]\n",
"[220, 250]\n",
"[221, 250]\n",
"[222, 250]\n",
"[223, 250]\n",
"[224, 250]\n",
"[225, 250]\n",
"[226, 250]\n",
"[227, 250]\n",
"[228, 250]\n",
"[229, 250]\n",
"[230, 250]\n",
"[231, 250]\n",
"[232, 250]\n",
"[233, 250]\n",
"[234, 250]\n",
"[235, 250]\n",
"[236, 250]\n",
"[237, 250]\n",
"[238, 250]\n",
"[239, 250]\n",
"[240, 250]\n",
"[241, 250]\n",
"[242, 250]\n",
"[243, 250]\n",
"[244, 250]\n",
"[245, 250]\n",
"[246, 250]\n",
"[247, 250]\n",
"[248, 250]\n",
"[249, 250]\n",
"[250, 250]\n",
"CPU times: user 31min 18s, sys: 1min 17s, total: 32min 35s\n",
"Wall time: 5min 29s\n"
]
}
],
"source": [
"%%time\n",
"# grid search\n",
"parameters_and_k3nerror = []\n",
"all_calculation_numbers = len(candidates_of_shape_of_map) * len(candidates_of_shape_of_rbf_centers) * len(\n",
" candidates_of_variance_of_rbfs) * len(candidates_of_lambda_in_em_algorithm)\n",
"calculation_number = 0\n",
"for shape_of_map_grid in candidates_of_shape_of_map:\n",
" for shape_of_rbf_centers_grid in candidates_of_shape_of_rbf_centers:\n",
" for variance_of_rbfs_grid in candidates_of_variance_of_rbfs:\n",
" for lambda_in_em_algorithm_grid in candidates_of_lambda_in_em_algorithm:\n",
" calculation_number += 1\n",
" print([calculation_number, all_calculation_numbers])\n",
" # construct GTM model\n",
" model = gtm([shape_of_map_grid, shape_of_map_grid],\n",
" [shape_of_rbf_centers_grid, shape_of_rbf_centers_grid],\n",
" variance_of_rbfs_grid, lambda_in_em_algorithm_grid, number_of_iterations, display_flag)\n",
" model.fit(input_dataset)\n",
" if model.success_flag:\n",
" # calculate of responsibilities\n",
" responsibilities = model.responsibility(input_dataset)\n",
" # calculate the mean of responsibilities\n",
" means = responsibilities.dot(model.map_grids)\n",
" # calculate k3n-error\n",
" k3nerror_of_gtm = k3nerror(input_dataset, means, k_in_k3nerror)\n",
" else:\n",
" k3nerror_of_gtm = 10 ** 100\n",
" parameters_and_k3nerror.append(\n",
" [shape_of_map_grid, shape_of_rbf_centers_grid, variance_of_rbfs_grid, lambda_in_em_algorithm_grid,\n",
" k3nerror_of_gtm])\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# optimized GTM\n",
"parameters_and_k3nerror = np.array(parameters_and_k3nerror)\n",
"optimized_hyperparameter_number = \\\n",
" np.where(parameters_and_k3nerror[:, 4] == np.min(parameters_and_k3nerror[:, 4]))[0][0]\n",
"shape_of_map = [parameters_and_k3nerror[optimized_hyperparameter_number, 0],\n",
" parameters_and_k3nerror[optimized_hyperparameter_number, 0]]\n",
"shape_of_rbf_centers = [parameters_and_k3nerror[optimized_hyperparameter_number, 1],\n",
" parameters_and_k3nerror[optimized_hyperparameter_number, 1]]\n",
"variance_of_rbfs = parameters_and_k3nerror[optimized_hyperparameter_number, 2]\n",
"lambda_in_em_algorithm = parameters_and_k3nerror[optimized_hyperparameter_number, 3]\n",
"\n",
"# construct GTM model\n",
"model = gtm(shape_of_map, shape_of_rbf_centers, variance_of_rbfs, lambda_in_em_algorithm, number_of_iterations,\n",
" display_flag)\n",
"model.fit(input_dataset)\n",
"\n",
"# calculate of responsibilities\n",
"responsibilities = model.responsibility(input_dataset)\n",
"\n",
"# plot the mean of responsibilities\n",
"means = responsibilities.dot(model.map_grids)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.22380413497668755"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"k3nerror(input_dataset, means, 10)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x1a27f94668>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(5, 5))\n",
"for i in np.arange(3):\n",
" x = np.ma.masked_where(color!=i, means[:, 0])\n",
" y = np.ma.masked_where(color!=i, means[:, 1])\n",
" plt.scatter(x, y, c=f'C{i}', label=iris.target_names[i])\n",
"plt.ylim(-1.1, 1.1)\n",
"plt.xlim(-1.1, 1.1)\n",
"plt.title('GTM iris dataset')\n",
"plt.xlabel('z1 (mean)')\n",
"plt.ylabel('z2 (mean)')\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best parameters\n",
"Shape of map = 30\n",
"Shape of RBF centers = 4\n",
"Variance of RBFs = 8.0\n",
"Lambda in EM algorithm = 0.25\n"
]
}
],
"source": [
"print(f'Best parameters')\n",
"print(f'Shape of map = {shape_of_map[0]}')\n",
"print(f'Shape of RBF centers = {shape_of_rbf_centers[0]}')\n",
"print(f'Variance of RBFs = {variance_of_rbfs}')\n",
"print(f'Lambda in EM algorithm = {lambda_in_em_algorithm}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# EOF "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment