Skip to content

Instantly share code, notes, and snippets.

@lucidfrontier45
Forked from Yukishita26/KRR_GPR_test.ipynb
Last active February 24, 2020 03:37
Show Gist options
  • Save lucidfrontier45/2aeb965dd03dc5b82837eceaf194460c to your computer and use it in GitHub Desktop.
Save lucidfrontier45/2aeb965dd03dc5b82837eceaf194460c to your computer and use it in GitHub Desktop.
Comparison of GPR and KRR
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.datasets import load_boston\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from sklearn.kernel_ridge import KernelRidge\n",
"from sklearn.gaussian_process import GaussianProcessRegressor\n",
"from sklearn.gaussian_process.kernels import RBF\n",
"from sklearn.base import BaseEstimator, RegressorMixin\n",
"from sklearn.metrics import mean_squared_error\n",
"import GPy\n",
"\n",
"boston = load_boston()\n",
"\n",
"X = MinMaxScaler().fit_transform(boston[\"data\"])\n",
"y = boston[\"target\"]"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class MyKernelRidge(RegressorMixin, BaseEstimator):\n",
" def __init__(self, alpha=1.0, gamma=1.0):\n",
" self.gamma = gamma\n",
" self._kernel_func = RBF(length_scale=gamma)\n",
" self.alpha = alpha\n",
" \n",
" \n",
" def fit(self, X, y):\n",
" self.X_fit_ = X\n",
" K = self._kernel_func(X, X)\n",
" self.dual_coef_ = np.linalg.solve(K + np.eye(len(X))*self.alpha, y)\n",
" return self\n",
" \n",
" def predict(self, X):\n",
" K = self._kernel_func(X, self.X_fit_)\n",
" return K.dot(self.dual_coef_)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"krr = KernelRidge(kernel=\"rbf\", alpha=1.0, gamma=1.0).fit(X, y)\n",
"krr2 = MyKernelRidge(alpha=1.0, gamma=1.0).fit(X, y)\n",
"gpr = GaussianProcessRegressor(alpha=1.0, kernel=RBF(length_scale=1.0), optimizer=None).fit(X, y)\n",
"gpr2 = GPy.models.GPRegression(X, y[:, np.newaxis],\n",
" GPy.kern.RBF(input_dim=X.shape[1], variance=1., lengthscale=1.))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"y_krr = krr.predict(X)\n",
"y_krr2 = krr2.predict(X)\n",
"y_gpr = gpr.predict(X)\n",
"y_gpr2 = gpr2.predict(X)[0][:, 0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"the results of `KernelRidge` and `GaussianProcessRegressor` are very different."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.5927075657366099"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mean_squared_error(y_gpr, y_krr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`GaussianProcessRegressor` and my own Kernel Ridge are consistent "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.1368152844150655e-28"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mean_squared_error(y_gpr, y_krr2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`GaussianProcessRegressor` and GPy's `GPRegression` are also consistemt"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.4301826259702018e-16"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mean_squared_error(y_gpr, y_gpr2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment