Skip to content

Instantly share code, notes, and snippets.

@yukoba
Last active June 11, 2017 12:56
Show Gist options
  • Save yukoba/af768495e8ba07cf126b51a788ff58c7 to your computer and use it in GitHub Desktop.
Save yukoba/af768495e8ba07cf126b51a788ff58c7 to your computer and use it in GitHub Desktop.
Tensor library. Only + is implemented. Broadcasting is supported.
package jp.yukoba.tensor
import java.util
import com.github.fommil.netlib.BLAS
import org.scalatest.FunSuite
class TensorFloat(val data: Array[Float], val shape: Array[Int], val strides: Array[Int]) {
assert(data.length > 0)
def ndim: Int = shape.length
def size: Int = shape.product
override def toString: String = {
val sb = new StringBuilder
sb.append("TensorFloat(")
if (ndim == 0) {
sb.append(data(0))
} else {
def loop(dim: Int, idx: Int): Unit = {
if (dim == shape.length - 1) {
for (i <- 0 until shape(dim)) {
sb.append(data(idx + strides(dim) * i)).append(", ")
}
} else {
for (i <- 0 until shape(dim)) {
sb.append("(")
loop(dim + 1, idx + strides(dim) * i)
sb.delete(sb.size - 2, sb.size).append("), ")
}
}
}
loop(0, 0)
sb.delete(sb.size - 2, sb.size)
}
sb.append(")")
sb.toString
}
def broadcast(shape1: Array[Int], shape2: Array[Int],
strides1: Array[Int], strides2: Array[Int]): (Array[Int], Array[Int], Array[Int]) = {
if (util.Arrays.equals(shape1, shape2)) {
(shape1, strides1, strides2)
} else if (shape1.length > shape2.length) {
val x = broadcast(shape2, shape1, strides2, strides1)
(x._1, x._3, x._2)
} else {
if (shape1.length == 0) {
(shape2, new Array(shape2.length), strides2)
} else {
val shape1_2 = Array.fill(shape2.length - shape1.length)(1) ++ shape1
val strides1_2 = new Array(shape2.length - shape1.length) ++ strides1
val strides2_2 = strides2.clone()
for (i <- shape1_2.indices) {
if (shape1_2(i) != shape2(i)) {
if (shape1_2(i) == 1) {
shape1_2(i) = shape2(i)
strides1_2(i) = 0
} else if (shape2(i) == 1) {
strides2_2(i) = 0
} else assert(false)
}
}
(shape1_2, strides1_2, strides2_2)
}
}
}
def +(that: TensorFloat): TensorFloat = {
val data1 = data
val data2 = that.data
val (shape1, strides1, strides2) = broadcast(shape, that.shape, strides, that.strides)
val size1 = shape1.product
if (shape1.length == 0) {
new TensorFloat(Array(data1(0) + data2(0)), shape1, strides1)
} else {
val ary = new Array[Float](size1)
var aryIdx = 0
def loop(dim: Int, idx1: Int, idx2: Int): Unit = {
if (dim == shape1.length - 1) {
// Warning: The current Java VM does auto-vectorization. I should write a pure Java code here and should not use BLAS.
val blas = BLAS.getInstance()
blas.scopy(shape1(dim), data1, idx1, strides1(dim), ary, aryIdx, 1)
blas.saxpy(shape1(dim), 1f, data2, idx2, strides2(dim), ary, aryIdx, 1)
aryIdx += shape1(dim)
} else {
for (i <- 0 until shape1(dim)) {
loop(dim + 1, idx1 + i * strides1(dim), idx2 + i * strides2(dim))
}
}
}
loop(0, 0, 0)
new TensorFloat(ary, shape1, shape1.tail :+ 1)
}
}
}
object TensorFloat {
def apply(values: Float*): TensorFloat = {
assert(values.nonEmpty)
if (values.size == 1)
new TensorFloat(Array(values(0)), Array(), Array())
else
TensorFloat(values.toArray)
}
def apply(ary: Array[Float]): TensorFloat = {
assert(ary.length > 0)
new TensorFloat(ary, Array(ary.length), Array(1))
}
def apply(ary: Array[Array[Float]]): TensorFloat = {
assert(ary.length > 0)
assert(ary.forall(_.length == ary(0).length))
val shape = Array(ary.length, ary(0).length)
val strides = Array(shape(1), 1)
new TensorFloat(ary.flatten, shape, strides)
}
implicit class ScalaFloat(val v: Float) extends TensorFloat(Array(v), Array(), Array())
}
class TensorFloatTest extends FunSuite {
test("vector") {
val t1 = TensorFloat(1f, 2f, 3f)
val t2 = TensorFloat(4f, 5f, 6f)
val t3 = t1 + t2
println(t3)
val t4 = 10f + t3
println(t4)
}
test("matrix") {
val t1 = TensorFloat(Array(Array(1f, 2f), Array(3f, 4f)))
val t2 = TensorFloat(Array(Array(2f, 3f), Array(4f, 5f)))
val t3 = t1 + t2
println(t3)
val t4 = 10f + t3
println(t4)
}
test("broadcast") {
val t1 = TensorFloat(Array(Array(1f, 2f), Array(3f, 4f)))
val t2 = TensorFloat(5f, 6f)
val t3 = t1 + t2
println(t3)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment