Last active
April 17, 2024 17:29
-
-
Save alanzchen/17d0c4a45d59b79052b1cd07f531689e to your computer and use it in GitHub Desktop.
ChiMerge implementation in Python 3.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# ChiMerge (Ker92)\n", | |
"\n", | |
"ChiMerge [Ker92] is a supervised, bottom-up (i.e., merge-based) data discretization method.\n", | |
"\n", | |
"It relies on $ \\chi^2 $ analysis: Adjacent intervals with the least $ \\chi^2 $ values are merged together until the chosen stopping criterion satisfies.\n", | |
"\n", | |
"Here we implement a version of ChiMerge that uses the number of maximum interval as the stopping condition.\n", | |
"\n", | |
"## Loading Iris Data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"from collections import Counter\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"iris = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"iris.columns = ['sepal_l', 'sepal_w', 'petal_l', 'petal_w', 'type']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>sepal_l</th>\n", | |
" <th>sepal_w</th>\n", | |
" <th>petal_l</th>\n", | |
" <th>petal_w</th>\n", | |
" <th>type</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>5.1</td>\n", | |
" <td>3.5</td>\n", | |
" <td>1.4</td>\n", | |
" <td>0.2</td>\n", | |
" <td>Iris-setosa</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>4.9</td>\n", | |
" <td>3.0</td>\n", | |
" <td>1.4</td>\n", | |
" <td>0.2</td>\n", | |
" <td>Iris-setosa</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>4.7</td>\n", | |
" <td>3.2</td>\n", | |
" <td>1.3</td>\n", | |
" <td>0.2</td>\n", | |
" <td>Iris-setosa</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>4.6</td>\n", | |
" <td>3.1</td>\n", | |
" <td>1.5</td>\n", | |
" <td>0.2</td>\n", | |
" <td>Iris-setosa</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>5.0</td>\n", | |
" <td>3.6</td>\n", | |
" <td>1.4</td>\n", | |
" <td>0.2</td>\n", | |
" <td>Iris-setosa</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" sepal_l sepal_w petal_l petal_w type\n", | |
"0 5.1 3.5 1.4 0.2 Iris-setosa\n", | |
"1 4.9 3.0 1.4 0.2 Iris-setosa\n", | |
"2 4.7 3.2 1.3 0.2 Iris-setosa\n", | |
"3 4.6 3.1 1.5 0.2 Iris-setosa\n", | |
"4 5.0 3.6 1.4 0.2 Iris-setosa" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"iris.head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## ChiMerge Implementation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def chimerge(data, attr, label, max_intervals):\n", | |
" distinct_vals = sorted(set(data[attr])) # Sort the distinct values\n", | |
" labels = sorted(set(data[label])) # Get all possible labels\n", | |
" empty_count = {l: 0 for l in labels} # A helper function for padding the Counter()\n", | |
" intervals = [[distinct_vals[i], distinct_vals[i]] for i in range(len(distinct_vals))] # Initialize the intervals for each attribute\n", | |
" while len(intervals) > max_intervals: # While loop\n", | |
" chi = []\n", | |
" for i in range(len(intervals)-1):\n", | |
" # Calculate the Chi2 value\n", | |
" obs0 = data[data[attr].between(intervals[i][0], intervals[i][1])]\n", | |
" obs1 = data[data[attr].between(intervals[i+1][0], intervals[i+1][1])]\n", | |
" total = len(obs0) + len(obs1)\n", | |
" count_0 = np.array([v for i, v in {**empty_count, **Counter(obs0[label])}.items()])\n", | |
" count_1 = np.array([v for i, v in {**empty_count, **Counter(obs1[label])}.items()])\n", | |
" count_total = count_0 + count_1\n", | |
" expected_0 = count_total*sum(count_0)/total\n", | |
" expected_1 = count_total*sum(count_1)/total\n", | |
" chi_ = (count_0 - expected_0)**2/expected_0 + (count_1 - expected_1)**2/expected_1\n", | |
" chi_ = np.nan_to_num(chi_) # Deal with the zero counts\n", | |
" chi.append(sum(chi_)) # Finally do the summation for Chi2\n", | |
" min_chi = min(chi) # Find the minimal Chi2 for current iteration\n", | |
" for i, v in enumerate(chi):\n", | |
" if v == min_chi:\n", | |
" min_chi_index = i # Find the index of the interval to be merged\n", | |
" break\n", | |
" new_intervals = [] # Prepare for the merged new data array\n", | |
" skip = False\n", | |
" done = False\n", | |
" for i in range(len(intervals)):\n", | |
" if skip:\n", | |
" skip = False\n", | |
" continue\n", | |
" if i == min_chi_index and not done: # Merge the intervals\n", | |
" t = intervals[i] + intervals[i+1]\n", | |
" new_intervals.append([min(t), max(t)])\n", | |
" skip = True\n", | |
" done = True\n", | |
" else:\n", | |
" new_intervals.append(intervals[i])\n", | |
" intervals = new_intervals\n", | |
" for i in intervals:\n", | |
" print('[', i[0], ',', i[1], ']', sep='')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Interval for sepal_l\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.6/site-packages/ipykernel_launcher.py:18: RuntimeWarning: invalid value encountered in true_divide\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[4.3,4.8]\n", | |
"[4.9,4.9]\n", | |
"[5.0,5.4]\n", | |
"[5.5,5.7]\n", | |
"[5.8,7.0]\n", | |
"[7.1,7.9]\n", | |
"Interval for sepal_w\n", | |
"[2.0,2.2]\n", | |
"[2.3,2.4]\n", | |
"[2.5,2.8]\n", | |
"[2.9,2.9]\n", | |
"[3.0,3.3]\n", | |
"[3.4,4.4]\n", | |
"Interval for petal_l\n", | |
"[1.0,1.9]\n", | |
"[3.0,4.4]\n", | |
"[4.5,4.7]\n", | |
"[4.8,4.9]\n", | |
"[5.0,5.1]\n", | |
"[5.2,6.9]\n", | |
"Interval for petal_w\n", | |
"[0.1,0.6]\n", | |
"[1.0,1.3]\n", | |
"[1.4,1.6]\n", | |
"[1.7,1.7]\n", | |
"[1.8,1.8]\n", | |
"[1.9,2.5]\n" | |
] | |
} | |
], | |
"source": [ | |
"for attr in ['sepal_l', 'sepal_w', 'petal_l', 'petal_w']:\n", | |
" print('Interval for', attr)\n", | |
" chimerge(data=iris, attr=attr, label='type', max_intervals=6)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
only merge the first one min chi index in one loop.?
or merge the all min chi in one loop, is it better??
Thank you a lot !!
I finded for long time to looking algorithm for partitioning of ages intervals
(That it is not equi-depth partitioning)
thank a lot,
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for the clean code and use of libraries! This helped me a lot for my assignment.