Skip to content

Instantly share code, notes, and snippets.

@srp
Created November 9, 2011 03:49
Show Gist options
  • Save srp/1350285 to your computer and use it in GitHub Desktop.
Save srp/1350285 to your computer and use it in GitHub Desktop.
scala script to calculate gridworld
// To run: scala gridworld.scala
val l = 1.0
val R = -200
class Cell(val initValue: Option[Double], var value: Option[Double] = None) {
var x0: Int = -9999
var y0: Int = -9999
var world: IndexedSeq[IndexedSeq[Cell]] = null
value = initValue
override def toString = "Cell(" + value + ")"
def tryMove(c: Cell) = if (c.value.isDefined) c else this
def up = tryMove(world(y0 - 1 max 0)(x0))
def down = tryMove(world(y0 + 1 min 2)(x0))
def left = tryMove(world(y0)(x0 - 1 max 0))
def right = tryMove(world(y0)(x0 + 1 min 3))
def moveVal(c1: Cell, c2: Cell, c3: Cell): Double = {
(0.8 * c1.value.get +
0.1 * c2.value.get +
0.1 * c3.value.get)
}
def update {
if (initValue == Some(0.0)) {
value = Some((moveVal(up, left, right) max
moveVal(down, left, right) max
moveVal(left, up, down) max
moveVal(right, up, down)) * l + R)
}
}
def policy =
if (initValue == Some(0.0)) {
val m = Map(moveVal(up, left, right) -> "up",
moveVal(down, left, right) -> "down",
moveVal(left, up, down) -> "left",
moveVal(right, up, down) -> "right")
//println(m)
m(m.keys.max)
} else "---"
}
val world: IndexedSeq[IndexedSeq[Cell]] =
IndexedSeq(IndexedSeq(new Cell(Some(0)), new Cell(Some(0)), new Cell(Some(0)), new Cell(Some(100))),
IndexedSeq(new Cell(Some(0)), new Cell(None), new Cell(Some(0)), new Cell(Some(-100))),
IndexedSeq(new Cell(Some(0)), new Cell(Some(0)), new Cell(Some(0)), new Cell(Some(0))))
for (y <- (0 to 2)) {
for (x <- (0 to 3)) {
world(y)(x).x0 = x
world(y)(x).y0 = y
world(y)(x).world = world
}
}
var isChanged = true
var loops = 0
while (isChanged) {
loops += 1
isChanged = false
for {
cells <- world
cell <- cells
} {
val last = cell.value
cell.update
if (last != cell.value)
isChanged = true
}
}
println("Iterations: " + loops.toString)
println
println("Values:")
for (cells <- world)
println(cells.toString)
println
println("Policies:")
for (cells <- world)
println(cells.map { _.policy }.toString)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment