Skip to content

Instantly share code, notes, and snippets.

@brianhsu
Created March 28, 2010 05:02
Show Gist options
  • Save brianhsu/346588 to your computer and use it in GitHub Desktop.
Save brianhsu/346588 to your computer and use it in GitHub Desktop.
/*================================================================================
* 這邊是 Sparse Matrix 的實作
*===============================================================================*/
class SparseMatrix
{
/**
* 這個 Sparse Matrix 我們採用 Dictionary of Keys 的作法,用一個 Map 存非零元素,
* Key 是 (row, col) 的 Tuple,而不是 Array[(col, row, value)]。
*
* 這樣的作法有以下幾個好處:
*
* - Random Access/Modify 任何一個元素都是 O(1)
* - 程式寫起來比較簡潔
*/
private var matrix: Map[(Int, Int), Int] = Map ()
/**
* 更新矩陣內特定的一個元素
*/
private def update (row: Int, col: Int, value: Int) {
matrix += (row, col) -> value
}
/**
* 結合兩個大小相同的矩陣(好用來做加減法)
*
* 這個演算法並不好,時間複雜度是 O (row * col)
*
* 這裡只是示範 Pattern Matching,其實可以用其他的方式讓時間複雜雜度降為
* O (m * n),其中 m 和 n 分別是兩個矩陣的非零元素個數。
*
* @param that 另一個矩陣
* @param func 要進行的操作
* @return 新的矩陣
*/
private def combine (that: SparseMatrix)(func: (Int, Int) => Int): SparseMatrix = {
// 確定兩個矩陣大小是一樣的
require (that != null)
require (this.row == that.row && this.col == that.col, "Matrix size is not the same")
// 計算後的結果在這個新矩陣裡
var newMatrix = new SparseMatrix
// 掃過每一個 row 和每一個 col
for (i <- 1 to row; j <- 1 to col) {
val thisValue = this.matrix.get(i,j) // 第一個矩陣的 (row, col) 的值
val thatValue = that.matrix.get(i,j) // 第二個矩陣的 (row, col) 的值
// 把 thisValue 和 thatValue 合起來比一下
(thisValue, thatValue) match {
// 如果都不是非 0 元素就不用做事
case (None, None) =>
// 如果只有 thisValue 有值,那就在新矩陣裡填進 thisValue 的值
case (Some(x), None) => newMatrix(i,j) = x
// 如果只有 thatValue 有值,就在新矩陣裡填進 thatValue 的值
case (None, Some(y)) => newMatrix(i,j) = y
// 如果兩個都有值,就用 func 這個函式來計算新的值並填到新矩陣裡
case (Some(x), Some(y)) => newMatrix(i,j) = func (x, y)
}
}
newMatrix
}
/**
* Constructor
*
* 傳進來的是一連串的 (row, col, value) Tuple,也就是非零元素的所在位至和值
*/
def this (elements: (Int, Int, Int)*) = {
this ()
// 這沒啥好說的吧?把 elements 裡每一個 Tuple 拿出來,
// 再把舉陣裡 (row, col) 的地方設成 value
for ( (row, col, value) <- elements) {
this(row, col) = value
}
}
/**
* Getter Method
*
* 這是 Scala 裡的 Getter Method,讓你可以用 matrix(i,j) 的方式取得特定的一個
* 元素,因為底層用 Map 做,所以複雜度是 O(1)
*/
def apply (row: Int, col: Int) = matrix.getOrElse ((row,col), 0)
/**
* 計算 Matrix 有幾個 col
*
* 詳細的解釋如下
* {{{
* // 先把我們 Map 裡的所有 Key,也就是 (row, col) 的 Tuple 取出來成為一個 List。
* matrix.keys.
*
* // 再將這些 Tuple 映設成只留下 (col) 的部份,也就是說假設是 List((1,1), (1, 4), (1, 2))
* // 就會變成 List(1, 4, 2)。
* map (_._2)
*
* // 最後從 List 最左邊一直往右邊比,只留下比左邊還大的數值,最後就會是整個 List 裡最大的
* // 值。
* foldLeft(0)((x, y) => if (x > y) x else y)
* }}}
*
*/
def col = matrix.keys.map (_._2).foldLeft(0)((x,y) => if (x > y ) x else y)
/**
* 計算 Matrix 有幾個 row
*
* 和算 col 的方法一樣,只是在做 map 的時候改成拉出 (row, col) 裡 (row) 的部份
*/
def row = matrix.keys.map (_._1).foldLeft(0)((x,y) => if (x > y ) x else y)
/**
* 矩陣相加
*
* 等於將兩個矩陣用加法 combine 起來
*/
def + (that: SparseMatrix): SparseMatrix = combine(that)(_ + _)
/**
* 矩陣相減
*
* 等於將兩個矩陣用減法 combine 起來
*/
def - (that: SparseMatrix): SparseMatrix = combine(that)(_ - _)
/**
* 矩陣乘法
*/
def * (that: SparseMatrix) = {
require (that != null && this.col == that.row, "These two matrix cannot be multiplied")
// 先將相乘的對象做轉置,題目就會變成 matrixA 的第一個 row 乘 matrixB 的第一個 row
// 會是新矩陣 (1, 1) 的值,matrixA 的第一個 row 乘 matrixB 的特二個 row 會是新矩陣
// 裡 (1, 2) 的值,依此類推
val transported = that.transpose
val newMatrix = new SparseMatrix
// 計算 matrixA 的某個 row 乘 matrixB 的某個 row
def multiPart (rowA: Int, rowB: Int) = {
// 先取出 matrixA 裡面 rowA 的所有非零元素,也就是說在 Map 裡面的 key 值裡的
// row == rowA
val matrixA = this.matrix.filter {x =>
val ((row, col), value) = x
row == rowA
}
// 同上,取出 matrixB 裡面所有 rowB 的非零元素,只是這裡的寫法比較簡潔,不用另
// 外取變數名稱
val matrixB = transported.matrix.filter (x => x._1._1 == rowB)
// 然後把這兩個 row 個別元素相乘的結果放進 List 中
val productList = if (matrixA.size >= matrixB.size) {
for ((key, value) <- matrixA) yield value * matrixB.getOrElse((rowB, key._2), 0)
} else {
for ((key, value) <- matrixB) yield value * matrixA.getOrElse((rowA, key._2), 0)
}
// 再將 List 加總,變成最後的值
productList.foldLeft(0)(_+_)
}
for (i <- 1 to this.row; j <- 1 to that.col) {
newMatrix(i,j) = multiPart (i, j)
}
newMatrix
}
/**
* 轉置矩陣
*/
def transpose = {
// 先生一個新矩陣
val newMatrix = new SparseMatrix
// 把舊矩陣裡的 (row, col) 對調後填到新矩陣就好了
for ( ((row, col), value) <- this.matrix) {
newMatrix(col, row) = value
}
newMatrix
}
override def toString = {
def makeRow (row: Int) = for (i <- 1 to col) yield { matrix.getOrElse((row,i), 0) }
def makeMatrix = for (i <- 1 to row) yield { makeRow(i).mkString("\t") }
makeMatrix.mkString ("\n")
}
}
object SparseMatrix
{
// 沒啥用處,只不過是從一邊的二維陣列轉換成稀疏矩陣而已
def fromMatrix (elements: Array[Array[Int]]) = {
val newMatrix = new SparseMatrix
// 看起來像是一層迴圈,其實是兩層。XD
for (i <- 0 until elements.length;
j <- 0 until elements(i).length if elements(i)(j) != 0) {
newMatrix(i+1, j+1) = elements(i)(j)
}
newMatrix
}
}
/*===============================================================================
* 下面是 Unit Test
*==============================================================================*/
import org.scalatest.FlatSpec
import org.scalatest.matchers.ShouldMatchers
class SparseMatrixSpec extends FlatSpec with ShouldMatchers {
val matrixA = Array (Array (1, 2, 3),
Array (4, 5, 6),
Array (7, 8, 9))
val matrixB = Array (Array (10, 11, 12),
Array (13, 14, 15),
Array (16, 17, 18))
val matrixC = Array (Array (11, 13, 15),
Array (17, 19, 21),
Array (23, 25, 27))
"A sparse matrix" should "be able to created from hand" in {
val sMatrix = new SparseMatrix ((1, 1, 1), (1, 2, 2), (1, 3, 3),
(2, 1, 4), (2, 2, 5), (2, 3, 6),
(3, 1, 7), (3, 2, 8), (3, 3, 9))
sMatrix.row should be === 3
sMatrix.col should be === 3
for (i <- 1 to sMatrix.row; j <- 1 to sMatrix.col) {
matrixA(i-1)(j-1) should be === sMatrix(i,j)
}
}
"A sparse matrix" should "be able created from array of array" in {
val sMatrix = SparseMatrix.fromMatrix (matrixA)
sMatrix.row should be === 3
sMatrix.col should be === 3
for (i <- 1 to sMatrix.row; j <- 1 to sMatrix.col) {
matrixA(i-1)(j-1) should be === sMatrix(i,j)
}
}
"Two sparse matrix" should "added correctly" in {
val sMatrixA = SparseMatrix.fromMatrix (matrixA)
val sMatrixB = SparseMatrix.fromMatrix (matrixB)
val sMatrixC = sMatrixA + sMatrixB
sMatrixC.row should be === 3
sMatrixC.col should be === 3
for (i <- 1 to sMatrixC.row; j <- 1 to sMatrixC.col) {
matrixC(i-1)(j-1) should be === sMatrixC(i,j)
}
}
"Two sparse matrix" should "multiply correctly" in {
val matrixA = Array (Array( 1, 0, 2),
Array(-1, 3, 1))
val matrixB = Array (Array(3, 1),
Array(2, 1),
Array(1, 0))
val matrixC = Array (Array(5, 1),
Array(4, 2))
val sMatrixA = SparseMatrix.fromMatrix (matrixA)
val sMatrixB = SparseMatrix.fromMatrix (matrixB)
val sMatrixC = sMatrixA * sMatrixB
sMatrixC.row should be === 2
sMatrixC.col should be === 2
for (i <- 1 to sMatrixC.row; j <- 1 to sMatrixC.col) {
matrixC(i-1)(j-1) should be === sMatrixC(i,j)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment