Created
January 23, 2016 07:22
-
-
Save tma15/bc2e556ca79f055bf3cb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package lu | |
import ( | |
"fmt" | |
) | |
type Matrix [][]float64 | |
func NewMatrix() Matrix { | |
return Matrix{} | |
} | |
func (this *Matrix) Resize(row, col int) { | |
for i := len(*this); i < row; i++ { | |
*this = append(*this, make([]float64, 0, 10)) | |
for j := len((*this)[i]); j < col; j++ { | |
(*this)[i] = append((*this)[i], 0.) | |
} | |
} | |
} | |
func (this *Matrix) Fill(n float64) { | |
for i := 0; i < len(*this); i++ { | |
for j := 0; j < len((*this)[i]); j++ { | |
(*this)[i][j] = n | |
} | |
} | |
} | |
func (this *Matrix) RowSize() int { | |
return len(*this) | |
} | |
func (this *Matrix) ColSize() int { | |
return len((*this)[0]) | |
} | |
func (this *Matrix) Mul(other *Matrix) Matrix { | |
rThis := this.RowSize() | |
cThis := this.ColSize() | |
rOther := other.RowSize() | |
cOther := other.ColSize() | |
if cThis != rOther { | |
panic("error") | |
} | |
n := NewMatrix() | |
n.Resize(rThis, cOther) | |
for i := 0; i < rThis; i++ { | |
for j := 0; j < cThis; j++ { | |
for k := 0; k < rOther; k++ { | |
n[i][j] += (*this)[i][k] * (*other)[k][j] | |
} | |
} | |
} | |
return n | |
} | |
func (this *Matrix) Det() float64 { | |
_, l, u, numSwap := LU(*this) | |
det := 1. | |
for i := 0; i < l.RowSize(); i++ { | |
det *= l[i][i] | |
} | |
for i := 0; i < l.RowSize(); i++ { | |
det *= u[i][i] | |
} | |
if numSwap%2 != 0 { | |
return -det | |
} | |
return det | |
} | |
func (this *Matrix) Print() { | |
for i := 0; i < len(*this); i++ { | |
for j := 0; j < len((*this)[i]); j++ { | |
fmt.Printf(fmt.Sprintf("%.1f, ", (*this)[i][j])) | |
} | |
fmt.Println("") | |
} | |
} | |
func LU(m Matrix) (Matrix, Matrix, Matrix, int) { | |
r := m.RowSize() | |
c := m.ColSize() | |
l := NewMatrix() | |
l.Resize(r, c) | |
u := NewMatrix() | |
u.Resize(r, c) | |
p := NewMatrix() | |
p.Resize(r, c) | |
for i := 0; i < r; i++ { | |
p[i][i] = 1. | |
} | |
/* 各行の最大値が対角成分になるようにmの行を並び替える */ | |
numSwap := 0 | |
for i := 0; i < r; i++ { | |
max := -1. | |
argmax := -1 | |
/* i行において、最大の値になる列を探す */ | |
for j := i; j < c; j++ { | |
if m[i][j] > max { | |
max = m[i][j] | |
argmax = j | |
} | |
} | |
if i != argmax { | |
p[i], p[argmax] = p[argmax], p[i] | |
numSwap += 1 | |
} | |
} | |
m = p.Mul(&m) | |
for i := 0; i < r; i++ { | |
l[i][i] = 1. | |
} | |
for j := 0; j < c; j++ { | |
/* Uを計算する */ | |
/* U[i][j] = m[i][j] - sum */ | |
/* sum = L[i][s] * U[s][j]; ただし、sは0<=s<=r */ | |
for i := 0; i <= j; i++ { /* Uは列番号以下の行番号のみ計算する */ | |
sum := 0. | |
for k := 0; k < r; k++ { | |
sum += l[i][k] * u[k][j] | |
} | |
u[i][j] = m[i][j] - sum | |
// fmt.Println(fmt.Sprintf("u[%d][%d] = %f - %f", i, j, m[i][j], sum)) | |
} | |
/* Lを計算する */ | |
/* L[i][j] = (m[i][j] - sum) / U[j][j] */ | |
/* sum = L[i][s] * U[s][j]; ただし、sは0<=s<=r */ | |
for i := j + 1; i < r; i++ { /* Lは列番号より大きな行番号のみ計算する */ | |
sum := 0. | |
for k := 0; k < r; k++ { | |
sum += l[i][k] * u[k][j] | |
} | |
l[i][j] = (m[i][j] - sum) / u[j][j] /* ピボット選択しないと都合が悪いことがある */ | |
// fmt.Println(fmt.Sprintf("l[%d][%d] = %f - %f / %f = %f", i, j, | |
// m[i][j], sum, u[j][j], l[i][j])) | |
} | |
} | |
return p, l, u, numSwap | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package lu | |
import ( | |
"fmt" | |
"testing" | |
) | |
func TestMatrix(t *testing.T) { | |
a := Matrix{} | |
a.Resize(1, 2) | |
a.Fill(0.) | |
if a.RowSize() != 1 { | |
t.Error(fmt.Sprintf("want 1 got %d", a.RowSize())) | |
} | |
if a.ColSize() != 2 { | |
t.Error(fmt.Sprintf("want 1 got %d", a.ColSize())) | |
} | |
a.Resize(3, 2) | |
a.Fill(1.) | |
if a.RowSize() != 3 { | |
t.Error(fmt.Sprintf("want 3 got %d", a.RowSize())) | |
} | |
if a.ColSize() != 2 { | |
t.Error(fmt.Sprintf("want 1 got %d", a.ColSize())) | |
} | |
} | |
func TestLU(t *testing.T) { | |
a := Matrix{ | |
{1., 1., 2.}, | |
{2., 1., 1.}, | |
{3., 2., 1.}, | |
} | |
p, l, u, _ := LU(a) | |
/* aを並び替える */ | |
a__ := p.Mul(&a) | |
/* A = LU */ | |
a_ := l.Mul(&u) | |
for i := 0; i < a.RowSize(); i++ { | |
for j := 0; j < a.ColSize(); j++ { | |
if Abs(a__[i][j]-a_[i][j]) > 1e-10 { | |
t.Error(fmt.Sprintf("%.32f %.32f", a__[i][j], a_[i][j])) | |
} | |
} | |
} | |
b := Matrix{ | |
{1., 2.}, | |
{3., 4.}, | |
} | |
d := b.Det() | |
if d != -2. { | |
t.Error(fmt.Sprintf("want -2 got %f", d)) | |
} | |
b = Matrix{ | |
{3., 5.}, | |
{2., 8.}, | |
} | |
d = b.Det() | |
if d != 14. { | |
t.Error(fmt.Sprintf("want -2 got %f", d)) | |
} | |
} | |
func Abs(x float64) float64 { | |
if x < 0. { | |
return -1 * x | |
} | |
return x | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment