Skip to content

Instantly share code, notes, and snippets.

@benqua
Created October 5, 2017 09:23
Show Gist options
  • Save benqua/a60169693561d6939c3c00964953ae7e to your computer and use it in GitHub Desktop.
Save benqua/a60169693561d6939c3c00964953ae7e to your computer and use it in GitHub Desktop.
import ml.dmlc.mxnet._
import ml.dmlc.mxnet.optimizer.SGD
import ml.dmlc.mxnet.module._
import com.typesafe.scalalogging.Logger
import scala.collection.immutable.ListMap
object BugShape {
def main(args: Array[String]): Unit = {
val logger = Logger("TrainModuleUNet")
def outShape(s: Symbol, inShape: Shape): Shape = {
val (_, shape, _) = s.inferShape(Map("data" -> inShape))
shape(0)
}
val batchSize = 2
val dataShape = Shape(batchSize, 1, 1052, 1052)
val labelShape = Shape(batchSize, 1, 868, 868)
val dataNDA = NDArray.ones(dataShape, Context.cpu())
val labelNDA = NDArray.ones(labelShape, Context.cpu())
val batch = new DataBatch(IndexedSeq(dataNDA), IndexedSeq(labelNDA), (0l to batchSize.toLong).toIndexedSeq, 0)
val sliceDiff = (1052 - 868) / 2
val sliceBegin = Shape(0, 0, sliceDiff, sliceDiff)
val sliceEnd = Shape(batchSize, 2, 1052 - sliceDiff, 1052 - sliceDiff)
val input = Symbol.Variable("data")
val label = Symbol.Variable("softmax_label")
val conv = Symbol.Convolution()()(Map("data" -> input, "num_filter" -> 2, "kernel" -> "(1,1)"))
val sliced = Symbol.slice()()(Map("data" -> conv, "begin" -> sliceBegin, "end" -> sliceEnd))
val so = Symbol.SoftmaxOutput()()(Map("data" -> sliced, "label" -> label, "multi_output" -> true))
logger.info(s"data shape $dataShape, label shape $labelShape")
logger.info("sliced shape: " + outShape(sliced, dataShape))
logger.info("so shape: " + outShape(so, dataShape))
val mod = new Module(so)
mod.bind(dataShapes = ListMap("data" -> dataShape), labelShapes = Some(ListMap("softmax_label" -> labelShape)))
logger.info(s"mod output shape: ${mod.outputShapes}")
mod.initParams(new Xavier())
mod.initOptimizer(optimizer = new SGD())
mod.forward(batch)
val outputMerged = mod.getOutputsMerged
logger.info(s"mod output merged: ${outputMerged(0).shape}")
mod.updateMetric(new Accuracy(), batch.label) // fails here
logger.info(s"updateMetric done")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment