Skip to content

Instantly share code, notes, and snippets.

@oxinabox
Created March 11, 2020 11:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oxinabox/1674e5b8e0e86cfa80535ba615ac545f to your computer and use it in GitHub Desktop.
Save oxinabox/1674e5b8e0e86cfa80535ba615ac545f to your computer and use it in GitHub Desktop.
NonNegative Matrix Factorization
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m\u001b[1mActivating\u001b[22m\u001b[39m environment at `~/Documents/misccode/NNMF/Convex/Project.toml`\n"
]
}
],
"source": [
"using Pkg: @pkg_str\n",
"pkg\"activate .\"\n",
"# use EricPHanson's branch as it has updated packages and Compat setup right\n",
"#pkg\"add Plots Distributions Convex ECOS SCS https://github.com/ericphanson/SequentialConvexRelaxation.jl.git\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: Precompiling Convex [f65535da-76fb-5f13-bab9-19810c17039a]\n",
"└ @ Base loading.jl:1273\n",
"┌ Info: Precompiling SequentialConvexRelaxation [154aca32-cc25-11e9-1929-3b69ed34445f]\n",
"└ @ Base loading.jl:1273\n",
"┌ Info: Precompiling ECOS [e2685f51-7e38-5353-a97d-a921fd2c8199]\n",
"└ @ Base loading.jl:1273\n",
"┌ Info: Precompiling SCS [c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13]\n",
"└ @ Base loading.jl:1273\n"
]
},
{
"data": {
"text/plain": [
"Plots.GRBackend()"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using Convex\n",
"using SequentialConvexRelaxation\n",
"using ECOS\n",
"using SCS\n",
"using Distributions\n",
"using LinearAlgebra\n",
"using Plots\n",
"gr(fmt=:png)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: redefining constant true_cluster_assignment\n",
"WARNING: redefining constant cluster_distributions\n",
"WARNING: redefining constant data\n"
]
},
{
"data": {
"text/plain": [
"5×10 Array{Float64,2}:\n",
" 101.024 101.004 101.03 101.023 101.024 … 98.9384 99.9862 99.9652\n",
" 101.058 100.993 101.013 100.972 101.038 98.9286 99.9543 100.056 \n",
" 100.949 101.003 100.992 101.022 101.065 98.9416 99.9211 99.9866\n",
" 101.005 100.968 101.0 101.014 100.985 99.0079 100.023 100.029 \n",
" 101.027 100.97 100.978 100.954 100.993 99.0447 100.02 99.9276"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"const true_cluster_assignment = [fill(1, 5); fill(2, 3); fill(3, 2)]\n",
"const ndims = 5\n",
"const cluster_distributions = [ # Use IsoNormal because only allow circular distributions\n",
" MultivariateNormal(fill(101, ndims), 0.001I),\n",
" MultivariateNormal(fill(99, ndims), 0.002I),\n",
" MultivariateNormal(fill(100, ndims), 0.003I), \n",
"]\n",
"\n",
"# features × observations\n",
"const data = hcat((rand(cluster_distributions[cl]) for cl in true_cluster_assignment)...)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": ""
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scatter(data[1, :], data[2, :]; zcolor = true_cluster_assignment, legend=:bottomleft)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"nnmf"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\"\"\"\n",
" nnmf(V, latent_size)\n",
"\n",
"### Args\n",
" - `V` the data, a matrix of features × observations\n",
" - `latent_size` the number of latent variables. For clustering this is the number of clusters.\n",
"\n",
"Returns `W,H`, which are nonnegative matrixes such that ``W*H ≈ V``\n",
"\"\"\"\n",
"function nnmf(V, latent_size)\n",
" all(V.>=0) || throw(ArgumentError(\"V contains negative entries\"))\n",
" \n",
" W = Variable(size(V, 1), latent_size)\n",
" H = Variable(latent_size, size(V, 2))\n",
" \n",
" # The main goal `opnorm(W*H - V))` will be handled by the BilinearConstraint\n",
" # thus we set the problem to be a dummy value\n",
" problem = minimize(\n",
" 1,\n",
" [\n",
" W>=0,\n",
" H>=0,\n",
" ]\n",
" )\n",
" \n",
" # Constructing the bilinear equality constraint\n",
" # In general this is of the form A*P*B=C.\n",
" # Here A=W, P=I, B=H, C=V .\n",
" eye = Matrix(I, latent_size, latent_size)\n",
" reconstruct_constraint= BilinearConstraint(W, eye, H, V, λ=0.1)\n",
" eye_orth_inner = Matrix(I, size(H,2), size(H,2))\n",
" orthogonality_constraint= BilinearConstraint(H, eye_orth_inner, H', eye, λ=0.1)\n",
"\n",
" # The bilinear problem is the convex problem with the bilinear constraint\n",
" bilinear_problem = BilinearProblem(\n",
" problem,\n",
" [\n",
" reconstruct_constraint,\n",
" # orthogonality_constraint\n",
" ],\n",
" )\n",
" res = solve!(bilinear_problem, () -> SCS.Optimizer(verbose=0), iterations=10)\n",
" (W=W, H=H), bilinear_problem\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Warning: Bilinear constraint 1 not satisfied (gap = 0.11234117682821705)\n",
"└ @ SequentialConvexRelaxation /Users/oxinabox/.julia/packages/SequentialConvexRelaxation/9jc2O/src/SequentialConvexRelaxation.jl:408\n"
]
},
{
"data": {
"text/plain": [
"((W = Variable\n",
"size: (5, 3)\n",
"sign: real\n",
"vexity: affine\n",
"id: 140…734\n",
"value: [0.1570462845229664 11.417119780254275 1.9790771743623983; 0.021141235108979976 11.299729854032721 2.297251144320985; 0.023640789088924436 11.484763948519115 1.844308740524739; 0.2620201397676304 11.303159303534848 2.2257761979127215; 0.47901684406437545 11.362428001936705 2.0262235624224973], H = Variable\n",
"size: (3, 10)\n",
"sign: real\n",
"vexity: affine\n",
"id: 114…020\n",
"value: [0.9090761727022365 0.8018998057589781 0.8156148196829522 0.793127657006549 0.7303882764054443 0.9560068400380448 0.8835317223972944 1.108916235520905 1.0571211581975082 0.689856014912743; 8.22145778191618 8.255554482564627 8.245501346939792 8.260128763462827 8.265130415875387 8.132590595705121 8.0724488049853 8.075654836102757 8.147332052116685 8.13525181295928; 3.5356276815757015 3.3466106276014056 3.4074959247868417 3.3267055113350907 3.316543654963961 3.048534318704198 3.386482639601667 3.333416861023218 3.4313193606284322 3.536654868595232]), BilinearProblem(minimize\n",
"└─ 1\n",
"subject to\n",
"├─ >= constraint (affine)\n",
"│ ├─ 5×3 real variable (id: 140…734)\n",
"│ └─ 0\n",
"└─ >= constraint (affine)\n",
" ├─ 3×10 real variable (id: 114…020)\n",
" └─ 0\n",
"\n",
"termination status: OPTIMAL\n",
"primal status: FEASIBLE_POINT\n",
"dual status: FEASIBLE_POINT, BilinearConstraint[BilinearConstraint(Variable\n",
"size: (5, 3)\n",
"sign: real\n",
"vexity: affine\n",
"id: 140…734\n",
"value: [0.1570462845229664 11.417119780254275 1.9790771743623983; 0.021141235108979976 11.299729854032721 2.297251144320985; 0.023640789088924436 11.484763948519115 1.844308740524739; 0.2620201397676304 11.303159303534848 2.2257761979127215; 0.47901684406437545 11.362428001936705 2.0262235624224973], Bool[1 0 0; 0 1 0; 0 0 1], Variable\n",
"size: (3, 10)\n",
"sign: real\n",
"vexity: affine\n",
"id: 114…020\n",
"value: [0.9090761727022365 0.8018998057589781 0.8156148196829522 0.793127657006549 0.7303882764054443 0.9560068400380448 0.8835317223972944 1.108916235520905 1.0571211581975082 0.689856014912743; 8.22145778191618 8.255554482564627 8.245501346939792 8.260128763462827 8.265130415875387 8.132590595705121 8.0724488049853 8.075654836102757 8.147332052116685 8.13525181295928; 3.5356276815757015 3.3466106276014056 3.4074959247868417 3.3267055113350907 3.316543654963961 3.048534318704198 3.386482639601667 3.333416861023218 3.4313193606284322 3.536654868595232], [101.02364745133809 101.00418121769547 … 99.98623365604394 99.96517226835647; 101.05839042646836 100.99314613005879 … 99.95431490867371 100.05551695615708; … ; 101.00457311317994 100.96845799149115 … 100.02296898927675 100.0289186610699; 101.02662908998043 100.9704462880352 … 100.01970963605815 99.92763650541038], Variable\n",
"size: (5, 3)\n",
"sign: real\n",
"vexity: constant\n",
"id: 988…862\n",
"value: [-0.1570462845229664 -11.417119780254275 -1.9790771743623983; -0.021141235108979976 -11.299729854032721 -2.297251144320985; -0.023640789088924436 -11.484763948519115 -1.844308740524739; -0.2620201397676304 -11.303159303534848 -2.2257761979127215; -0.47901684406437545 -11.362428001936705 -2.0262235624224973], Variable\n",
"size: (3, 10)\n",
"sign: real\n",
"vexity: constant\n",
"id: 114…974\n",
"value: [-0.9090761727022365 -0.8018998057589781 -0.8156148196829522 -0.793127657006549 -0.7303882764054443 -0.9560068400380448 -0.8835317223972944 -1.108916235520905 -1.0571211581975082 -0.689856014912743; -8.22145778191618 -8.255554482564627 -8.245501346939792 -8.260128763462827 -8.265130415875387 -8.132590595705121 -8.0724488049853 -8.075654836102757 -8.147332052116685 -8.13525181295928; -3.5356276815757015 -3.3466106276014056 -3.4074959247868417 -3.3267055113350907 -3.316543654963961 -3.048534318704198 -3.386482639601667 -3.333416861023218 -3.4313193606284322 -3.536654868595232], Variable\n",
"size: (5, 5)\n",
"sign: real\n",
"vexity: constant\n",
"id: 169…881\n",
"value: [1.0 0.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0 0.0; 0.0 0.0 1.0 0.0 0.0; 0.0 0.0 0.0 1.0 0.0; 0.0 0.0 0.0 0.0 1.0], Variable\n",
"size: (10, 10)\n",
"sign: real\n",
"vexity: constant\n",
"id: 624…272\n",
"value: [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0], 0.1)], SequentialConvexRelaxation.Result(10, false, AbstractFloat[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], Array{AbstractFloat,1}[[0.47002954935838154], [0.13459561941770679], [0.12111592265281247], [0.11550643127095846], [0.11349153619789144], [0.11283465517897813], [0.11252425804889135], [0.11240977173745709], [0.11236063100153294], [0.11234117682821705]], Any[])))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res, bilinear_problem = nnmf(data, 3)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.11234117682821705"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"norm(res.W.value * res.H.value .- data)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×10 Array{Float64,2}:\n",
" -0.0182321 -0.000390748 -0.0182582 … -0.010305 0.0236244 \n",
" -0.0166948 -0.00265244 0.00409316 0.0132875 0.00979968\n",
" 0.0147802 0.000936427 0.00941258 0.0024681 -0.016155 \n",
" 0.0315861 0.00430964 -0.00136609 -0.0180409 -0.022313 \n",
" -0.0114721 -0.00219804 0.00604102 0.0127632 0.00508257"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(res.W.value * res.H.value .- data)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×3 Array{Float64,2}:\n",
" 0.157046 11.4171 1.97908\n",
" 0.0211412 11.2997 2.29725\n",
" 0.0236408 11.4848 1.84431\n",
" 0.26202 11.3032 2.22578\n",
" 0.479017 11.3624 2.02622"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res.W.value"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3×10 Array{Float64,2}:\n",
" 0.909076 0.8019 0.815615 0.793128 … 1.10892 1.05712 0.689856\n",
" 8.22146 8.25555 8.2455 8.26013 8.07565 8.14733 8.13525 \n",
" 3.53563 3.34661 3.4075 3.32671 3.33342 3.43132 3.53665 "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res.H.value"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.3"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predicted_cluster_assignments = vec(first.(Tuple.(findmax(res.H.value, dims=1)[end])))\n",
"mean(predicted_cluster_assignments .== true_cluster_assignment)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"image/png": ""
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scatter(data[1, :], data[2, :]; zcolor = true_cluster_assignment)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"image/png": ""
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"recon = res.W.value * res.H.value \n",
"scatter!(recon[1, : ], recon[2, :]; zcolor=predicted_cluster_assignments, marker=:star)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"@webio": {
"lastCommId": null,
"lastKernelId": null
},
"kernelspec": {
"display_name": "Julia 1.3.2-pre",
"language": "julia",
"name": "julia-1.3"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.3.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment