Skip to content

Instantly share code, notes, and snippets.

@badbye
Forked from thirdwing/callback.plot.R
Created June 4, 2016 01:35
Show Gist options
  • Save badbye/360c8d3145cbbe4619bd67ea39d3eaa1 to your computer and use it in GitHub Desktop.
Save badbye/360c8d3145cbbe4619bd67ea39d3eaa1 to your computer and use it in GitHub Desktop.
logger <- mx.metric.logger$new()
mx.callback.plot.train.metric <- function(period, logger=NULL) {
function(iteration, nbatch, env, verbose=TRUE) {
if (nbatch %% period == 0 && !is.null(env$metric)) {
N = env$end.round
result <- env$metric$get(env$train.metric)
plot(c(0.5,1)~c(0,N), col=NA, ylab = paste0("Train-", result$name),xlab = "")
logger$train <- c(logger$train, result$value)
lines(logger$train, lwd = 3, col="red")
}
return(TRUE)
}
}
mx.set.seed(0)
model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,
ctx=mx.gpu(), num.round=10, array.batch.size=100,
learning.rate=0.05, momentum=0.9,
eval.metric=mx.metric.accuracy,
initializer=mx.init.uniform(0.07),
epoch.end.callback=mx.callback.plot.train.metric(100, logger))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment