Skip to content

Instantly share code, notes, and snippets.

@bnorm
Created November 11, 2017 21:11
Show Gist options
  • Save bnorm/da007a8bfaadda956e78eebf2ccb55b5 to your computer and use it in GitHub Desktop.
Save bnorm/da007a8bfaadda956e78eebf2ccb55b5 to your computer and use it in GitHub Desktop.
import java.util.*
fun main(args: Array<String>) {
val i = DoubleMatrix.identity(4)
val a = DoubleMatrix.ofRows {
row(1.0, 2.0, 3.0, 4.0)
row(5.0, 6.0, 7.0, 8.0)
row(9.0, 10.0, 11.0, 12.0)
row(13.0, 14.0, 15.0, 16.0)
}
val value = a[1, 2]
val b = DoubleMatrix.column(1.0, 2.0, 3.0, 4.0)
val c = DoubleMatrix.row(1.0, 2.0, 3.0, 4.0)
println(a)
println(b)
println(a * b)
println(c)
println((c * b).toSingle())
}
class DoubleMatrix(
val rows: Int,
val columns: Int,
private val values: DoubleArray
) {
init {
require(rows > 0)
require(columns > 0)
require(rows * columns == values.size)
}
companion object {
fun identity(size: Int): DoubleMatrix {
val values = DoubleArray(size * size) { if (it % (size + 1) == 0) 1.0 else 0.0 }
return DoubleMatrix(size, size, values)
}
class RowDsl {
internal val rows = ArrayList<DoubleArray>()
fun row(vararg values: Double) {
require(rows.size == 0 || rows[0].size == values.size)
rows.add(values)
}
}
fun ofRows(block: RowDsl.() -> Unit): DoubleMatrix {
val matrix = RowDsl().apply(block).rows
val rows = matrix.size
val columns = matrix[0].size
val result = DoubleArray(rows * columns)
for ((r, row) in matrix.withIndex()) {
for ((c, value) in row.withIndex()) {
result[c * rows + r] = value
}
}
return DoubleMatrix(rows, columns, result)
}
class ColumnDsl {
internal val columns = ArrayList<DoubleArray>()
fun column(vararg values: Double) {
require(columns.size == 0 || columns[0].size == values.size)
columns.add(values)
}
}
fun ofColumns(block: ColumnDsl.() -> Unit): DoubleMatrix {
val matrix = ColumnDsl().apply(block).columns
val columns = matrix.size
val rows = matrix[0].size
val result = DoubleArray(rows * columns)
for ((c, column) in matrix.withIndex()) {
for ((r, value) in column.withIndex()) {
result[c * rows + r] = value
}
}
return DoubleMatrix(rows, columns, result)
}
fun column(vararg values: Double): DoubleMatrix {
return DoubleMatrix(values.size, 1, values)
}
fun row(vararg values: Double): DoubleMatrix {
return DoubleMatrix(1, values.size, values)
}
}
operator fun get(row: Int, column: Int): Double {
checkIndex(row, rows) { "row=$row rows=$rows" }
checkIndex(column, columns) { "column=$column columns=$columns" }
return values[rows * column + row]
}
operator fun plus(other: DoubleMatrix): DoubleMatrix {
if (rows != other.rows) throw IndexOutOfBoundsException()
if (columns != other.columns) throw IndexOutOfBoundsException()
val result = DoubleArray(rows * columns)
for (i in 0 until result.size) {
result[i] = values[i] + other.values[i]
}
return DoubleMatrix(rows, columns, result)
}
operator fun plus(b: Double): DoubleMatrix {
val result = DoubleArray(rows * columns)
for (i in 0 until result.size) {
result[i] = values[i] + b
}
return DoubleMatrix(rows, columns, result)
}
operator fun times(other: DoubleMatrix): DoubleMatrix {
if (columns != other.rows) throw IndexOutOfBoundsException()
val result = DoubleArray(rows * other.columns)
for (column in 0 until other.columns) {
for (row in 0 until rows) {
val base = column * other.columns
val i = base + row
for (x in 0 until columns) {
// this[row, x] * other[x, column]
val x1 = values[rows * x + row]
val x2 = other.values[base + x]
result[i] += x1 * x2
}
}
}
return DoubleMatrix(rows, other.columns, result)
}
operator fun times(a: Double): DoubleMatrix {
val result = DoubleArray(rows * columns)
for (i in 0 until result.size) {
result[i] = a * values[i]
}
return DoubleMatrix(rows, columns, result)
}
fun transpose(): DoubleMatrix {
val result = DoubleArray(rows * columns)
for (i in 0 until result.size) {
result[i] = values[(i % columns) * rows + i / columns]
}
return DoubleMatrix(rows = columns, columns = rows, values = result)
}
override fun toString(): String {
return buildString {
append("DoubleMatrix($rows, $columns)[")
for (row in 0 until rows) {
append("[")
var count = 0
for (column in 0 until columns) {
if (++count > 1) append(", ")
append(values[rows * column + row])
}
append("]")
}
append("]")
}
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as DoubleMatrix
if (rows != other.rows) return false
if (columns != other.columns) return false
if (!Arrays.equals(values, other.values)) return false
return true
}
override fun hashCode(): Int {
var result = rows
result = 31 * result + columns
result = 31 * result + Arrays.hashCode(values)
return result
}
private inline fun checkIndex(value: Int, max: Int, lazyMessage: () -> Any) {
if (value >= max || value < 0) {
val message = lazyMessage()
throw IndexOutOfBoundsException(message.toString())
}
}
}
operator fun Double.times(matrix: DoubleMatrix): DoubleMatrix = matrix.times(this)
operator fun Double.plus(matrix: DoubleMatrix): DoubleMatrix = matrix.plus(this)
fun DoubleMatrix.toSingle(): Double {
require(rows == 1 && columns == 1)
return this[0, 0]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment