Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@thirdwing
Created June 3, 2016 22:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save thirdwing/ceaf2f8725349d98249e7631227b5f35 to your computer and use it in GitHub Desktop.
Save thirdwing/ceaf2f8725349d98249e7631227b5f35 to your computer and use it in GitHub Desktop.
logger <- mx.metric.logger$new()
mx.callback.plot.train.metric <- function(period, logger=NULL) {
function(iteration, nbatch, env, verbose=TRUE) {
if (nbatch %% period == 0 && !is.null(env$metric)) {
N = env$end.round
result <- env$metric$get(env$train.metric)
plot(c(0.5,1)~c(0,N), col=NA, ylab = paste0("Train-", result$name),xlab = "")
logger$train <- c(logger$train, result$value)
lines(logger$train, lwd = 3, col="red")
}
return(TRUE)
}
}
mx.set.seed(0)
model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,
ctx=mx.gpu(), num.round=10, array.batch.size=100,
learning.rate=0.05, momentum=0.9,
eval.metric=mx.metric.accuracy,
initializer=mx.init.uniform(0.07),
epoch.end.callback=mx.callback.plot.train.metric(100, logger))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment