Skip to content

Instantly share code, notes, and snippets.

@suzusuzu
Created October 5, 2019 23:07
Show Gist options
  • Save suzusuzu/148f0e7f44f98fba4f6cf70d64295e8c to your computer and use it in GitHub Desktop.
Save suzusuzu/148f0e7f44f98fba4f6cf70d64295e8c to your computer and use it in GitHub Desktop.
caratheodory.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "caratheodory.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/suzusuzu/148f0e7f44f98fba4f6cf70d64295e8c/caratheodory.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "LmGv945mqc5b",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as np\n",
"from scipy.linalg import null_space\n",
"import itertools\n",
"\n",
"def caratheodory(p, u):\n",
" n = p.shape[0]\n",
" d = p.shape[1]\n",
" w = np.copy(u)\n",
" indices = np.arange(n)\n",
" while True:\n",
" n = indices.shape[0]\n",
" if n <= (d + 1):\n",
" return p[indices], w, indices\n",
" A = np.zeros((d, n-1))\n",
" p0 = p[indices[0]]\n",
" A = p[indices[1:]].T - np.tile(p0.reshape(-1, 1), (n-1,) )\n",
" v = null_space(A)[:,0]\n",
" v1 = - np.sum(v)\n",
" v = np.insert(v, 0, v1)\n",
" alpha = np.min([ (w[indices[i]] / v[i] ) for i in range(n) if v[i] > 0.0])\n",
" w[indices] = w[indices] - (alpha * v)\n",
" indices = np.where(w > 0)[0]\n",
"\n",
"def flatten(l):\n",
" return list(itertools.chain.from_iterable(l))\n",
"\n",
"def fast_caratheodory(p, u, k):\n",
" n = p.shape[0]\n",
" d = p.shape[1]\n",
" w = np.copy(u)\n",
" indices = np.arange(n)\n",
" while True:\n",
" n = indices.shape[0]\n",
" if n <= (d+1):\n",
" return p[indices], w, indices\n",
" if n < k:\n",
" k = n\n",
" int_split = np.insert(np.cumsum([(n + i) // k for i in range(k)]), 0, 0)\n",
" clusters = []\n",
" for i in range(k):\n",
" l = int_split[i]\n",
" r = int_split[i+1]\n",
" clusters.append(list(range(l, r)))\n",
" u_ = np.zeros(k)\n",
" u_i = np.zeros((k, d))\n",
" for i in range(k):\n",
" c = clusters[i]\n",
" u_[i] = np.sum(w[indices[c]])\n",
" for ci in c:\n",
" u_i[i] += w[indices[ci]] * p[indices[ci]]\n",
" u_i[i] /= u_[i]\n",
" u__, w__, indices_ = caratheodory(u_i, u_)\n",
" for index in indices_:\n",
" sum_u = np.sum(w[indices[clusters[index]]])\n",
" for c in clusters[index]:\n",
" w[indices[c]] = w__[index] * w[indices[c]] / sum_u\n",
" tmp = []\n",
" for index in indices_:\n",
" tmp.append(clusters[index])\n",
" tmp = flatten(tmp)\n",
" indices = indices[tmp]\n",
" for i in range(n):\n",
" if i not in indices:\n",
" w[i] = 0.0\n",
" \n",
"def caratheodory_matrix(a):\n",
" n = a.shape[0]\n",
" d = a.shape[1]\n",
" p = np.zeros((n, d*d))\n",
" u = np.ones(n) / n\n",
" for i in range(n):\n",
" p[i] = (a[i].reshape(-1, 1) @ a[i].reshape(1, -1)).reshape(-1)\n",
" c, w, indices = caratheodory(p, u)\n",
" s = np.zeros((c.shape[0], d))\n",
" for i, ai in enumerate(indices):\n",
" s[i] = np.sqrt(n*w[ai]) * a[ai]\n",
" return s \n",
"\n",
"def fast_caratheodory_matrix(a, k):\n",
" n = a.shape[0]\n",
" d = a.shape[1]\n",
" p = np.zeros((n, d*d))\n",
" u = np.ones(n) / n\n",
" for i in range(n):\n",
" p[i] = (a[i].reshape(-1, 1) @ a[i].reshape(1, -1)).reshape(-1)\n",
" c, w, indices = fast_caratheodory(p, u, k)\n",
" s = np.zeros((c.shape[0], d))\n",
" for i, ai in enumerate(indices):\n",
" s[i] = np.sqrt(n*w[ai]) * a[ai]\n",
" return s \n",
"\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "aQSUiQ-hah_w",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"outputId": "85d8402a-b445-4c9d-9fc5-c2771bdbf22c"
},
"source": [
"mat = np.random.random((1000, 3))\n",
"print(mat.T @ mat)\n",
"%time mat_ = caratheodory_matrix(mat)\n",
"print(mat_.T @ mat_)\n",
"%time mat__ = fast_caratheodory_matrix(mat, 100)\n",
"print(mat__.T @ mat__)"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"[[334.36317998 251.34114435 259.17179814]\n",
" [251.34114435 338.5228183 255.16522369]\n",
" [259.17179814 255.16522369 345.22660839]]\n",
"CPU times: user 6.92 s, sys: 1.82 s, total: 8.74 s\n",
"Wall time: 4.45 s\n",
"[[334.36317998 251.34114435 259.17179814]\n",
" [251.34114435 338.5228183 255.16522369]\n",
" [259.17179814 255.16522369 345.22660839]]\n",
"CPU times: user 140 ms, sys: 87.1 ms, total: 227 ms\n",
"Wall time: 114 ms\n",
"[[334.36317998 251.34114435 259.17179814]\n",
" [251.34114435 338.5228183 255.16522369]\n",
" [259.17179814 255.16522369 345.22660839]]\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment