Created
September 4, 2023 07:14
-
-
Save tatsy/c82f2e02cb71181f0df1417356766c47 to your computer and use it in GitHub Desktop.
SciPy_LLT_LDLT.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyO2MLi+LBWHGp6wD5Hi5Sd2", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/tatsy/c82f2e02cb71181f0df1417356766c47/scipy_llt_ldlt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "sROE59Re_iO1", | |
"outputId": "74868632-1f3d-4a6d-83e5-ea8c35e13d88" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.23.5)\n", | |
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.10.1)\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install numpy scipy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import numpy as np\n", | |
"import scipy as sp" | |
], | |
"metadata": { | |
"id": "Gl3W8Yw5_mSn" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"A = np.arange(1, 10).reshape((3, 3))\n", | |
"A = A.T @ A" | |
], | |
"metadata": { | |
"id": "8DJem__k_r54" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"np.linalg.det(A) # A is approximately singular" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "2jo4QPRb_0Dn", | |
"outputId": "fbfa8a1c-ded4-469c-f6d7-93ddc5cd012d" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"-1.5347723092418215e-12" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"bb = np.array([1, 2, 3])" | |
], | |
"metadata": { | |
"id": "7DK1iVdg_1-m" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Normal solution" | |
], | |
"metadata": { | |
"id": "LoFaclevAM_5" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"xx = np.linalg.solve(A, bb)\n", | |
"print(xx)\n", | |
"print(\"residue =\", ((A @ xx - bb)**2.0).sum())" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "PeUHO5vO_33O", | |
"outputId": "e60504d4-a44c-41ba-a5dc-a60ab6022c7b" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"[-6.66666667e-01 -2.22044605e-15 5.00000000e-01]\n", | |
"residue = 1.0097419586828951e-28\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Cholesky decomposition (LLT)" | |
], | |
"metadata": { | |
"id": "g5XX7po3AQ-L" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class LLT(object):\n", | |
" def __init__(self, A):\n", | |
" self.A = A\n", | |
" try:\n", | |
" self.chol = sp.linalg.cho_factor(A)\n", | |
" except Exception as e:\n", | |
" self.chol = None\n", | |
" print(\"[ERROR] Factorization failed!!\")\n", | |
"\n", | |
" def solve(self, bb):\n", | |
" if self.chol is None:\n", | |
" print(\"[ERROR] Cannot apply LLT method for singular matrix!!\")\n", | |
" return None\n", | |
"\n", | |
" return sp.linalg.cho_solve(self.chol, bb)" | |
], | |
"metadata": { | |
"id": "6vJ72eyJCCkf" | |
}, | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"llt = LLT(A)\n", | |
"xx = llt.solve(bb)\n", | |
"if xx is None:\n", | |
" print(\"Linear system could not be solve!!\")\n", | |
"else:\n", | |
" print(xx)\n", | |
" print(\"residue =\", ((A @ xx - bb)**2.0).sum())" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "tRZ7N3mC_8s3", | |
"outputId": "991e136f-4343-4625-9c71-c8d86828bf3b" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"[ERROR] Factorization failed!!\n", | |
"[ERROR] Cannot apply LLT method for singular matrix!!\n", | |
"Linear system could not be solve!!\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## LDLT solver" | |
], | |
"metadata": { | |
"id": "4JkoccR_AuiA" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class LDLT(object):\n", | |
" def __init__(self, A):\n", | |
" self.A = A\n", | |
" self.L, self.D, self.P = sp.linalg.ldl(A)\n", | |
"\n", | |
" def solve(self, bb):\n", | |
" x1 = sp.linalg.solve_triangular(self.L[self.P, :], bb[self.P], lower=True)\n", | |
" x2 = x1 / np.diag(self.D)\n", | |
" x3 = sp.linalg.solve_triangular(self.L[self.P, :], x2, lower=True, trans=\"T\")\n", | |
" return x3[self.P]" | |
], | |
"metadata": { | |
"id": "vTgEPk7JAYe-" | |
}, | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"ldlt = LDLT(A)\n", | |
"xx = ldlt.solve(bb)\n", | |
"print(xx)\n", | |
"print(\"residue =\", ((A @ xx - bb)**2.0).sum())" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "rJb6X1UeBL32", | |
"outputId": "15be3144-d0de-4f3e-d111-3818432c80bd" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"[-1.15151515 0.96969697 0.01515152]\n", | |
"residue = 3.6070664891230765e-28\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment