Skip to content

Instantly share code, notes, and snippets.

@waynejo
Created March 31, 2023 12:55
Show Gist options
  • Save waynejo/623eed33f6767c606d55c71c75f11b46 to your computer and use it in GitHub Desktop.
Save waynejo/623eed33f6767c606d55c71c75f11b46 to your computer and use it in GitHub Desktop.
import java.io.FileInputStream
import scala.annotation.{tailrec, targetName}
import scala.io.StdIn
import scala.math.abs
sealed trait TreeElement extends Ordered[TreeElement]:
import scala.math.Ordered.orderingToOrdered
def compare(that: TreeElement): Int = compareTreeElement(this, that, 0) match
case CompareResult.valid => -1
case CompareResult.invalid => 1
case CompareResult.unknown => 0
case class Number(v: Int) extends TreeElement
case class Child(v: Vector[TreeElement]) extends TreeElement
enum CompareResult:
case valid, invalid, unknown
def buildTreePair(lines: Seq[String]): (TreeElement, TreeElement) =
(parseTreeElement(lines.head, 1, List(Child(Vector()))), parseTreeElement(lines(1), 1, List(Child(Vector()))))
def parseTreeElement(s: String, index: Int = 1, acc: List[TreeElement]): TreeElement =
(s.lift(index), acc.head, acc.tail.headOption) match
case (None, _, _) =>
acc.head
case (Some(c), Number(n), Some(Child(lastChilds))) if !('0' <= c && '9' >= c) =>
parseTreeElement(s, index, Child(lastChilds :+ Number(n)) +: acc.tail.tail)
case (Some('['), _, _) =>
parseTreeElement(s, index + 1, Child(Vector()) +: acc)
case (Some(']'), child, Some(Child(lastChilds))) =>
parseTreeElement(s, index + 1, Child(lastChilds :+ child) +: acc.tail.tail)
case (Some(']'), _, None) =>
parseTreeElement(s, index + 1, acc)
case (Some(','), _, _) =>
parseTreeElement(s, index + 1, acc)
case (Some(c), Child(_), _) if '0' <= c && '9' >= c =>
parseTreeElement(s, index + 1, Number(c - '0') +: acc)
case (Some(c), Number(number), _) if '0' <= c && '9' >= c =>
parseTreeElement(s, index + 1, Number(number * 10 + (c - '0')) +: acc.tail)
case c =>
throw new Exception(s"Unexpected char at index $index: $c")
def compareTreeElement(one: TreeElement, two: TreeElement, index: Int): CompareResult =
(one, two) match
case (Number(n1), Number(n2)) if n1 < n2 =>
CompareResult.valid
case (Number(n1), Number(n2)) if n1 > n2 =>
CompareResult.invalid
case (Number(n1), Number(n2)) if n1 == n2 =>
CompareResult.unknown
case (Number(n1), Child(c1)) =>
compareTreeElement(Child(Vector(Number(n1))), Child(c1), 0)
case (Child(c1), Number(n1)) =>
compareTreeElement(Child(c1), Child(Vector(Number(n1))), 0)
case (Child(c1), Child(c2)) =>
(c1.lift(index), c2.lift(index)) match
case (None, Some(_)) =>
CompareResult.valid
case (Some(_), None) =>
CompareResult.invalid
case (Some(n1), Some(n2)) =>
compareTreeElement(n1, n2, 0) match
case CompareResult.valid =>
CompareResult.valid
case CompareResult.invalid =>
CompareResult.invalid
case CompareResult.unknown =>
compareTreeElement(Child(c1), Child(c2), index + 1)
case (_, _) =>
CompareResult.unknown
case (_, _) =>
throw new Exception(s"Unexpected combination: $one, $two")
def solve13_1(input: Vector[(TreeElement, TreeElement)]): Int =
input.map(x => compareTreeElement(x._1, x._2, 0)).zipWithIndex.filter(_._1 == CompareResult.valid).map(_._2 + 1).sum
def solve13_2(input: Vector[(TreeElement, TreeElement)]): Int =
val key1 = Child(Vector(Child(Vector(Number(2)))))
val key2 = Child(Vector(Child(Vector(Number(6)))))
val values: Vector[TreeElement] = input.flatMap(x => Vector(x._1, x._2)) ++ Vector(key1, key2)
val sortedValues = values.sorted
(sortedValues.indexOf(key1) + 1) * (sortedValues.indexOf(key2) + 1)
@main def solve13(): Unit =
val in = new FileInputStream("example13-2.in")
System.setIn(in)
val inputs = Iterator.continually(StdIn.readLine())
.takeWhile(line => null != line)
.filter(_.nonEmpty)
.grouped(2)
.map(buildTreePair)
.toVector
println(solve13_1(inputs))
println(solve13_2(inputs))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment