Created
November 11, 2017 21:11
-
-
Save bnorm/da007a8bfaadda956e78eebf2ccb55b5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.* | |
fun main(args: Array<String>) { | |
val i = DoubleMatrix.identity(4) | |
val a = DoubleMatrix.ofRows { | |
row(1.0, 2.0, 3.0, 4.0) | |
row(5.0, 6.0, 7.0, 8.0) | |
row(9.0, 10.0, 11.0, 12.0) | |
row(13.0, 14.0, 15.0, 16.0) | |
} | |
val value = a[1, 2] | |
val b = DoubleMatrix.column(1.0, 2.0, 3.0, 4.0) | |
val c = DoubleMatrix.row(1.0, 2.0, 3.0, 4.0) | |
println(a) | |
println(b) | |
println(a * b) | |
println(c) | |
println((c * b).toSingle()) | |
} | |
class DoubleMatrix( | |
val rows: Int, | |
val columns: Int, | |
private val values: DoubleArray | |
) { | |
init { | |
require(rows > 0) | |
require(columns > 0) | |
require(rows * columns == values.size) | |
} | |
companion object { | |
fun identity(size: Int): DoubleMatrix { | |
val values = DoubleArray(size * size) { if (it % (size + 1) == 0) 1.0 else 0.0 } | |
return DoubleMatrix(size, size, values) | |
} | |
class RowDsl { | |
internal val rows = ArrayList<DoubleArray>() | |
fun row(vararg values: Double) { | |
require(rows.size == 0 || rows[0].size == values.size) | |
rows.add(values) | |
} | |
} | |
fun ofRows(block: RowDsl.() -> Unit): DoubleMatrix { | |
val matrix = RowDsl().apply(block).rows | |
val rows = matrix.size | |
val columns = matrix[0].size | |
val result = DoubleArray(rows * columns) | |
for ((r, row) in matrix.withIndex()) { | |
for ((c, value) in row.withIndex()) { | |
result[c * rows + r] = value | |
} | |
} | |
return DoubleMatrix(rows, columns, result) | |
} | |
class ColumnDsl { | |
internal val columns = ArrayList<DoubleArray>() | |
fun column(vararg values: Double) { | |
require(columns.size == 0 || columns[0].size == values.size) | |
columns.add(values) | |
} | |
} | |
fun ofColumns(block: ColumnDsl.() -> Unit): DoubleMatrix { | |
val matrix = ColumnDsl().apply(block).columns | |
val columns = matrix.size | |
val rows = matrix[0].size | |
val result = DoubleArray(rows * columns) | |
for ((c, column) in matrix.withIndex()) { | |
for ((r, value) in column.withIndex()) { | |
result[c * rows + r] = value | |
} | |
} | |
return DoubleMatrix(rows, columns, result) | |
} | |
fun column(vararg values: Double): DoubleMatrix { | |
return DoubleMatrix(values.size, 1, values) | |
} | |
fun row(vararg values: Double): DoubleMatrix { | |
return DoubleMatrix(1, values.size, values) | |
} | |
} | |
operator fun get(row: Int, column: Int): Double { | |
checkIndex(row, rows) { "row=$row rows=$rows" } | |
checkIndex(column, columns) { "column=$column columns=$columns" } | |
return values[rows * column + row] | |
} | |
operator fun plus(other: DoubleMatrix): DoubleMatrix { | |
if (rows != other.rows) throw IndexOutOfBoundsException() | |
if (columns != other.columns) throw IndexOutOfBoundsException() | |
val result = DoubleArray(rows * columns) | |
for (i in 0 until result.size) { | |
result[i] = values[i] + other.values[i] | |
} | |
return DoubleMatrix(rows, columns, result) | |
} | |
operator fun plus(b: Double): DoubleMatrix { | |
val result = DoubleArray(rows * columns) | |
for (i in 0 until result.size) { | |
result[i] = values[i] + b | |
} | |
return DoubleMatrix(rows, columns, result) | |
} | |
operator fun times(other: DoubleMatrix): DoubleMatrix { | |
if (columns != other.rows) throw IndexOutOfBoundsException() | |
val result = DoubleArray(rows * other.columns) | |
for (column in 0 until other.columns) { | |
for (row in 0 until rows) { | |
val base = column * other.columns | |
val i = base + row | |
for (x in 0 until columns) { | |
// this[row, x] * other[x, column] | |
val x1 = values[rows * x + row] | |
val x2 = other.values[base + x] | |
result[i] += x1 * x2 | |
} | |
} | |
} | |
return DoubleMatrix(rows, other.columns, result) | |
} | |
operator fun times(a: Double): DoubleMatrix { | |
val result = DoubleArray(rows * columns) | |
for (i in 0 until result.size) { | |
result[i] = a * values[i] | |
} | |
return DoubleMatrix(rows, columns, result) | |
} | |
fun transpose(): DoubleMatrix { | |
val result = DoubleArray(rows * columns) | |
for (i in 0 until result.size) { | |
result[i] = values[(i % columns) * rows + i / columns] | |
} | |
return DoubleMatrix(rows = columns, columns = rows, values = result) | |
} | |
override fun toString(): String { | |
return buildString { | |
append("DoubleMatrix($rows, $columns)[") | |
for (row in 0 until rows) { | |
append("[") | |
var count = 0 | |
for (column in 0 until columns) { | |
if (++count > 1) append(", ") | |
append(values[rows * column + row]) | |
} | |
append("]") | |
} | |
append("]") | |
} | |
} | |
override fun equals(other: Any?): Boolean { | |
if (this === other) return true | |
if (javaClass != other?.javaClass) return false | |
other as DoubleMatrix | |
if (rows != other.rows) return false | |
if (columns != other.columns) return false | |
if (!Arrays.equals(values, other.values)) return false | |
return true | |
} | |
override fun hashCode(): Int { | |
var result = rows | |
result = 31 * result + columns | |
result = 31 * result + Arrays.hashCode(values) | |
return result | |
} | |
private inline fun checkIndex(value: Int, max: Int, lazyMessage: () -> Any) { | |
if (value >= max || value < 0) { | |
val message = lazyMessage() | |
throw IndexOutOfBoundsException(message.toString()) | |
} | |
} | |
} | |
operator fun Double.times(matrix: DoubleMatrix): DoubleMatrix = matrix.times(this) | |
operator fun Double.plus(matrix: DoubleMatrix): DoubleMatrix = matrix.plus(this) | |
fun DoubleMatrix.toSingle(): Double { | |
require(rows == 1 && columns == 1) | |
return this[0, 0] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment