-
-
Save KeitaTakenouchi/6139fdfedbf91a70532a4f1e27435fa3 to your computer and use it in GitHub Desktop.
Value iteration for Markov Decision Problem
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import kotlin.math.abs | |
import kotlin.math.max | |
import kotlin.math.min | |
const val width = 5 | |
const val height = 5 | |
val goal = Point(4, 4) | |
val barrier = Point(2, 2) | |
fun main() { | |
run() | |
} | |
fun run() { | |
// Initialize the value function | |
val values = mutableMapOf<State, Int>().apply { | |
allStates().forEach { this[it] = 0 } | |
} | |
// Value iteration using Bellman optimality equation | |
do { | |
var delta = 0 | |
allStates().forEach { state -> | |
val vTmp = values[state]!! | |
// update the value of a state | |
values[state] = Action.all().maxOf { action -> | |
reward(state, action) + values[nextState(state, action)]!! | |
} | |
delta = max(delta, abs(vTmp - values[state]!!)) | |
} | |
} while (delta > 0) | |
// Show the optimal trajectory | |
val a = Point(0, 0) | |
val b = Point(1, 1) | |
var state = State(a, b) | |
repeat(10) { | |
println(state) | |
// find optimal action | |
val optValue = Action.all().maxOf { action -> | |
reward(state, action) + values[nextState(state, action)]!! | |
} | |
val optAction = Action.all().find { action -> | |
optValue == reward(state, action) + values[nextState(state, action)]!! | |
}!! | |
// state transition | |
state = nextState(state, optAction) | |
} | |
} | |
fun allStates(): List<State> { | |
val ret = mutableListOf<State>() | |
for (x1 in 0 until width) { | |
for (y1 in 0 until height) { | |
val a = Point(x1, y1) | |
if (a == barrier) continue | |
for (x2 in 0 until width) { | |
for (y2 in 0 until height) { | |
val b = Point(x2, y2) | |
if (b == barrier) continue | |
if (a == b) continue | |
ret.add(State(a, b)) | |
} | |
} | |
} | |
} | |
return ret | |
} | |
fun nextState(state: State, action: Action): State { | |
val next = when (action) { | |
Action.Up -> State(state.a.goUp(), state.b.goUp()) | |
Action.Down -> State(state.a.goDown(), state.b.goDown()) | |
Action.Left -> State(state.a.goLeft(), state.b.goLeft()) | |
Action.Right -> State(state.a.goRight(), state.b.goRight()) | |
} | |
return if (next.isDuplicated()) state else next | |
} | |
fun reward(state: State, action: Action): Int { | |
return if (state.a == goal) 0 else -1 | |
} | |
enum class Action { | |
Up, Down, Left, Right; | |
companion object { | |
fun all(): Array<Action> { | |
return arrayOf(Up, Down, Left, Right) | |
} | |
} | |
} | |
data class State(val a: Point, val b: Point) { | |
fun isDuplicated(): Boolean { | |
return a == b | |
} | |
override fun toString(): String { | |
val sb = StringBuilder() | |
(0 until height).forEach { y -> | |
(0 until width).forEach { x -> | |
val p = Point(x, y) | |
val char = when (p) { | |
a -> " A" | |
b -> " B" | |
goal -> " G" | |
barrier -> " I" | |
else -> "|_" | |
} | |
sb.append(char) | |
} | |
sb.append("\n") | |
} | |
return sb.toString() | |
} | |
} | |
data class Point(val x: Int, val y: Int) { | |
fun goUp(): Point { | |
val next = Point(x, max(y - 1, 0)) | |
return if (next == barrier) this else next | |
} | |
fun goDown(): Point { | |
val next = Point(x, min(y + 1, height - 1)) | |
return if (next == barrier) this else next | |
} | |
fun goLeft(): Point { | |
val next = Point(max(x - 1, 0), y) | |
return if (next == barrier) this else next | |
} | |
fun goRight(): Point { | |
val next = Point(min(x + 1, width - 1), y) | |
return if (next == barrier) this else next | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment