Skip to content

Instantly share code, notes, and snippets.

@tma15
Created January 23, 2016 07:22
Show Gist options
  • Save tma15/bc2e556ca79f055bf3cb to your computer and use it in GitHub Desktop.
Save tma15/bc2e556ca79f055bf3cb to your computer and use it in GitHub Desktop.
package lu
import (
"fmt"
)
type Matrix [][]float64
func NewMatrix() Matrix {
return Matrix{}
}
func (this *Matrix) Resize(row, col int) {
for i := len(*this); i < row; i++ {
*this = append(*this, make([]float64, 0, 10))
for j := len((*this)[i]); j < col; j++ {
(*this)[i] = append((*this)[i], 0.)
}
}
}
func (this *Matrix) Fill(n float64) {
for i := 0; i < len(*this); i++ {
for j := 0; j < len((*this)[i]); j++ {
(*this)[i][j] = n
}
}
}
func (this *Matrix) RowSize() int {
return len(*this)
}
func (this *Matrix) ColSize() int {
return len((*this)[0])
}
func (this *Matrix) Mul(other *Matrix) Matrix {
rThis := this.RowSize()
cThis := this.ColSize()
rOther := other.RowSize()
cOther := other.ColSize()
if cThis != rOther {
panic("error")
}
n := NewMatrix()
n.Resize(rThis, cOther)
for i := 0; i < rThis; i++ {
for j := 0; j < cThis; j++ {
for k := 0; k < rOther; k++ {
n[i][j] += (*this)[i][k] * (*other)[k][j]
}
}
}
return n
}
func (this *Matrix) Det() float64 {
_, l, u, numSwap := LU(*this)
det := 1.
for i := 0; i < l.RowSize(); i++ {
det *= l[i][i]
}
for i := 0; i < l.RowSize(); i++ {
det *= u[i][i]
}
if numSwap%2 != 0 {
return -det
}
return det
}
func (this *Matrix) Print() {
for i := 0; i < len(*this); i++ {
for j := 0; j < len((*this)[i]); j++ {
fmt.Printf(fmt.Sprintf("%.1f, ", (*this)[i][j]))
}
fmt.Println("")
}
}
func LU(m Matrix) (Matrix, Matrix, Matrix, int) {
r := m.RowSize()
c := m.ColSize()
l := NewMatrix()
l.Resize(r, c)
u := NewMatrix()
u.Resize(r, c)
p := NewMatrix()
p.Resize(r, c)
for i := 0; i < r; i++ {
p[i][i] = 1.
}
/* 各行の最大値が対角成分になるようにmの行を並び替える */
numSwap := 0
for i := 0; i < r; i++ {
max := -1.
argmax := -1
/* i行において、最大の値になる列を探す */
for j := i; j < c; j++ {
if m[i][j] > max {
max = m[i][j]
argmax = j
}
}
if i != argmax {
p[i], p[argmax] = p[argmax], p[i]
numSwap += 1
}
}
m = p.Mul(&m)
for i := 0; i < r; i++ {
l[i][i] = 1.
}
for j := 0; j < c; j++ {
/* Uを計算する */
/* U[i][j] = m[i][j] - sum */
/* sum = L[i][s] * U[s][j]; ただし、sは0<=s<=r */
for i := 0; i <= j; i++ { /* Uは列番号以下の行番号のみ計算する */
sum := 0.
for k := 0; k < r; k++ {
sum += l[i][k] * u[k][j]
}
u[i][j] = m[i][j] - sum
// fmt.Println(fmt.Sprintf("u[%d][%d] = %f - %f", i, j, m[i][j], sum))
}
/* Lを計算する */
/* L[i][j] = (m[i][j] - sum) / U[j][j] */
/* sum = L[i][s] * U[s][j]; ただし、sは0<=s<=r */
for i := j + 1; i < r; i++ { /* Lは列番号より大きな行番号のみ計算する */
sum := 0.
for k := 0; k < r; k++ {
sum += l[i][k] * u[k][j]
}
l[i][j] = (m[i][j] - sum) / u[j][j] /* ピボット選択しないと都合が悪いことがある */
// fmt.Println(fmt.Sprintf("l[%d][%d] = %f - %f / %f = %f", i, j,
// m[i][j], sum, u[j][j], l[i][j]))
}
}
return p, l, u, numSwap
}
package lu
import (
"fmt"
"testing"
)
func TestMatrix(t *testing.T) {
a := Matrix{}
a.Resize(1, 2)
a.Fill(0.)
if a.RowSize() != 1 {
t.Error(fmt.Sprintf("want 1 got %d", a.RowSize()))
}
if a.ColSize() != 2 {
t.Error(fmt.Sprintf("want 1 got %d", a.ColSize()))
}
a.Resize(3, 2)
a.Fill(1.)
if a.RowSize() != 3 {
t.Error(fmt.Sprintf("want 3 got %d", a.RowSize()))
}
if a.ColSize() != 2 {
t.Error(fmt.Sprintf("want 1 got %d", a.ColSize()))
}
}
func TestLU(t *testing.T) {
a := Matrix{
{1., 1., 2.},
{2., 1., 1.},
{3., 2., 1.},
}
p, l, u, _ := LU(a)
/* aを並び替える */
a__ := p.Mul(&a)
/* A = LU */
a_ := l.Mul(&u)
for i := 0; i < a.RowSize(); i++ {
for j := 0; j < a.ColSize(); j++ {
if Abs(a__[i][j]-a_[i][j]) > 1e-10 {
t.Error(fmt.Sprintf("%.32f %.32f", a__[i][j], a_[i][j]))
}
}
}
b := Matrix{
{1., 2.},
{3., 4.},
}
d := b.Det()
if d != -2. {
t.Error(fmt.Sprintf("want -2 got %f", d))
}
b = Matrix{
{3., 5.},
{2., 8.},
}
d = b.Det()
if d != 14. {
t.Error(fmt.Sprintf("want -2 got %f", d))
}
}
func Abs(x float64) float64 {
if x < 0. {
return -1 * x
}
return x
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment