Skip to content

Instantly share code, notes, and snippets.

@lan2720
Created July 15, 2021 12:14
Show Gist options
  • Save lan2720/8b7fb498c61d08d175389440ff657257 to your computer and use it in GitHub Desktop.
Save lan2720/8b7fb498c61d08d175389440ff657257 to your computer and use it in GitHub Desktop.
Whitening.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d78d12f3-6463-4d6e-aaa8-6f2cdc5d0dcf",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "5128dd7c-dc4f-4b37-b299-3d57a0891f5e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"相关系数:\n",
" [[1. 0.97821977]\n",
" [0.97821977 1. ]]\n",
"特征值:\n",
" [450.18259359 3.1377546 ]\n",
"特征向量:\n",
" [[-0.44159531 -0.89721434]\n",
" [-0.89721434 0.44159531]]\n",
"(100, 2)\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f4dcb529e80>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"def correlation(X):\n",
" \"\"\"\n",
" :param X: [N,d], 行是数据,列是变量\n",
" return: [d,d],变量之间的协方差\n",
" \"\"\"\n",
" # 每个变量的均值\n",
" feature_mean = np.mean(X, axis=0)\n",
" X_bar = X - feature_mean\n",
" feature_norm = np.sum(X_bar*X_bar, axis=0)**0.5\n",
" # 每个变量进行归一化\n",
" X_norm = X_bar / feature_norm\n",
" cov = X_norm.T.dot(X_norm)\n",
" return cov\n",
"\n",
"# a = np.array([[1,3,5], [5,4,1], [3,8,6]])\n",
"x = np.random.normal(2, 10, (100,1))\n",
"y = np.random.normal(5, 4, (100,1)) + 2*x\n",
"a = np.concatenate((x,y), axis=1)\n",
"print(\"相关系数:\\n\", correlation(a))\n",
"\n",
"plt.figure(figsize=(10, 10)) \n",
"plt.gca().set_aspect(1)\n",
"\n",
"def whiten(X):\n",
" plt.scatter(X[:,0], X[:,1], label=\"origin\")\n",
" corr = correlation(X)\n",
" cov = np.cov(X, rowvar=False)\n",
" eigVals, eigVecs = np.linalg.eig(cov)\n",
" idx = np.argsort(eigVals)[::-1]\n",
" eigVals = eigVals[idx]\n",
" eigVecs = eigVecs[:,idx]\n",
" print(\"特征值:\\n\", eigVals)\n",
" print(\"特征向量:\\n\", eigVecs)\n",
" for i in range(eigVecs.shape[1]):\n",
" # 每列是一个特征向量\n",
" vec = eigVecs[:,i]\n",
" plt.arrow(0,0, vec[0]*10, vec[1]*10, width=0.3, color='black', alpha=0.7)\n",
"\n",
" decorrelated = X.dot(eigVecs)\n",
" print(decorrelated.shape)\n",
" whitened = decorrelated / np.sqrt(eigVals + 1e-5)\n",
" plt.scatter(decorrelated[:,0], decorrelated[:,1], label=\"decorrelated\")\n",
" plt.scatter(whitened[:,0], whitened[:,1], label=\"whiten\")\n",
"\n",
"whiten(a)\n",
"plt.legend()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "jupyter",
"language": "python",
"name": "jupyter"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment