Skip to content

Instantly share code, notes, and snippets.

@alanzchen
Last active April 17, 2024 17:29
Show Gist options
  • Save alanzchen/17d0c4a45d59b79052b1cd07f531689e to your computer and use it in GitHub Desktop.
Save alanzchen/17d0c4a45d59b79052b1cd07f531689e to your computer and use it in GitHub Desktop.
ChiMerge implementation in Python 3.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ChiMerge (Ker92)\n",
"\n",
"ChiMerge [Ker92] is a supervised, bottom-up (i.e., merge-based) data discretization method.\n",
"\n",
"It relies on $ \\chi^2 $ analysis: Adjacent intervals with the least $ \\chi^2 $ values are merged together until the chosen stopping criterion satisfies.\n",
"\n",
"Here we implement a version of ChiMerge that uses the number of maximum interval as the stopping condition.\n",
"\n",
"## Loading Iris Data"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from collections import Counter\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"iris = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"iris.columns = ['sepal_l', 'sepal_w', 'petal_l', 'petal_w', 'type']"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sepal_l</th>\n",
" <th>sepal_w</th>\n",
" <th>petal_l</th>\n",
" <th>petal_w</th>\n",
" <th>type</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.1</td>\n",
" <td>3.5</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.9</td>\n",
" <td>3.0</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.7</td>\n",
" <td>3.2</td>\n",
" <td>1.3</td>\n",
" <td>0.2</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.6</td>\n",
" <td>3.1</td>\n",
" <td>1.5</td>\n",
" <td>0.2</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5.0</td>\n",
" <td>3.6</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>Iris-setosa</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" sepal_l sepal_w petal_l petal_w type\n",
"0 5.1 3.5 1.4 0.2 Iris-setosa\n",
"1 4.9 3.0 1.4 0.2 Iris-setosa\n",
"2 4.7 3.2 1.3 0.2 Iris-setosa\n",
"3 4.6 3.1 1.5 0.2 Iris-setosa\n",
"4 5.0 3.6 1.4 0.2 Iris-setosa"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"iris.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ChiMerge Implementation"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def chimerge(data, attr, label, max_intervals):\n",
" distinct_vals = sorted(set(data[attr])) # Sort the distinct values\n",
" labels = sorted(set(data[label])) # Get all possible labels\n",
" empty_count = {l: 0 for l in labels} # A helper function for padding the Counter()\n",
" intervals = [[distinct_vals[i], distinct_vals[i]] for i in range(len(distinct_vals))] # Initialize the intervals for each attribute\n",
" while len(intervals) > max_intervals: # While loop\n",
" chi = []\n",
" for i in range(len(intervals)-1):\n",
" # Calculate the Chi2 value\n",
" obs0 = data[data[attr].between(intervals[i][0], intervals[i][1])]\n",
" obs1 = data[data[attr].between(intervals[i+1][0], intervals[i+1][1])]\n",
" total = len(obs0) + len(obs1)\n",
" count_0 = np.array([v for i, v in {**empty_count, **Counter(obs0[label])}.items()])\n",
" count_1 = np.array([v for i, v in {**empty_count, **Counter(obs1[label])}.items()])\n",
" count_total = count_0 + count_1\n",
" expected_0 = count_total*sum(count_0)/total\n",
" expected_1 = count_total*sum(count_1)/total\n",
" chi_ = (count_0 - expected_0)**2/expected_0 + (count_1 - expected_1)**2/expected_1\n",
" chi_ = np.nan_to_num(chi_) # Deal with the zero counts\n",
" chi.append(sum(chi_)) # Finally do the summation for Chi2\n",
" min_chi = min(chi) # Find the minimal Chi2 for current iteration\n",
" for i, v in enumerate(chi):\n",
" if v == min_chi:\n",
" min_chi_index = i # Find the index of the interval to be merged\n",
" break\n",
" new_intervals = [] # Prepare for the merged new data array\n",
" skip = False\n",
" done = False\n",
" for i in range(len(intervals)):\n",
" if skip:\n",
" skip = False\n",
" continue\n",
" if i == min_chi_index and not done: # Merge the intervals\n",
" t = intervals[i] + intervals[i+1]\n",
" new_intervals.append([min(t), max(t)])\n",
" skip = True\n",
" done = True\n",
" else:\n",
" new_intervals.append(intervals[i])\n",
" intervals = new_intervals\n",
" for i in intervals:\n",
" print('[', i[0], ',', i[1], ']', sep='')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Interval for sepal_l\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/site-packages/ipykernel_launcher.py:18: RuntimeWarning: invalid value encountered in true_divide\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[4.3,4.8]\n",
"[4.9,4.9]\n",
"[5.0,5.4]\n",
"[5.5,5.7]\n",
"[5.8,7.0]\n",
"[7.1,7.9]\n",
"Interval for sepal_w\n",
"[2.0,2.2]\n",
"[2.3,2.4]\n",
"[2.5,2.8]\n",
"[2.9,2.9]\n",
"[3.0,3.3]\n",
"[3.4,4.4]\n",
"Interval for petal_l\n",
"[1.0,1.9]\n",
"[3.0,4.4]\n",
"[4.5,4.7]\n",
"[4.8,4.9]\n",
"[5.0,5.1]\n",
"[5.2,6.9]\n",
"Interval for petal_w\n",
"[0.1,0.6]\n",
"[1.0,1.3]\n",
"[1.4,1.6]\n",
"[1.7,1.7]\n",
"[1.8,1.8]\n",
"[1.9,2.5]\n"
]
}
],
"source": [
"for attr in ['sepal_l', 'sepal_w', 'petal_l', 'petal_w']:\n",
" print('Interval for', attr)\n",
" chimerge(data=iris, attr=attr, label='type', max_intervals=6)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@bhemati
Copy link

bhemati commented Nov 2, 2018

Thanks for the clean code and use of libraries! This helped me a lot for my assignment.

@WellJoea
Copy link

only merge the first one min chi index in one loop.?
or merge the all min chi in one loop, is it better??

@Faisuvaporn
Copy link

Thank you a lot !!
I finded for long time to looking algorithm for partitioning of ages intervals
(That it is not equi-depth partitioning)

@jorguzb
Copy link

jorguzb commented Oct 24, 2022

thank a lot,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment