package com.aniketkadam.heaps

import org.jetbrains.annotations.TestOnly

class MaxHeap<T : Comparable<T>> {
    private val items: MutableList<T> = mutableListOf()

    fun insertAll(items: List<T>): Unit = items.forEach(::insert)

    fun insert(item: T): Unit {
        items.add(item)
        bubbleUp(items.size - 1)
    }

    private fun getParentIndex(index: Int): Int {
        // Left child = 2i + 1
        // Right child = 2i + 2
        // So the parent of 'i' is:
        //  2pi + 1 = ci
        // pi = Integer value of (c1 - 1)/2
        return (index - 1) / 2
    }

    private fun bubbleUp(index: Int) {
        // If it's the root item, it'll be compared to itself and will never be greater than itself
        val parentIndex = getParentIndex(index)
        val isGreaterThanParent = items[index] > items[parentIndex]

        if (isGreaterThanParent) {
            // Swap with the parent
            val temp = items[index]
            items[index] = items[parentIndex]
            items[parentIndex] = temp
            // Check the parent
            bubbleUp(parentIndex)
        }

    }

    private fun bubbleDown(index: Int = 0) {
        // Get the left and right child
        // Left child = 2i + 1
        // Right child = 2i + 2
        val leftIdx = 2 * index + 1
        val rightIdx = 2 * index + 2

        var compareIndex = -1
        if (rightIdx < items.size) {
            // Both left and right children exist
            // The compare index is the one which is greater
            if(items[leftIdx] > items[rightIdx]) {
                compareIndex = leftIdx
            } else {
                compareIndex = rightIdx
            }
        } else if (leftIdx < items.size) {
            // Only left child exists
            compareIndex = leftIdx
        } else {
            // There are no children so nothing to do
            return
        }

        // If we get here then compareIndex is not -1
        //  Also a swap may be necessary
        if(items[index] < items[compareIndex]) {
            val temp = items[index]
            items[index] = items[compareIndex]
            items[compareIndex] = temp
            bubbleDown(compareIndex)
        }
    }

    fun extract(): T {
        val max = items[0]
        items[0] = items.removeAt(items.size - 1)
        bubbleDown()
        return max
    }

    @TestOnly
    fun shape(): List<T> = items
}