-
-
Save shrubb/4c3f3f92325bd8f176447ff38b5f8da6 to your computer and use it in GitHub Desktop.
Competitive programming template
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
def partition(array, start, end, pivot_idx): | |
pivot = array[pivot_idx] | |
array[pivot_idx], array[end - 1] = array[end - 1], array[pivot_idx] | |
left, right = start, end - 2 | |
# invariant: | |
# array[< left] < pivot | |
# array[> right] >= pivot | |
while True: | |
while left <= right and array[left] < pivot: | |
left += 1 | |
while left <= right and array[right] >= pivot: | |
right -= 1 | |
if left > right: | |
break | |
array[left], array[right] = array[right], array[left] | |
# Now left <= end - 1, and pivot is at array[end - 1] | |
array[left], array[end - 1] = array[end - 1], array[left] | |
return left | |
def quicksort(array, start=0, end=None): | |
if end is None: | |
end = len(array) | |
if end - start <= 1: | |
return | |
new_pivot_idx = partition(array, start, end, start) | |
quicksort(array, start, new_pivot_idx) | |
quicksort(array, new_pivot_idx + 1, end) | |
def kth_element(nums, k: int): | |
left, right = 0, len(nums) | |
assert 0 <= k < right # invariant | |
while right - left != 1: | |
pivot_idx = random.randint(left, right - 1) | |
new_pivot_idx = partition(nums, left, right, pivot_idx) | |
if new_pivot_idx == k: | |
return nums[new_pivot_idx] | |
elif new_pivot_idx > k: | |
right = new_pivot_idx | |
else: # new_pivot_idx < k | |
left = new_pivot_idx + 1 | |
return nums[left] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment