package jhm

case class SolverUsingFoldLeft( digits: String, total: Int, maxDigitsPerNumber: Int = 5
                              ) extends JHMSolver(digits, total, maxDigitsPerNumber) {

  case class PartialSolution( val partialStr: String, val partialNumber: Int,
                              val subTotal: Int, val numbersInReverse: List[Int] )
  {
    def runningSubTotal = subTotal + partialNumber
    def isInitialSolution = partialStr == ""
    def isZero = partialStr == "0"
    def isMaxLength = partialStr.length == maxDigitsPerNumber
  }

  override def solve(digits: String, total: Int, maxDigitsPerNumber: Int = 5): List[List[Int]] = {
    def StartNewPartialNumber(par: PartialSolution, newPartialStr: String,
                              newPartialNumber: Int)= PartialSolution(
      partialStr = newPartialStr,
      partialNumber = newPartialNumber,
      subTotal = par.runningSubTotal,
      numbersInReverse = if (par.isInitialSolution) List.empty // don't prepend zero
      else par.partialNumber :: par.numbersInReverse
    )

    def buildUpSolutions(partialSolutions: List[PartialSolution], nextDigit: Char): List[PartialSolution] = {
      def expandSolution(par: PartialSolution): List[PartialSolution] = {
        val nextDigitAsStr = nextDigit.toString
        val nextDigitAsInt = nextDigitAsStr.toInt
        val parSolWithStartOfNewNumber = StartNewPartialNumber(par, nextDigitAsStr, nextDigitAsInt)
        if (parSolWithStartOfNewNumber.runningSubTotal > total) 
          List.empty  // No point in trying an even bigger number
        else {
          if (par.isZero || par.isMaxLength || par.isInitialSolution) List(parSolWithStartOfNewNumber)
          else {
            val parSolWithLargerNumber = PartialSolution(
              partialStr = par.partialStr + nextDigit,
              partialNumber = 10 * par.partialNumber + nextDigit.toString.toInt,
              subTotal = par.subTotal,
              numbersInReverse = par.numbersInReverse
            )
            if (parSolWithLargerNumber.runningSubTotal > total) List(parSolWithStartOfNewNumber)
            else List(parSolWithStartOfNewNumber, parSolWithLargerNumber)
          }
        }
      }
      partialSolutions flatMap expandSolution
    }

    val initialPartialSolution = PartialSolution(
      partialStr = "", partialNumber = 0, subTotal = 0, numbersInReverse = List())
    val allSolutions = digits.foldLeft(List(initialPartialSolution))(buildUpSolutions)

    def completePartialSolution(par: PartialSolution): PartialSolution = PartialSolution(
      partialStr = "",
      partialNumber = 0,
      subTotal = par.runningSubTotal,
      numbersInReverse = par.partialNumber :: par.numbersInReverse
    )

    allSolutions map completePartialSolution filter (_.subTotal == total) map (_.numbersInReverse.reverse)
  }
}