Skip to content

Instantly share code, notes, and snippets.

@cjgb
Created August 12, 2016 14:48
Show Gist options
  • Save cjgb/159f5f3203e0bc1eb5105bc4e93a4d6e to your computer and use it in GitHub Desktop.
Save cjgb/159f5f3203e0bc1eb5105bc4e93a4d6e to your computer and use it in GitHub Desktop.
library(mxnet)
batch.size <- 32
seq.len <- 64
num.hidden = 128
num.embed = 128
num.lstm.layer = 1
num.round = 1
learning.rate= 0.1
wd=0.00001
clip_gradient=1
update.period = 1
make.data <- function(dir.boe, seq.len = 32, max.vocab=10000, dic = NULL) {
text <- lapply(dir(dir.boe), readLines)
text <- lapply(text, paste, collapse = "\n")
text <- paste(text, collapse = "\n")
chars <- unique(strsplit(text, '')[[1]])
dic <- as.list(1:length(chars))
names(dic) <- chars
lookup.table <- as.list(chars)
char.lst <- strsplit(text, '')[[1]]
num.seq <- floor(length(char.lst) / seq.len)
char.lst <- char.lst[1:(num.seq * seq.len)]
data <- matrix(match(char.lst, chars) - 1, seq.len, num.seq)
return (list(data=data, dic=dic, lookup.table=lookup.table))
}
get.label <- function(X) {
label <- c(X[-1], X[1])
matrix(label, nrow(X), ncol(X))
}
ret <- make.data(".", seq.len=seq.len)
X <- ret$data
dic <- ret$dic
lookup.table <- ret$lookup.table
vocab <- length(dic)
train.val.fraction <- 0.9
train.cols <- floor(ncol(X) * train.val.fraction)
drop.tail <- function(x, batch.size) {
nstep <- floor(ncol(x) / batch.size)
x[, 1:(nstep * batch.size)]
}
X.train.data <- X[, 1:train.cols]
X.train.data <- drop.tail(X.train.data, batch.size)
X.train.label <- get.label(X.train.data)
X.train <- list(data=X.train.data, label=X.train.label)
X.val.data <- X[, -(1:train.cols)]
X.val.data <- drop.tail(X.val.data, batch.size)
X.val.label <- get.label(X.val.data)
X.val <- list(data=X.val.data, label=X.val.label)
model <- mx.lstm(X.train, X.val,
ctx=mx.cpu(),
num.round=num.round,
update.period=update.period,
num.lstm.layer=num.lstm.layer,
seq.len=seq.len,
num.hidden=num.hidden,
num.embed=num.embed,
num.label=vocab,
batch.size=batch.size,
input.size=vocab,
initializer=mx.init.uniform(0.1),
learning.rate=learning.rate,
wd=wd,
clip_gradient=clip_gradient)
get.sample <- function(n, start = "<", random.sample = TRUE){
make.output <- function(prob, sample=FALSE) {
prob <- as.numeric(as.array(prob))
if (!sample)
return(which.max(as.array(prob)))
sample(1:length(prob), 1, prob = prob^2)
}
infer.model <- mx.lstm.inference(num.lstm.layer=num.lstm.layer,
input.size=vocab,
num.hidden=num.hidden,
num.embed=num.embed,
num.label=vocab,
arg.params=model$arg.params,
ctx=mx.cpu())
out <- start
last.id <- dic[[start]]
for (i in 1:(n-1)) {
ret <- mx.lstm.forward(infer.model, last.id - 1, FALSE)
infer.model <- ret$model
last.id <- make.output(ret$prob, random.sample)
out <- paste0(out, lookup.table[[last.id]])
}
out
}
cat(get.sample(1000, start = "A", random.sample = T))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment