Skip to content

Instantly share code, notes, and snippets.

@Dris101
Last active July 21, 2020 09:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Dris101/9983f0aedaa78d9477b4d3a1c8770bd0 to your computer and use it in GitHub Desktop.
Save Dris101/9983f0aedaa78d9477b4d3a1c8770bd0 to your computer and use it in GitHub Desktop.
Various test code for tensorMmul_bp

Python code to generate random tensors and calculate tensordot and gradients using TensorFlow

import tensorflow as tf
import numpy as np
from pathlib import Path

g = tf.random.Generator.from_seed(123456)
commonSize = 2
for aRank in range(2,7):
    common = g.uniform(shape=[commonSize], minval=1, maxval=5, dtype=tf.dtypes.int32)
    aShape = tf.concat([g.uniform(shape=[aRank - commonSize], minval=1, maxval=5, dtype=tf.dtypes.int32), common], 0)
    a = g.normal(shape=aShape)
    for bRank in range(2,7):
        testDir = "Test_" + str(aRank) + "_" + str(bRank)
        testPath = Path(testDir)
        testPath.mkdir(exist_ok=True)
        bShape = tf.concat([g.uniform(shape=[bRank - commonSize], minval=1, maxval=5, dtype=tf.dtypes.int32), common], 0)
        b = g.normal(shape=bShape)
        with tf.GradientTape(persistent=True) as tape:
          tape.watch(a)
          tape.watch(b)
          c = tf.tensordot(a, b, axes=[[-2],[-2]])
        
        [dc_da, dc_db] = tape.gradient(c, [a, b])
        del tape

        np.save(testPath / "a", a.numpy())
        np.save(testPath / "b", b.numpy())
        np.save(testPath / "c", c.numpy())
        np.save(testPath / "dcda", dc_da.numpy())
        np.save(testPath / "dcdb", dc_db.numpy())

Minimal Scala override of TensorMmul (some imports elided)

tensormul

import java.{util => ju}
import scala.collection.JavaConverters._

class TensorMmulCustom(sd: SameDiff, x: SDVariable, y: SDVariable, dimensionsX: Seq[Int], dimensionsY: Seq[Int])
    extends TensorMmul(sd, x, y, dimensionsX.toArray, dimensionsY.toArray, false, false, false) {

  override def doDiff(gradients: ju.List[SDVariable]): ju.List[SDVariable] = {

    val op = new TensorMmulBp(this.getSameDiff(), larg(), rarg(), gradients.get(0), dimensionsX.toArray, dimensionsY.toArray)
    val outputs = op.outputVariables()
    outputs.toList.asJava
  }
}

tensormmul_bp op

class TensorMmulBp(sd: SameDiff, x: SDVariable, y: SDVariable, eps: SDVariable, dimensionsX: Array[Int], dimensionsY: Array[Int]) extends DynamicCustomOp("tensormmul_bp", sd, Array(x, y, eps)) {
    
  addIArgument(dimensionsX.size)
  addIArgument(dimensionsX: _*)
  addIArgument(dimensionsY.size)
  addIArgument(dimensionsY: _*)

  override def opName() = "tensormmul_bp"

  override def doDiff(i_v1: ju.List[SDVariable]): ju.List[SDVariable] =
    throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported")

  override def calculateOutputDataTypes(dataTypes: ju.List[DataType]): ju.List[DataType] = {

    Preconditions.checkState(
      dataTypes != null && dataTypes.size() == 3,
      "Expected exactly 3 inputs to tensormmul_bp op, got %s",
      dataTypes
    )
    Preconditions.checkState(
      dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType() && dataTypes.get(0).isFPType(),
      "Inputs to tensormmul_bp op must both be a floating point type: got %s",
      dataTypes
    )
    dataTypes.subList(0, 2)
  }
}

ScalaTest

import org.scalatest._
import org.scalatest.wordspec._
import org.scalatest.matchers.must.Matchers

abstract class CoreTestSpec extends AnyWordSpec with Matchers

class TensorMmulCustomTensorFlowTest extends CoreTestSpec {
  import scala.collection.JavaConverters._

  import java.io._

  case class TestResult(aShape: Seq[Long], bShape: Seq[Long], calcPass: Boolean, gradPass: Boolean, gradCrash: Boolean) {
    override def toString() =
      s"""a=[${aShape.mkString(",")}], b=[${bShape.mkString(",")}] ${if (calcPass) "Calc pass" else ""} ${if (gradPass) "Grad pass"
      else ""} ${if (gradCrash) "Grad crash" else ""}"""
  }

  val baseDir       = new File("""<PATH TO BASEDIR FROM PYTHON HERE>""")
  val contractIndex = 2 // Same one used in python code
  val testResults = baseDir.listFiles().filter(p => p.isDirectory() && p.getName().startsWith("Test")) map { dir =>
    info("Checking " + dir.getPath() + "...")
    val a    = Nd4j.createFromNpyFile(new File(dir, "a.npy"))
    val b    = Nd4j.createFromNpyFile(new File(dir, "b.npy"))
    val c    = Nd4j.createFromNpyFile(new File(dir, "c.npy"))
    val dcda = Nd4j.createFromNpyFile(new File(dir, "dcda.npy"))
    val dcdb = Nd4j.createFromNpyFile(new File(dir, "dcdb.npy"))

    val sd  = SameDiff.create()
    val aSd = sd.`var`("A", a)
    val bSd = sd.`var`("B", b)

    var result = TestResult(a.shape, b.shape, true, true, false)
    try {
      val cSd = new TensorMmulCustom(
        sd,
        aSd,
        bSd,
        Array(a.rank - contractIndex),
        Array(b.rank - contractIndex)
      ).outputVariable().rename("C")
      sd.setLossVariables(cSd)
      val grads = sd.calculateGradients(null, "A", "B").asScala
      grads("A") mustBe dcda
      grads("B") mustBe dcdb
    } catch {
      case ex: TestFailedException =>
        info(ex.getMessage())
        result = result.copy(gradPass = false)

      case _: Throwable =>
        info(s"Grads failed for a=${a.shapeInfoToString()}, b=${b.shapeInfoToString()}")
        result = result.copy(gradPass = false, gradCrash = true)
    }

    try {
      val output = sd.output(Map.empty[String, INDArray].asJava, "C").asScala
      output("C") mustBe c
      info("Calc Pass")
    } catch {
      case ex: TestFailedException =>
        info(s"TensorMmul calc error for a=${a.shapeInfoToString()}, b=${b.shapeInfoToString()}")
        info(ex.getMessage())
        result = result.copy(calcPass = false)

      case _: Throwable =>
        info(s"TensorMmul calc crashed for a=${a.shapeInfoToString()}, b=${b.shapeInfoToString()}")
        result = result.copy(calcPass = false)
    }

    result
  }

  val (ok, fail) = testResults.partition(r => r.calcPass && r.gradPass)
  info("\nPass\n" + ok.mkString("\n"))
  info("\nFail\n" + fail.mkString("\n"))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment