Created
October 5, 2017 09:23
-
-
Save benqua/a60169693561d6939c3c00964953ae7e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ml.dmlc.mxnet._ | |
import ml.dmlc.mxnet.optimizer.SGD | |
import ml.dmlc.mxnet.module._ | |
import com.typesafe.scalalogging.Logger | |
import scala.collection.immutable.ListMap | |
object BugShape { | |
def main(args: Array[String]): Unit = { | |
val logger = Logger("TrainModuleUNet") | |
def outShape(s: Symbol, inShape: Shape): Shape = { | |
val (_, shape, _) = s.inferShape(Map("data" -> inShape)) | |
shape(0) | |
} | |
val batchSize = 2 | |
val dataShape = Shape(batchSize, 1, 1052, 1052) | |
val labelShape = Shape(batchSize, 1, 868, 868) | |
val dataNDA = NDArray.ones(dataShape, Context.cpu()) | |
val labelNDA = NDArray.ones(labelShape, Context.cpu()) | |
val batch = new DataBatch(IndexedSeq(dataNDA), IndexedSeq(labelNDA), (0l to batchSize.toLong).toIndexedSeq, 0) | |
val sliceDiff = (1052 - 868) / 2 | |
val sliceBegin = Shape(0, 0, sliceDiff, sliceDiff) | |
val sliceEnd = Shape(batchSize, 2, 1052 - sliceDiff, 1052 - sliceDiff) | |
val input = Symbol.Variable("data") | |
val label = Symbol.Variable("softmax_label") | |
val conv = Symbol.Convolution()()(Map("data" -> input, "num_filter" -> 2, "kernel" -> "(1,1)")) | |
val sliced = Symbol.slice()()(Map("data" -> conv, "begin" -> sliceBegin, "end" -> sliceEnd)) | |
val so = Symbol.SoftmaxOutput()()(Map("data" -> sliced, "label" -> label, "multi_output" -> true)) | |
logger.info(s"data shape $dataShape, label shape $labelShape") | |
logger.info("sliced shape: " + outShape(sliced, dataShape)) | |
logger.info("so shape: " + outShape(so, dataShape)) | |
val mod = new Module(so) | |
mod.bind(dataShapes = ListMap("data" -> dataShape), labelShapes = Some(ListMap("softmax_label" -> labelShape))) | |
logger.info(s"mod output shape: ${mod.outputShapes}") | |
mod.initParams(new Xavier()) | |
mod.initOptimizer(optimizer = new SGD()) | |
mod.forward(batch) | |
val outputMerged = mod.getOutputsMerged | |
logger.info(s"mod output merged: ${outputMerged(0).shape}") | |
mod.updateMetric(new Accuracy(), batch.label) // fails here | |
logger.info(s"updateMetric done") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment