Last active
November 1, 2019 15:45
-
-
Save Gobi03/107394032c63cd11500cbc1892999ba7 to your computer and use it in GitHub Desktop.
社内勉強会用、ビームサーチの例
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Scanner | |
import scala.xml.XML | |
import scala.collection.mutable.{PriorityQueue, Set} | |
import scala.util.control.Breaks.{breakable,break} | |
import Math.abs | |
/** <div.board> を入力として与える **/ | |
object Main extends App with Tools { | |
override val H = 4 | |
override val W = 4 | |
val board = (new BoardInputReader(H, W)).get() | |
def beamSearch(initialNode: Node, beamWidth: Int): List[Int] = { | |
var res: List[Int] = Nil | |
val done: Set[List[Byte]] = Set() | |
var nowState = new PriorityQueue[Node] | |
nowState.enqueue(initialNode) | |
breakable { | |
while (true) { | |
val nextState = new PriorityQueue[Node] | |
var i = 0 | |
while (i < beamWidth && !nowState.isEmpty) { | |
val now = nowState.dequeue() | |
if (!done(now.serialize)) { | |
done += now.serialize | |
if (now.isCompleted) { | |
res = now.history.reverse | |
break | |
} | |
for { | |
p <- make4dir(now.hole) | |
nextNode = now.move(p) | |
if !done(nextNode.serialize) | |
} { | |
nextState.enqueue(nextNode) | |
} | |
i += 1 | |
} | |
} | |
nowState = nextState | |
} | |
} | |
res | |
} | |
val ans = beamSearch(mkRootNode(board), 100) | |
JSCodeGenerator.run(ans) | |
} | |
trait Tools { | |
val H: Int | |
val W: Int | |
def make4dir(p: (Int, Int)): List[(Int, Int)] = { | |
def inRange(p: (Int, Int)) = { | |
val (x, y) = p | |
0 <= x && x < W && 0 <= y && y < H | |
} | |
val (x, y) = p | |
List((x-1, y), (x, y+1), (x+1, y), (x, y-1)).filter(inRange) | |
} | |
def copy2DArray(ar: Array[Array[Int]]): Array[Array[Int]] = ar.map(_.clone) | |
def mkRootNode(board: Array[Array[Int]]): Node = { | |
var hole = (-1, -1) | |
for (y <- 0 until H; x <- 0 until W) { | |
if (board(y)(x) == -1) | |
hole = (x, y) | |
} | |
Node(1, Nil, board, hole) | |
} | |
def manhattanDist(p1: (Int, Int), p2: (Int, Int)): Int = | |
abs(p1._1 - p2._1) + abs(p1._2 - p2._2) | |
def numberToCoord(n: Int): (Int, Int) = (n % W, n / W) | |
case class Node( | |
depth: Int, | |
history: List[Int], | |
board: Array[Array[Int]], | |
hole: (Int, Int) | |
) extends Ordered[Node] { | |
def eval: Int = { | |
var res = 0 | |
for (y <- 0 until H; x <- 0 until W) { | |
val now = this.board(y)(x) | |
if (now >= 0) { | |
res += manhattanDist((x, y), numberToCoord(now)) | |
} | |
} | |
res | |
} | |
def compare(that: Node): Int = Integer.compare(that.eval, this.eval) | |
def isCompleted: Boolean = { | |
for (y <- 0 until H; x <- 0 until W) { | |
if (y < H-1 || x < W-1) | |
if (this.board(y)(x) != y*W + x) | |
return false | |
} | |
true | |
} | |
// 空マスに隣接するマスが選ばれてること前提 | |
def move(clicked: (Int, Int)): Node = { | |
val (ex, ey) = this.hole | |
val (cx, cy) = clicked | |
val nextBoard = copy2DArray(this.board) | |
val num = this.board(cy)(cx) | |
nextBoard(ey)(ex) = num | |
nextBoard(cy)(cx) = -1 | |
Node( | |
depth = this.depth + 1, | |
history = num :: this.history, | |
board = nextBoard, | |
hole = clicked | |
) | |
} | |
def serialize: List[Byte] = this.board.flatten.toList.map(_.toByte) | |
} | |
} | |
class BoardInputReader(H: Int, W: Int) { | |
def get(): Array[Array[Int]] = { | |
val sc = new Scanner(System.in) | |
val xmlString = sc.nextLine() | |
val xml = XML.loadString(xmlString) | |
val input = parse(xml) | |
val board = Array.fill(H)(Array.fill(W)(-1)) | |
for (i <- 0 until H*W-1) { | |
val (x, y) = input(i) | |
board(y)(x) = i | |
} | |
board | |
} | |
private def parse(a: scala.xml.Elem): Seq[(Int, Int)] = { // (x, y) | |
(a \ "div" \ "div") | |
.map { row => | |
(row \ "@style").toString | |
.split(",").last.split(";") | |
.slice(2, 4).map(_.split(" ").last) | |
.map(s => s.take(s.length - 2).toInt) | |
} | |
.map( ar => (ar(0) / 80, ar(1) / 80) ) | |
} | |
} | |
object JSCodeGenerator { | |
def run(history: List[Int]): Unit = { | |
var timeCnt = 0 | |
val ans = history | |
.map { num => s"nyan[${num}].click()" } | |
.map{ s => | |
timeCnt += 1 | |
s"setTimeout(() => {$s}, ${timeCnt * 50})" } | |
.mkString(";") | |
val nyanDec = "const nyan = document.querySelectorAll(\"div.nyan\");" | |
println("{" + nyanDec + ans + "}") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment