Skip to content

Instantly share code, notes, and snippets.

@KeitaTakenouchi
Created July 1, 2021 14:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save KeitaTakenouchi/6139fdfedbf91a70532a4f1e27435fa3 to your computer and use it in GitHub Desktop.
Save KeitaTakenouchi/6139fdfedbf91a70532a4f1e27435fa3 to your computer and use it in GitHub Desktop.
Value iteration for Markov Decision Problem
package main
import kotlin.math.abs
import kotlin.math.max
import kotlin.math.min
const val width = 5
const val height = 5
val goal = Point(4, 4)
val barrier = Point(2, 2)
fun main() {
run()
}
fun run() {
// Initialize the value function
val values = mutableMapOf<State, Int>().apply {
allStates().forEach { this[it] = 0 }
}
// Value iteration using Bellman optimality equation
do {
var delta = 0
allStates().forEach { state ->
val vTmp = values[state]!!
// update the value of a state
values[state] = Action.all().maxOf { action ->
reward(state, action) + values[nextState(state, action)]!!
}
delta = max(delta, abs(vTmp - values[state]!!))
}
} while (delta > 0)
// Show the optimal trajectory
val a = Point(0, 0)
val b = Point(1, 1)
var state = State(a, b)
repeat(10) {
println(state)
// find optimal action
val optValue = Action.all().maxOf { action ->
reward(state, action) + values[nextState(state, action)]!!
}
val optAction = Action.all().find { action ->
optValue == reward(state, action) + values[nextState(state, action)]!!
}!!
// state transition
state = nextState(state, optAction)
}
}
fun allStates(): List<State> {
val ret = mutableListOf<State>()
for (x1 in 0 until width) {
for (y1 in 0 until height) {
val a = Point(x1, y1)
if (a == barrier) continue
for (x2 in 0 until width) {
for (y2 in 0 until height) {
val b = Point(x2, y2)
if (b == barrier) continue
if (a == b) continue
ret.add(State(a, b))
}
}
}
}
return ret
}
fun nextState(state: State, action: Action): State {
val next = when (action) {
Action.Up -> State(state.a.goUp(), state.b.goUp())
Action.Down -> State(state.a.goDown(), state.b.goDown())
Action.Left -> State(state.a.goLeft(), state.b.goLeft())
Action.Right -> State(state.a.goRight(), state.b.goRight())
}
return if (next.isDuplicated()) state else next
}
fun reward(state: State, action: Action): Int {
return if (state.a == goal) 0 else -1
}
enum class Action {
Up, Down, Left, Right;
companion object {
fun all(): Array<Action> {
return arrayOf(Up, Down, Left, Right)
}
}
}
data class State(val a: Point, val b: Point) {
fun isDuplicated(): Boolean {
return a == b
}
override fun toString(): String {
val sb = StringBuilder()
(0 until height).forEach { y ->
(0 until width).forEach { x ->
val p = Point(x, y)
val char = when (p) {
a -> " A"
b -> " B"
goal -> " G"
barrier -> " I"
else -> "|_"
}
sb.append(char)
}
sb.append("\n")
}
return sb.toString()
}
}
data class Point(val x: Int, val y: Int) {
fun goUp(): Point {
val next = Point(x, max(y - 1, 0))
return if (next == barrier) this else next
}
fun goDown(): Point {
val next = Point(x, min(y + 1, height - 1))
return if (next == barrier) this else next
}
fun goLeft(): Point {
val next = Point(max(x - 1, 0), y)
return if (next == barrier) this else next
}
fun goRight(): Point {
val next = Point(min(x + 1, width - 1), y)
return if (next == barrier) this else next
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment