-
-
Save pljones/8462061 to your computer and use it in GitHub Desktop.
/** * Alice in Markov Chains for West London Hack Night * * Developed in 1.5 hours by Scala team. * Post-night development by Peter L Jones */
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package wlhacknight | |
import scala.io.Source | |
import java.io.File | |
import java.util.Arrays | |
import scala.collection.mutable.ArrayBuffer | |
import scala.collection.immutable.HashMap | |
import scala.util.Random | |
/** | |
* Alice in Markov Chains for West London Hack Night | |
* | |
* Developed in 1.5 hours by Scala team. | |
* Post-night development by Peter L Jones | |
* ... mostly to help me learn more bits of Scala ... | |
*/ | |
object HackNight extends App { | |
case class Node(value: String, val isFollowedBy: Map[String, Int]) { | |
def link = { | |
val random = Random.nextInt(isFollowedBy.foldLeft(0)(_ + _._2)) + 1 | |
// I'm convinced I can get rid of this var sometime | |
var sum = 0 | |
isFollowedBy.toList | |
// Sort in increasing probability | |
.sortWith(_._2 < _._2) | |
// Drop until we've reached the random value | |
.dropWhile(p => { sum += p._2; sum < random }) | |
.head._1 | |
} | |
def isTerminal = { | |
value match { | |
case _@ ("'''line'''" | "'''verse'''" | "'''start'''") => true | |
case _ => false | |
} | |
} | |
} | |
class Chain(private val chain: Map[String, Node], private val start: String) extends Iterable[Node] { | |
def iterator = new ChainIterator | |
override def toString = chain.toString() | |
class ChainIterator extends Iterator[Node] { | |
// Stroll forward over terminal markers so we have next to start with | |
private[this] var current = chain(chain(start).link) | |
while (current.isTerminal) { | |
current = chain(current.link) | |
} | |
// Ensure hasNext stays false once it says false | |
private[this] var _hasNext = true | |
def hasNext = { | |
_hasNext = _hasNext && !self.isTerminal | |
_hasNext | |
} | |
// Do not lose the current element whilst computing the next | |
private[this] var self = current | |
def next = { | |
if (_hasNext) { | |
self = current | |
current = chain(current.link) | |
self | |
} else | |
throw new Exception("There is nothing here to see.") | |
} | |
} | |
} | |
def somekindofparsingfunction(words: List[String], nodeMap: Map[String, Node]): Map[String, Node] = { | |
words match { | |
case _ :: Nil => nodeMap | |
case word :: moreWords => { | |
val nextWord = moreWords.head | |
nodeMap get word match { | |
case None => | |
somekindofparsingfunction(moreWords, nodeMap + (word -> Node(word, Map(nextWord -> 1)))) | |
case Some(node) => { | |
val newNode = node.isFollowedBy get nextWord match { | |
case None => Node(word, node.isFollowedBy + (nextWord -> 1)) | |
case Some(count) => Node(word, node.isFollowedBy.updated(nextWord, count + 1)) | |
} | |
somekindofparsingfunction(moreWords, nodeMap + (word -> newNode)) | |
} | |
} | |
} | |
} | |
} | |
def allLines(files: Traversable[String]) = { | |
val regex = "[^a-z', ]".r | |
for ( | |
file <- files; | |
line <- Source.fromFile(file).getLines | |
.map(line => regex.replaceAllIn(line.toLowerCase(), "").trim()) | |
.filter(line => line != "") | |
.map(line => (line + " '''line'''") | |
.split(" ") | |
.map(word => word.trim()) | |
.filter(word => word match { | |
case _@ ("" | "'" | "," | " ") => false | |
case _ => true | |
})) | |
) yield line | |
} | |
val nodeMap = somekindofparsingfunction( | |
allLines(args).flatten.toList ++: ("'''start'''" :: Nil), | |
HashMap[String, Node]()) | |
println(nodeMap + "\n\n") | |
def getLine(start: String, size: Int): Tuple2[Node, String] = { | |
val nodesInLine = new Chain(nodeMap, start).take(size).toList | |
(nodesInLine.last, nodesInLine.takeWhile(v => !v.isTerminal).map(v => v.value).mkString(" ")) | |
} | |
var nextStart = "'''start'''" | |
val song = Iterator.continually({ | |
val linesInVerse = Iterator.continually({ | |
val (lastNode, line) = getLine(nextStart, 12 + Random.nextInt(12)) | |
nextStart = lastNode.value | |
(lastNode, line) | |
}).filter(_._2 != "").take(5 + Random.nextInt(6)).toList | |
nextStart = "'''verse'''" | |
linesInVerse.takeWhile(v => v._1.isTerminal).map(v => v._2).mkString("\n") | |
}).take(5 + Random.nextInt(5)).mkString("\n\n") | |
println(song) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Yes, far more code now. This means it must be better.
The big news is "class Chain" and the associated ChainIterator to make it "easier" to get the next word in the song.
Other other big news is that it now does actually weight the words based on frequency when picking a next word.
In other news, I've tried hard to get rid of as many vars as I could -- and added a few to keep a decent balance as this is Scala... but the added ones are to support the state of iterators. Really. Apart from "sum"... which could be computed with a def and some tail recursion instead of how I've done it, I suppose...
The output is now less incomprehensible and far more recognisably influenced by the input.