Created
January 30, 2023 02:28
-
-
Save ZXYFrank/50e2c033e2bdaa5d566f87fdf71f823d to your computer and use it in GitHub Desktop.
quick sort: arbitrary pivoting
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import random | |
from tqdm import trange | |
a = [3, 1, 2, 5, 6, 1, 7, 3, 4, 2, 6, 8, 1, 3, 2, 6, 1] | |
SHOW = False | |
def show_ptrs(a, l, r, *pos): | |
_list = [str(_) for _ in a] | |
for p in pos: | |
_list[p] = f'"{_list[p]}"' | |
_list[l] = f'[{_list[l]}' | |
_list[r] = f'{_list[r]}]' | |
print(' '.join(_list)) | |
def show_change(a, l, r, i, j): | |
_list = [str(_) for _ in a] | |
_list[i] = f'({_list[i]})' | |
_list[j] = f'({_list[j]})' | |
_list[l] = f'[{_list[l]}' | |
_list[r] = f'{_list[r]}]' | |
print(' '.join(_list)) | |
def qs_right_first(a, l, r): | |
if l >= r: | |
return | |
pivot = a[l] | |
i = l | |
j = r | |
while i < j: | |
while i < j and a[j] >= pivot: | |
j -= 1 | |
while i < j and a[i] <= pivot: | |
i += 1 | |
if SHOW: | |
show_change(a, l, r, i, j) | |
a[i], a[j] = a[j], a[i] | |
a[l], a[i] = a[i], a[l] | |
if SHOW: | |
show_ptrs(a, l, r, i) | |
qs_right_first(a, i+1, r) | |
qs_right_first(a, l, i-1) | |
def qs_left_first(a, l, r): | |
if l >= r: | |
return | |
pivot = a[l] | |
i = l | |
j = r | |
while i < j: | |
while i < j and a[i] <= pivot: | |
i += 1 | |
while i < j and a[j] >= pivot: | |
j -= 1 | |
if SHOW: | |
show_change(a, l, r, i, j) | |
a[i], a[j] = a[j], a[i] | |
# if i,j does not meet exchange is unproblematic | |
# However, if they meet, we need to check | |
if a[i] > pivot: | |
i = i-1 | |
if SHOW: | |
print( | |
f'cannot put pivot a[{l}] = {a[l]} at a[{i +1 }] = {a[i + 1]}, i--') | |
show_ptrs(a, l, r, i) | |
# move the pivot to its location | |
a[l], a[i] = a[i], a[l] | |
if SHOW: | |
show_ptrs(a, l, r, i) | |
qs_right_first(a, i+1, r) | |
qs_right_first(a, l, i-1) | |
def qs_lomuto(a, l, r): | |
if l >= r: | |
return | |
pivot = a[l] | |
pos = l | |
for i in range(l + 1, r + 1): | |
if a[i] <= pivot: | |
pos += 1 | |
if SHOW: | |
show_change(a, l, r, pos, i) | |
a[pos], a[i] = a[i], a[pos] | |
# move the pivot to its location | |
a[l], a[pos] = a[pos], a[l] | |
if SHOW: | |
show_ptrs(a, l, r, pos) | |
qs_lomuto(a, pos+1, r) | |
qs_lomuto(a, l, pos-1) | |
def qs_anywhere(a, l, r): | |
if l >= r: | |
return | |
i = l | |
j = r | |
pivot_ind = random.randint(l, r) | |
if SHOW: | |
print(f'pivot is a[{pivot_ind}] = {a[pivot_ind]}') | |
while i < j: | |
while i < j and a[i] <= a[pivot_ind]: | |
i += 1 | |
while i < j and a[j] >= a[pivot_ind]: | |
j -= 1 | |
if SHOW: | |
show_change(a, l, r, i, j) | |
a[i], a[j] = a[j], a[i] | |
# if i,j does not meet, exchange is unproblematic | |
# However, if they meet, we need to check whether swaping the pivot and a[i] will INTRODUCE a[i] to a wrong interval | |
if a[pivot_ind] < a[i] and pivot_ind < i: | |
if SHOW: | |
print( | |
f'smaller pivot a[{pivot_ind}] = {a[pivot_ind]} at the left of a[{i}] = {a[i]}, i = max(l, i-1)') | |
i = max(l, i-1) | |
elif a[i] < a[pivot_ind] and i < pivot_ind: | |
if SHOW: | |
print( | |
f'greater pivot a[{pivot_ind}] = {a[pivot_ind]} at the right of a[{i}] = {a[i]}, i = min(i+1, r)') | |
i = min(i+1, r) | |
else: | |
if SHOW: | |
print(f'a[{i}] = {a[i]} can be safely swapped') | |
if SHOW: | |
show_ptrs(a, l, r, i, pivot_ind) | |
a[i], a[pivot_ind] = a[pivot_ind], a[i] | |
if SHOW: | |
show_ptrs(a, l, r, i) | |
print('---') | |
qs_anywhere(a, i+1, r) | |
qs_anywhere(a, l, i-1) | |
if __name__ == "__main__": | |
# K = 10 | |
# BOUND = 20 | |
# random.seed(781935) | |
# SHOW = True | |
# a = [random.randint(0, BOUND) for _ in range(K)] | |
# a = np.array(a) | |
# # a = np.array([7,1,8,3,5]) | |
# _a = a.copy() | |
# # qs_right_first(_a, 0, len(_a) - 1) | |
# qs_left_first(_a, 0, len(_a) - 1) | |
# a.sort() | |
# print((a == _a).all()) | |
# a = [random.randint(0, BOUND) for _ in range(K)] | |
# a = np.array(a) | |
# _a = a.copy() | |
# qs_lomuto(_a, 0, len(_a) - 1) | |
# a.sort() | |
# print((a == _a).all()) | |
# K = 1000 | |
# BOUND = 50 | |
# random.seed(781935) | |
# SHOW = False | |
# ans = True | |
# for i in trange(500): | |
# a = [random.randint(0, BOUND) for _ in range(K)] | |
# a = np.array(a) | |
# _a = a.copy() | |
# # qs_right_first(_a, 0, len(_a) - 1) | |
# qs_left_first(_a, 0, len(_a) - 1) | |
# a.sort() | |
# ans = np.logical_and(ans, ((a == _a).all())) | |
# print('qs_left_first', ans) | |
# K = 1000 | |
# BOUND = 50 | |
# random.seed(781935) | |
# SHOW = False | |
# ans = True | |
# for i in trange(500): | |
# a = [random.randint(0, BOUND) for _ in range(K)] | |
# a = np.array(a) | |
# _a = a.copy() | |
# qs_lomuto(_a, 0, len(_a) - 1) | |
# a.sort() | |
# ans = np.logical_and(ans, ((a == _a).all())) | |
# print('qs_lomuto', ans) | |
# K = 15 | |
# BOUND = 20 | |
# random.seed(781935) | |
# SHOW = True | |
# a = [random.randint(0, BOUND) for _ in range(K)] | |
# a = np.array(a) | |
# _a = a.copy() | |
# # qs_right_first(_a, 0, len(_a) - 1) | |
# qs_anywhere(_a, 0, len(_a) - 1) | |
# a.sort() | |
# print((a == _a).all()) | |
K = 10 | |
BOUND = 10 | |
random.seed(781935) | |
SHOW = True | |
ans = True | |
for i in trange(5): | |
a = [random.randint(0, BOUND) for _ in range(K)] | |
a = np.array(a) | |
_a = a.copy() | |
qs_anywhere(_a, 0, len(_a) - 1) | |
ori_a = a.copy() | |
a.sort() | |
ans = np.logical_and(ans, ((a == _a).all())) | |
if ans == False: | |
break | |
print('================pass================') | |
print('qs_anywhere', ans) | |
K = 1000 | |
BOUND = 500 | |
random.seed(781935) | |
SHOW = False | |
ans = True | |
for i in trange(500): | |
a = [random.randint(0, BOUND) for _ in range(K)] | |
a = np.array(a) | |
_a = a.copy() | |
qs_anywhere(_a, 0, len(_a) - 1) | |
ori_a = a.copy() | |
a.sort() | |
ans = np.logical_and(ans, ((a == _a).all())) | |
if ans == False: | |
break | |
print('qs_anywhere', ans) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
快速排序
Partition
的含义是让某个基准元素
归位排序算法——快速排序(Quicksort)基准值的三种选取和优化方法
i < j
a[j] >= pivot
,a[i] <= pivot
如果值比较的时候是严格小于或者大于,那么遇到相等数值的时候,指针是无法移动的
问题
以下分析以
pivot = a[l]
为例为什么指针相遇之后需要交换,而不是覆盖
如果不交换,直接把 pivot 放入相关位置的话,会有一个元素
a[i]
被覆盖掉那么就需要交换,关键问题在于,
a[i]
可以直接交换吗这个和下一个为什么右边的先移动这个问题有关
为什么右边的先移动
在 Wiki 当中,算法发明者的原始实现版本就不是右边先移动的
另外,国外版本通常以最右侧值作为枢值,也就要左边先移动了
有一些解释,但是根本原因,还是在于 Partition 函数的作用
我们扫描数组的最终目的,是找到一个位置,安放基准值。
更准确地说,是把基准值和某个值交换位置,这个交换不可以破坏 ij 已扫描过的区间有序性
其实不管哪一边先走,都可以满足如下语义
但是一旦指针相遇,语义就不确定了
如果让 j 先走,相遇的时候有两种情况
我们也可以让左边先走
只不过这个时候需要单独验证,最后的 pivot 位置是否满足要求
下面我们用 Python 脚本验证一下(便于打印)
更进一步,是否可以任意选取枢轴值
分析
i,j 相遇,左右两侧的序列一定是满足要求的
最后返回的也一定是
i=j
这个位置,其他所有位置的值都已经通过了检验另外,任意取
pivot
,a[pivot_ind]
一定是在原位的i != j
的情况,而i,j
不会在pivot
处停留pivot_ind
处,自然也没有影响不过,这个位置的值本身和
pivot
的关系,没有通过检验i
这个位置就自然满足要求a[i]
换成pivot
,因为我们的扫描过程是以枢轴值为标准的a[i]
需要和pivot
进行交换代码
smaller or greater
🧐 你发现了吗,所有的提示都是 smaller
如果我们让右边先走
就会全部变成 greater