Skip to content

Instantly share code, notes, and snippets.

@ewoo
Forked from jameskyle/main.R
Created November 29, 2015 13:46
Show Gist options
  • Save ewoo/914b2504a6f325e41583 to your computer and use it in GitHub Desktop.
Save ewoo/914b2504a6f325e41583 to your computer and use it in GitHub Desktop.
library(MDPtoolbox)
library(Matrix)
library(ggplot2)
library(grid)
library(gridExtra)
library(doMC)
cores <- detectCores() - (detectCores() / 2) / 2
registerDoMC(cores=cores)
set.seed(1234)
mdp_value_iteration <- function (P, R, discount, epsilon, max_iter, V0)
{
start <- as.POSIXlt(Sys.time())
if (discount <= 0 | discount > 1) {
print("--------------------------------------------------------")
print("MDP Toolbox ERROR: Discount rate must be in ]0; 1]")
print("--------------------------------------------------------")
}
else if (nargs() > 3 & ifelse(!missing(epsilon), ifelse(epsilon <
0, T, F), F)) {
print("--------------------------------------------------------")
print("MDP Toolbox ERROR: epsilon must be upper than 0")
print("--------------------------------------------------------")
}
else if (nargs() > 4 & ifelse(!missing(max_iter), ifelse(max_iter <=
0, T, F), F)) {
print("--------------------------------------------------------")
print("MDP Toolbox ERROR: The maximum number of iteration must be upper than 0")
print("--------------------------------------------------------")
}
else if (is.list(P) & nargs() > 5 & ifelse(!missing(V0),
ifelse(length(V0) != dim(P[[1]])[1], T, F), F)) {
print("--------------------------------------------------------")
print("MDP Toolbox ERROR: V0 must have the same dimension as P")
print("--------------------------------------------------------")
}
else if (!is.list(P) & nargs() > 5 & ifelse(!missing(V0),
ifelse(length(V0) != dim(P)[1], T, F), F)) {
print("--------------------------------------------------------")
print("MDP Toolbox ERROR: V0 must have the same dimension as P")
print("--------------------------------------------------------")
}
else {
if (discount == 1) {
print("--------------------------------------------------------")
print("MDP Toolbox WARNING: check conditions of convergence.")
print("With no discount, convergence is not always assumed.")
print("--------------------------------------------------------")
}
if (is.list(P)) {
S <- dim(P[[1]])[1]
A <- length(P)
}
else {
S <- dim(P)[1]
A <- dim(P)[3]
}
PR <- mdp_computePR(P, R)
if (nargs() < 6) {
V0 <- numeric(S)
}
if (nargs() < 4) {
epsilon <- 0.01
}
if (discount != 1)
computed_max_iter <- 5000
if (nargs() < 5) {
if (discount != 1) {
max_iter <- computed_max_iter
}
else {
max_iter <- 5000
}
}
else {
if (discount != 1 & max_iter > computed_max_iter) {
print(paste("MDP Toolbox WARNING: max_iter is bounded by ",
computed_max_iter))
max_iter <- computed_max_iter
}
}
if (discount != 1) {
thresh <- epsilon * (1 - discount)/discount
}
else {
thresh <- epsilon
}
iter <- 0
V <- V0
is_done <- F
converged <- -1
while (!is_done) {
iter <- iter + 1
Vprev <- V
bellman <- mdp_bellman_operator(P, PR, discount,
V)
V <- bellman[[1]]
policy <- bellman[[2]]
variation <- mdp_span(V - Vprev)
if (variation < thresh) {
# is_done <- T
converged <- iter
#print(sprintf("MDP Toolbox: epsilon-optimal policy found at iter %d", converged))
}
if (iter == max_iter) {
is_done <- T
#print("MDP Toolbox: iterations stopped by maximum number of iteration condition")
}
}
}
end <- as.POSIXlt(Sys.time())
return(list(V = V,
policy = policy,
iter = iter,
time = end - start,
epsilon = epsilon,
discount = discount,
converged = converged))
}
mdp_Q_learning <- function (P, R, discount, N, max.time=1800)
{
# ganked from MDPtoolbox
start <- as.POSIXlt(Sys.time())
if (discount <= 0 | discount > 1) {
print("--------------------------------------------------------")
print("MDP Toolbox ERROR: Discount rate must be in ]0; 1]")
print("--------------------------------------------------------")
}
else if (nargs() >= 4 & ifelse(!missing(N), N <= 0, F)) {
print("--------------------------------------------------------")
print("MDP Toolbox ERROR: N must a positive integer")
print("--------------------------------------------------------")
}
else {
if (nargs() < 4) {
N <- 10000
#N <- 1000
}
if (is.list(P)) {
S <- dim(P[[1]])[1]
A <- length(P)
}
else {
S <- dim(P)[1]
A <- dim(P)[3]
}
Q <- matrix(0, S, A)
dQ <- matrix(0, S, A)
mean_discrepancy <- NULL
discrepancy <- NULL
max.time.exceeded <- F
state <- sample(1:S, 1, replace = T)
iters <- 1
for (n in 1:N) {
if (n%%100 == 0) {
state <- sample(1:S, 1, replace = T)
}
pn <- runif(1)
if (pn < (1 - (1/log(n + 2)))) {
optimal_action <- max(Q[state, ])
a <- which.max(Q[state, ])
}
else {
a <- sample(1:A, 1, replace = T)
}
p_s_new <- runif(1)
p <- 0
s_new <- 0
while ((p < p_s_new) & (s_new < S)) {
s_new <- s_new + 1
if (is.list(P)) {
p <- p + P[[a]][state, s_new]
}
else {
p <- p + P[state, s_new, a]
}
}
if (is.list(R)) {
r <- R[[a]][state, s_new]
}
else {
if (length(dim(R)) == 3) {
r <- R[state, s_new, a]
}
else {
r <- R[state, a]
}
}
delta <- r + discount * max(Q[s_new, ]) - Q[state,
a]
dQ <- (1/sqrt(n + 2)) * delta
Q[state, a] <- Q[state, a] + dQ
state <- s_new
discrepancy[(n%%100) + 1] = abs(dQ)
if (length(discrepancy) == 100) {
mean_discrepancy <- c(mean_discrepancy, mean(discrepancy))
discrepancy <- NULL
}
iters <- n
t <- as.POSIXlt(Sys.time()) - start
if (t[[1]] > max.time) {
max.time.exceeded <- T
break
}
}
V <- apply(Q, 1, max)
policy <- apply(Q, 1, which.max)
}
end <- as.POSIXlt(Sys.time())
return(list(
Q = Q,
V = V,
policy = policy,
mean_discrepancy = mean_discrepancy,
discount = discount,
iter=iters,
max.iter=N,
time=end - start,
max.time=max.time.exceeded
))
}
mdp_policy_iteration <- function (P, R, discount, max_iter, policy0, eval_type)
{
# Modified from MDPtoolbox package
start <- as.POSIXlt(Sys.time())
if (discount <= 0 | discount > 1) {
print("--------------------------------------------------------")
print("MDP Toolbox ERROR: Discount rate must be in ]0; 1]")
print("--------------------------------------------------------")
}
else if (nargs() > 3 & is.list(P) & ifelse(!missing(policy0),
length(policy0) != dim(P[[1]])[1], F)) {
print("--------------------------------------------------------")
print("MDP Toolbox ERROR: policy must have the same dimension as P")
print("--------------------------------------------------------")
}
else if (nargs() > 3 & !is.list(P) & ifelse(!missing(policy0),
length(policy0) != dim(P)[1], F)) {
print("--------------------------------------------------------")
print("MDP Toolbox ERROR: policy must have the same dimension as P")
print("--------------------------------------------------------")
}
else if (nargs() > 4 & ifelse(!missing(max_iter), max_iter <= 0, F)) {
print("--------------------------------------------------------")
print("MDP Toolbox ERROR: The maximum number of iteration must be geater than 0")
print("--------------------------------------------------------")
}
else {
if (is.list(P)) {
S <- dim(P[[1]])[1]
A <- length(P)
}
else {
S <- dim(P)[1]
A <- dim(P)[3]
}
PR <- mdp_computePR(P, R)
if (nargs() < 6) {
eval_type <- 0
}
if (nargs() < 5) {
bellman <- mdp_bellman_operator(P, PR, discount,
numeric(S))
Vunused <- bellman[[1]]
policy0 <- bellman[[2]]
}
if (nargs() < 4) {
max_iter <- 1000
}
iter <- 0
policy <- policy0
is_done <- F
while (!is_done) {
iter <- iter + 1
if (eval_type == 0) {
V <- mdp_eval_policy_matrix(P, PR, discount,
policy)
}
else {
V <- mdp_eval_policy_iterative(P, PR, discount,
policy)
}
bellman <- mdp_bellman_operator(P, PR, discount,
V)
Vnext <- bellman[[1]]
policy_next <- bellman[[2]]
n_different <- sum(policy_next != policy)
#if (setequal(policy_next, policy) | iter == max_iter) {
if (iter == max_iter) {
is_done <- T
}
else {
policy <- policy_next
}
}
end <- as.POSIXlt(Sys.time())
return(list(V = V,
policy = policy,
iter = iter,
time = end - start,
discount=discount))
}
}
forest <- function() {
# Wait
m1 <- matrix(c(
# Cleared, Young Forest, Old Forest, Farm
0.1, 0.9, 0.0, 0.0, # Cleared
0.1, 0.0, 0.9, 0.0, # Young Forest
0.1, 0.0, 0.9, 0.0, # Old Forest
0.1, 0.0, 0.0, 0.9 # Farm
), 4, 4, byrow=T)
# Cut
m2 <- matrix(c(
# Cleared, Young Forest, Old Forest, Farm
1, 0, 0, 0, # Cleared
1, 0, 0, 0, # Young Forest
1, 0, 0, 0, # Old Forest
1, 0, 0, 0 # Farm
), 4, 4, byrow=T)
# Cultivate
m3 <- matrix(c(
# Cleared, Young Forest, Old Forest, Farm
0.1, 0, 0, 0.9, # Cleared
1, 0, 0, 0, # Young Forest
1, 0, 0, 0, # Old Forest
0.1, 0, 0, 0.9 # Farm
), 4, 4, byrow=T)
P <- array(0, dim=c(4,4,3))
P[,,1] <- m1
P[,,2] <- m2
P[,,3] <- m3
R <- matrix(c(
# Rewards
# Waiting, Cutting, Cultivating
0.0, 0.0, 0.0, # empty field
2.0, 1.0, 1.0, # Young Forest
4.0, 2.0, 2.0, # Old Forest
2.0, 0.0, 4.0 # Farm
), 4, 3, byrow=T)
colnames(R) <- c('R1', 'R2', 'R3')
colnames(P) <- c('Fire', 'Young Forest', 'Old Forest', 'Farm')
list(P=P, R=R)
}
forest.calc <- function() {
message("Collecting forest management calculations...")
f.data <- forest()
values <- value.iteration(f.data)
policies <- policy.iteration(f.data)
qlearning <- q.learning(f.data)
results <- list(Value.Iteration=values,
Policy.Iteration=policies,
QLearning=qlearning
)
rewards <- rewards(f.data, results, 1, max.plays=100, reps=100)
list(results=results, rewards=rewards)
}
tictactoe <- function() {
load("data/tictactoe/R.RData")
load("data/tictactoe/P.RData")
list(P=p, R=r)
}
value.iteration <- function(mat) {
message("Calculating value iteration...")
results <- foreach (i=seq(0.1,0.9,by=.1)) %dopar% {
results <- list()
iter <- -1
for(j in 1:10) {
model <- mdp_value_iteration(mat$P, mat$R, discount=i, epsilon=0.01, max_iter=j)
if (model$iter != iter) {
iter <- model$iter
results <- append(results, list(model))
} else {
break
}
}
results
}
collected <- list()
for (chunk in results) {
collected <- append(collected, chunk)
}
collected
}
policy.iteration <- function(mat) {
message("Calculating policy iteration...")
results <- foreach (i=seq(0.1,0.9,by=.1)) %dopar% {
results <- list()
iter <- -1
for(j in 1:10) {
model <- mdp_policy_iteration(mat$P, mat$R, discount=i, max_iter=j)
if (model$iter != iter) {
iter <- model$iter
results <- append(results, list(model))
} else {
break
}
}
results
}
collected <- list()
for (chunk in results) {
collected <- append(collected, chunk)
}
collected
}
q.learning <- function(mat, max.time=1800) {
message("Calculating Q Learning....")
results <- foreach (i=c(.1,.5,.9)) %dopar% {
results <- list()
for(j in seq(1, 3001, by=300)) {
model <- mdp_Q_learning(mat$P,
mat$R,
discount=i,
N=j,
max.time=max.time
)
results <- append(results, list(model))
}
results
}
collected <- list()
for (chunk in results) {
collected <- append(collected, chunk)
}
collected
}
next.state <- function(P, state) {
probs <- as.vector(P[state,])
sample(1:length(probs), 1, prob=probs)
}
simulate <- function(mat, state, policy, rewards=NULL, max.plays=10) {
if (length(rewards) == max.plays) {
rewards
} else {
action <- policy[state]
r <- as.numeric(mat$R[state, action])
if (is.null(rewards)) {
rewards <- array(r)
} else {
rewards <- append(rewards, r)
}
if (!is.list(mat$P)) {
P <- mat$P[,,action]
} else {
P <- mat$P[[action]]
}
next.state <- next.state(P, state)
simulate(mat, next.state, policy, rewards, max.plays)
}
}
num.iters <- function(results) {
unlist(unique(lapply(results, function(x) x$iter)))
}
num.discounts <- function(results) {
unlist(unique(lapply(results, function(x) x$discount)))
}
rewards <- function(world, results, state, max.plays=10, reps=1000) {
message("Calculating rewards...")
df <- NULL
models <- names(results)
for (model in models) {
for (v in results[[model]]) {
sums <- replicate(reps, {
sum(simulate(world, 1, v$policy, max.plays=max.plays))
})
reward.mean <- mean(sums)
if (!is.null(df)) {
row <- data.frame(iter=v$iter,
discount=v$discount,
reward=reward.mean,
model=model,
time.secs=v$time[[1]])
df <- rbind(df, row)
} else {
df <- data.frame(iter=v$iter,
discount=v$discount,
reward=reward.mean,
model=model,
time.secs=v$time[[1]])
}
}
}
df
}
tictactoe.calc <- function() {
message("Collecting tic-tac-toe calculations...")
t.data <- tictactoe()
values <- value.iteration(t.data)
policies <- policy.iteration(t.data)
qlearning <- q.learning(t.data)
results <- list(Value.Iteration=values,
Policy.Iteration=policies,
QLearning=qlearning
)
rewards <- rewards(f.data, results, 1, max.plays=100, reps=100)
list(results=results, rewards=rewards)
}
my.plot <- function(title, data, outdir, multi=T, individual=F) {
p1 <- ggplot(data, aes(x=iter, y=reward, colour=discount.factor)) +
geom_point() +
geom_line() +
labs(title="Reward Per Iteration",
x = "Iterations",
y = "Mean Reward",
colour = "Discount")
p2 <- ggplot(data, aes(x=iter, y=time.secs, colour=discount.factor)) +
geom_point() +
geom_line() +
labs(title="Time Per Iteration",
x = "Iterations",
y = "Time in Seconds",
colour = "Discount")
fname <- tolower(gsub(" ", "_", title))
if (individual) {
ggsave(filename=sprintf("%s/%s_reward.png", outdir, fname), plot=p1)
ggsave(filename=sprintf("%s/%s_time.png", outdir, fname), plot=p2)
}
if (multi) {
png(sprintf("%s/%s.png", outdir, fname))
grid.arrange(p1, p2, ncol = 1, top = textGrob(title))
dev.off()
}
}
plots.preproc <- function(data) {
d2 <- data
d2$discount.factor <- as.factor(d2$discount)
d2[data$discount == 0.1 | data$discount == 0.5 | data$discount == 0.9,]
}
main <- function() {
d <- "doc/graphs/forest/"
dir.create(d, recursive=T)
f.r <- forest.calc()
rewards <- plots.preproc(f.r$rewards)
message("Plotting Forest Results...")
my.plot("Value Iteration", rewards[rewards$model == "Value.Iteration",], d)
my.plot("Policy Iteration", rewards[rewards$model == "Policy.Iteration",], d)
my.plot("Q Learning", rewards[rewards$model == "QLearning",], d)
d <- "doc/graphs/tictactoe/"
dir.create(d, recursive=T)
t.r <- tictactoe.calc()
rewards <- plots.preproc(t.r$rewards)
message("Plotting Forest Results...")
my.plot("Value Iteration", rewards[rewards$model == "Value.Iteration",], d)
my.plot("Policy Iteration", rewards[rewards$model == "Policy.Iteration",], d)
my.plot("Q Learning", rewards[rewards$model == "QLearning",], d)
dir.create("results")
tictactoe <- t.r
forest <- f.r
save(tictactoe, file="results/TicTacToe.RData")
save(forest, file="results/Forest.RData")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment