Skip to content

Instantly share code, notes, and snippets.

@Yukishita26
Last active February 25, 2020 03:34
Show Gist options
  • Save Yukishita26/4751e295149416867673f815a57a7969 to your computer and use it in GitHub Desktop.
Save Yukishita26/4751e295149416867673f815a57a7969 to your computer and use it in GitHub Desktop.
Comparation of GPR and KRR
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.datasets import load_boston\n",
"boston = load_boston()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.kernel_ridge import KernelRidge\n",
"from sklearn.model_selection import cross_val_score\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"\n",
"input_x = pd.DataFrame(boston['data'], columns=boston['feature_names'])\n",
"input_y = boston['target']\n",
"mms = MinMaxScaler().fit(input_x)\n",
"train_x, test_x, train_y, test_y = train_test_split(mms.transform(input_x), input_y, test_size=0.2, random_state=0)"
]
},
{
"cell_type": "code",
"execution_count": 271,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.gaussian_process import GaussianProcessRegressor\n",
"from sklearn.gaussian_process.kernels import RBF\n",
"from sklearn.gaussian_process.kernels import WhiteKernel\n",
"\n",
"krr = KernelRidge(kernel='rbf', alpha=1.0, gamma=1.0)\n",
"gpr = GaussianProcessRegressor(alpha=1.0, kernel=RBF(length_scale=1.0), optimizer=None)"
]
},
{
"cell_type": "code",
"execution_count": 272,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 44 ms\n"
]
},
{
"data": {
"text/plain": [
"KernelRidge(alpha=1.0, coef0=1, degree=3, gamma=1.0, kernel='rbf',\n",
" kernel_params=None)"
]
},
"execution_count": 272,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"krr.fit(train_x, train_y)"
]
},
{
"cell_type": "code",
"execution_count": 273,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 51 ms\n"
]
},
{
"data": {
"text/plain": [
"GaussianProcessRegressor(alpha=1.0, copy_X_train=True,\n",
" kernel=RBF(length_scale=1), n_restarts_optimizer=0,\n",
" normalize_y=False, optimizer=None, random_state=None)"
]
},
"execution_count": 273,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"gpr.fit(train_x, train_y)"
]
},
{
"cell_type": "code",
"execution_count": 274,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.8363809363920776, 0.649983463876857)"
]
},
"execution_count": 274,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"krr.score(train_x, train_y), krr.score(test_x, test_y)"
]
},
{
"cell_type": "code",
"execution_count": 275,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.8146886102829334, 0.6398643616279389)"
]
},
"execution_count": 275,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gpr.score(train_x, train_y), gpr.score(test_x, test_y)"
]
},
{
"cell_type": "code",
"execution_count": 277,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"238.8410776079586"
]
},
"execution_count": 277,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"((krr.predict(train_x) - gpr.predict(train_x))**2).sum()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 302,
"metadata": {},
"outputs": [],
"source": [
"krr = KernelRidge(kernel='rbf', alpha=1.0, gamma=1.0)\n",
"gpr = GaussianProcessRegressor(alpha=1.0, kernel=RBF(length_scale=1.0/np.sqrt(2.0)), optimizer=None)"
]
},
{
"cell_type": "code",
"execution_count": 303,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 11 ms\n"
]
},
{
"data": {
"text/plain": [
"KernelRidge(alpha=1.0, coef0=1, degree=3, gamma=1.0, kernel='rbf',\n",
" kernel_params=None)"
]
},
"execution_count": 303,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"krr.fit(train_x, train_y)"
]
},
{
"cell_type": "code",
"execution_count": 304,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 20 ms\n"
]
},
{
"data": {
"text/plain": [
"GaussianProcessRegressor(alpha=1.0, copy_X_train=True,\n",
" kernel=RBF(length_scale=0.707), n_restarts_optimizer=0,\n",
" normalize_y=False, optimizer=None, random_state=None)"
]
},
"execution_count": 304,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"gpr.fit(train_x, train_y)"
]
},
{
"cell_type": "code",
"execution_count": 305,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.8363809363920776, 0.649983463876857)"
]
},
"execution_count": 305,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"krr.score(train_x, train_y), krr.score(test_x, test_y)"
]
},
{
"cell_type": "code",
"execution_count": 306,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.8363809363920777, 0.6499834638768573)"
]
},
"execution_count": 306,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gpr.score(train_x, train_y), gpr.score(test_x, test_y)"
]
},
{
"cell_type": "code",
"execution_count": 307,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.5646977168968247e-25"
]
},
"execution_count": 307,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"((krr.predict(train_x) - gpr.predict(train_x))**2).sum()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment