Skip to content

Instantly share code, notes, and snippets.

@ZXYFrank
Created January 30, 2023 02:28
Show Gist options
  • Save ZXYFrank/50e2c033e2bdaa5d566f87fdf71f823d to your computer and use it in GitHub Desktop.
Save ZXYFrank/50e2c033e2bdaa5d566f87fdf71f823d to your computer and use it in GitHub Desktop.
quick sort: arbitrary pivoting
import numpy as np
import random
from tqdm import trange
a = [3, 1, 2, 5, 6, 1, 7, 3, 4, 2, 6, 8, 1, 3, 2, 6, 1]
SHOW = False
def show_ptrs(a, l, r, *pos):
_list = [str(_) for _ in a]
for p in pos:
_list[p] = f'"{_list[p]}"'
_list[l] = f'[{_list[l]}'
_list[r] = f'{_list[r]}]'
print(' '.join(_list))
def show_change(a, l, r, i, j):
_list = [str(_) for _ in a]
_list[i] = f'({_list[i]})'
_list[j] = f'({_list[j]})'
_list[l] = f'[{_list[l]}'
_list[r] = f'{_list[r]}]'
print(' '.join(_list))
def qs_right_first(a, l, r):
if l >= r:
return
pivot = a[l]
i = l
j = r
while i < j:
while i < j and a[j] >= pivot:
j -= 1
while i < j and a[i] <= pivot:
i += 1
if SHOW:
show_change(a, l, r, i, j)
a[i], a[j] = a[j], a[i]
a[l], a[i] = a[i], a[l]
if SHOW:
show_ptrs(a, l, r, i)
qs_right_first(a, i+1, r)
qs_right_first(a, l, i-1)
def qs_left_first(a, l, r):
if l >= r:
return
pivot = a[l]
i = l
j = r
while i < j:
while i < j and a[i] <= pivot:
i += 1
while i < j and a[j] >= pivot:
j -= 1
if SHOW:
show_change(a, l, r, i, j)
a[i], a[j] = a[j], a[i]
# if i,j does not meet exchange is unproblematic
# However, if they meet, we need to check
if a[i] > pivot:
i = i-1
if SHOW:
print(
f'cannot put pivot a[{l}] = {a[l]} at a[{i +1 }] = {a[i + 1]}, i--')
show_ptrs(a, l, r, i)
# move the pivot to its location
a[l], a[i] = a[i], a[l]
if SHOW:
show_ptrs(a, l, r, i)
qs_right_first(a, i+1, r)
qs_right_first(a, l, i-1)
def qs_lomuto(a, l, r):
if l >= r:
return
pivot = a[l]
pos = l
for i in range(l + 1, r + 1):
if a[i] <= pivot:
pos += 1
if SHOW:
show_change(a, l, r, pos, i)
a[pos], a[i] = a[i], a[pos]
# move the pivot to its location
a[l], a[pos] = a[pos], a[l]
if SHOW:
show_ptrs(a, l, r, pos)
qs_lomuto(a, pos+1, r)
qs_lomuto(a, l, pos-1)
def qs_anywhere(a, l, r):
if l >= r:
return
i = l
j = r
pivot_ind = random.randint(l, r)
if SHOW:
print(f'pivot is a[{pivot_ind}] = {a[pivot_ind]}')
while i < j:
while i < j and a[i] <= a[pivot_ind]:
i += 1
while i < j and a[j] >= a[pivot_ind]:
j -= 1
if SHOW:
show_change(a, l, r, i, j)
a[i], a[j] = a[j], a[i]
# if i,j does not meet, exchange is unproblematic
# However, if they meet, we need to check whether swaping the pivot and a[i] will INTRODUCE a[i] to a wrong interval
if a[pivot_ind] < a[i] and pivot_ind < i:
if SHOW:
print(
f'smaller pivot a[{pivot_ind}] = {a[pivot_ind]} at the left of a[{i}] = {a[i]}, i = max(l, i-1)')
i = max(l, i-1)
elif a[i] < a[pivot_ind] and i < pivot_ind:
if SHOW:
print(
f'greater pivot a[{pivot_ind}] = {a[pivot_ind]} at the right of a[{i}] = {a[i]}, i = min(i+1, r)')
i = min(i+1, r)
else:
if SHOW:
print(f'a[{i}] = {a[i]} can be safely swapped')
if SHOW:
show_ptrs(a, l, r, i, pivot_ind)
a[i], a[pivot_ind] = a[pivot_ind], a[i]
if SHOW:
show_ptrs(a, l, r, i)
print('---')
qs_anywhere(a, i+1, r)
qs_anywhere(a, l, i-1)
if __name__ == "__main__":
# K = 10
# BOUND = 20
# random.seed(781935)
# SHOW = True
# a = [random.randint(0, BOUND) for _ in range(K)]
# a = np.array(a)
# # a = np.array([7,1,8,3,5])
# _a = a.copy()
# # qs_right_first(_a, 0, len(_a) - 1)
# qs_left_first(_a, 0, len(_a) - 1)
# a.sort()
# print((a == _a).all())
# a = [random.randint(0, BOUND) for _ in range(K)]
# a = np.array(a)
# _a = a.copy()
# qs_lomuto(_a, 0, len(_a) - 1)
# a.sort()
# print((a == _a).all())
# K = 1000
# BOUND = 50
# random.seed(781935)
# SHOW = False
# ans = True
# for i in trange(500):
# a = [random.randint(0, BOUND) for _ in range(K)]
# a = np.array(a)
# _a = a.copy()
# # qs_right_first(_a, 0, len(_a) - 1)
# qs_left_first(_a, 0, len(_a) - 1)
# a.sort()
# ans = np.logical_and(ans, ((a == _a).all()))
# print('qs_left_first', ans)
# K = 1000
# BOUND = 50
# random.seed(781935)
# SHOW = False
# ans = True
# for i in trange(500):
# a = [random.randint(0, BOUND) for _ in range(K)]
# a = np.array(a)
# _a = a.copy()
# qs_lomuto(_a, 0, len(_a) - 1)
# a.sort()
# ans = np.logical_and(ans, ((a == _a).all()))
# print('qs_lomuto', ans)
# K = 15
# BOUND = 20
# random.seed(781935)
# SHOW = True
# a = [random.randint(0, BOUND) for _ in range(K)]
# a = np.array(a)
# _a = a.copy()
# # qs_right_first(_a, 0, len(_a) - 1)
# qs_anywhere(_a, 0, len(_a) - 1)
# a.sort()
# print((a == _a).all())
K = 10
BOUND = 10
random.seed(781935)
SHOW = True
ans = True
for i in trange(5):
a = [random.randint(0, BOUND) for _ in range(K)]
a = np.array(a)
_a = a.copy()
qs_anywhere(_a, 0, len(_a) - 1)
ori_a = a.copy()
a.sort()
ans = np.logical_and(ans, ((a == _a).all()))
if ans == False:
break
print('================pass================')
print('qs_anywhere', ans)
K = 1000
BOUND = 500
random.seed(781935)
SHOW = False
ans = True
for i in trange(500):
a = [random.randint(0, BOUND) for _ in range(K)]
a = np.array(a)
_a = a.copy()
qs_anywhere(_a, 0, len(_a) - 1)
ori_a = a.copy()
a.sort()
ans = np.logical_and(ans, ((a == _a).all()))
if ans == False:
break
print('qs_anywhere', ans)
@ZXYFrank
Copy link
Author

快速排序

Partition的含义是让某个基准元素归位

排序算法——快速排序(Quicksort)基准值的三种选取和优化方法

  • 左侧基准
  • 右侧先走
  • 指针小于(严格)
    • i < j
  • 值等于(非严格)
    • a[j] >= pivot,a[i] <= pivot

如果值比较的时候是严格小于或者大于,那么遇到相等数值的时候,指针是无法移动的

1,2,3,1,1
^       ^
|       |
int partition(vector<int> a, int l, int r) {
    if (l >= r) return;
    /* int mid = (l + r) / 2;
    swap(a[mid], a[l]); */
    int pivot = a[l];
    int i = l, j = r;

    while (l < r) {
        while (a[j] >= pivot && i < j) j--;
        while (a[i] <= pivot && i < j) i++;
        swap(a[i], a[j]);
    }
    swap(a[i], a[l]);
    return i;
}
void quicksort(vector<int> a) {
    qs(a, 0, a.size() - 1);
}
void qs(vector<int> a, int l, int r) {
    if (l >= r) return;
    int p = partition(a, l, r);
    qs(a, p + 1, r);
    qs(a, l, p - 1);
}

问题

以下分析以pivot = a[l]为例

为什么指针相遇之后需要交换,而不是覆盖

如果不交换,直接把 pivot 放入相关位置的话,会有一个元素a[i]被覆盖掉

那么就需要交换,关键问题在于,a[i]可以直接交换吗

这个和下一个为什么右边的先移动这个问题有关

为什么右边的先移动

Wiki 当中,算法发明者的原始实现版本就不是右边先移动的

另外,国外版本通常以最右侧值作为枢值,也就要左边先移动了

有一些解释,但是根本原因,还是在于 Partition 函数的作用

我们扫描数组的最终目的,是找到一个位置,安放基准值。

更准确地说,是把基准值和某个值交换位置,这个交换不可以破坏 ij 已扫描过的区间有序性

其实不管哪一边先走,都可以满足如下语义

如果i != ja[l:i](闭区间) 所有元素 <= a[l]
如果i != ja[j:r](闭区间) 所有元素 >= a[l]

但是一旦指针相遇,语义就不确定了

如果让 j 先走,相遇的时候有两种情况

while (l < r) {
    // [1] j向左碰到了i,此时i这处一定是检验过的
    // 因此满足 a[i](即a[j]) <= a[l]
    while (a[j] <= pivot && i < j) j--;
    // [2] i向右碰到了j,我们假设这种情况可以成立
    // 那么此时j已经停在了一个 a[j]<=a[l] 的地方
    // 那么此时如果指针相撞,i == j
    // a[i] <= a[l] 的条件自然是满足的
    while (a[i] <= pivot && i < j) i++;
    swap(a[i], a[j]);
}

我们也可以让左边先走

只不过这个时候需要单独验证,最后的 pivot 位置是否满足要求

下面我们用 Python 脚本验证一下(便于打印)

import numpy as np
import random
from tqdm import trange

a = [3, 1, 2, 5, 6, 1, 7, 3, 4, 2, 6, 8, 1, 3, 2, 6, 1]

SHOW = False


def show_ptr(a, l, r, pos):
    _list = [str(_) for _ in a]
    _list[pos] = f'"{_list[pos]}"'
    _list[l] = f'[{_list[l]}'
    _list[r] = f'{_list[r]}]'
    print(' '.join(_list))


def show_change(a, l, r, i, j):
    _list = [str(_) for _ in a]
    _list[i] = f'({_list[i]})'
    _list[j] = f'({_list[j]})'
    _list[l] = f'[{_list[l]}'
    _list[r] = f'{_list[r]}]'
    print(' '.join(_list))


def qs_right_first(a, l, r):
    if l >= r:
        return
    pivot = a[l]
    i = l
    j = r
    while i < j:
        while i < j and a[j] >= pivot:
            j -= 1
        while i < j and a[i] <= pivot:
            i += 1
        if SHOW:
            show_change(a, l, r, i, j)
        a[i], a[j] = a[j], a[i]

    # # wrong
    # pivot, a[i] = a[i], pivot
    # right
    a[l], a[i] = a[i], a[l]

    if SHOW:
        show_ptr(a, l, r, i)
    qs_right_first(a, i+1, r)
    qs_right_first(a, l, i-1)


def qs_left_first(a, l, r):
    if l >= r:
        return
    pivot = a[l]
    i = l
    j = r
    while i < j:
        while i < j and a[i] <= pivot:
            i += 1
        while i < j and a[j] >= pivot:
            j -= 1
        if SHOW:
            show_change(a, l, r, i, j)
        a[i], a[j] = a[j], a[i]

    # if i,j does not meet exchange is unproblematic
    # However, if they meet, we need to check
    if a[i] > pivot:
        i = i-1
        if SHOW:
            print(
                f'cannot put pivot a[{l}] = {a[l]} at a[{i +1 }] = {a[i + 1]}, i--')
            show_ptr(a, l, r, i)

    # move the pivot to its location
    a[l], a[i] = a[i], a[l]

    if SHOW:
        show_ptr(a, l, r, i)
    qs_right_first(a, i+1, r)
    qs_right_first(a, l, i-1)


if __name__ == "__main__":
    K = 10
    BOUND = 20
    random.seed(781935)
    SHOW = True
    a = [random.randint(0, BOUND) for _ in range(K)]
    a = np.array(a)

    # a = np.array([7,1,8,3,5])

    _a = a.copy()
    # qs_right_first(_a, 0, len(_a) - 1)
    qs_left_first(_a, 0, len(_a) - 1)
    a.sort()
    print((a == _a).all())

    K = 1000
    BOUND = 50
    random.seed(781935)
    SHOW = False
    ans = True
    for i in trange(500):
        a = [random.randint(0, BOUND) for _ in range(K)]
        a = np.array(a)
        _a = a.copy()
        # qs_right_first(_a, 0, len(_a) - 1)
        qs_left_first(_a, 0, len(_a) - 1)
        a.sort()
        ans = np.logical_and(ans, ((a == _a).all()))
    print('qs_left_first', ans)
[17 10 3 5 10 10 10 0 (18) (12)]
# 18,12 交换之后,i先走一步,撞到了j(在len-1处)
[17 10 3 5 10 10 10 0 12 ((18))]
# 此时从while退出,不知道这个位置能不能作为pivot的归位处,因此需要进行判断
cannot put pivot a[0] = 17 at a[9] = 18, i--
# 向后一步,这一片区域都是i扫过的区域,因此pivot一定可以归位
[17 10 3 5 10 10 10 0 "12" 18]

[12 10 3 5 10 10 10 0 "17" 18]
[12 10 3 5 10 10 10 ((0))] 17 18
[0 10 3 5 10 10 10 "12"] 17 18
[((0)) 10 3 5 10 10 10] 12 17 18
["0" 10 3 5 10 10 10] 12 17 18
0 [10 3 ((5)) 10 10 10] 12 17 18
0 [5 3 "10" 10 10 10] 12 17 18
0 5 3 10 [((10)) 10 10] 12 17 18
0 5 3 10 ["10" 10 10] 12 17 18
0 5 3 10 10 [((10)) 10] 12 17 18
0 5 3 10 10 ["10" 10] 12 17 18
0 [5 ((3))] 10 10 10 10 12 17 18
0 [3 "5"] 10 10 10 10 12 17 18

更进一步,是否可以任意选取枢轴值

  • 以枢轴值为标准扫描
  • 枢轴值归位

分析

i,j 相遇,左右两侧的序列一定是满足要求的

最后返回的也一定是i=j这个位置,其他所有位置的值都已经通过了检验

另外,任意取pivota[pivot_ind]一定是在原位的

  • 因为移动只对应i != j的情况,而i,j不会在pivot处停留
  • 而如果相撞在pivot_ind处,自然也没有影响

不过,这个位置的值本身和pivot的关系,没有通过检验

  • 如果相等,i这个位置就自然满足要求
  • 如果不相等,我们需要把a[i]换成pivot,因为我们的扫描过程是以枢轴值为标准的

a[i]需要和pivot进行交换

代码

pivot is a[9] = 2
[(6) 1 (1) 2 7 8 3 7 9 2]
[1 1 ((6)) 2 7 8 3 7 9 2]
a[2] = 6 can be safely swapped
[1 1 "6" 2 7 8 3 7 9 "2"]
[1 1 "2" 2 7 8 3 7 9 6]
---
pivot is a[5] = 8
1 1 2 [2 7 8 3 7 (9) (6)]
1 1 2 [2 7 8 3 7 6 ((9))]
smaller pivot a[5] = 8 at the left of a[9] = 9, i = max(l, i-1)
1 1 2 [2 7 "8" 3 7 "6" 9]
1 1 2 [2 7 6 3 7 "8" 9]
---
pivot is a[5] = 6
1 1 2 [2 (7) 6 (3) 7] 8 9
1 1 2 [2 3 6 ((7)) 7] 8 9
smaller pivot a[5] = 6 at the left of a[6] = 7, i = max(l, i-1)
1 1 2 [2 3 ""6"" 7 7] 8 9
1 1 2 [2 3 "6" 7 7] 8 9
---
pivot is a[6] = 7
1 1 2 2 3 6 [7 ((7))] 8 9
a[7] = 7 can be safely swapped
1 1 2 2 3 6 ["7" "7"] 8 9
1 1 2 2 3 6 [7 "7"] 8 9
---
pivot is a[3] = 2
1 1 2 [2 ((3))] 6 7 7 8 9
smaller pivot a[3] = 2 at the left of a[4] = 3, i = max(l, i-1)
1 1 2 [""2"" 3] 6 7 7 8 9
1 1 2 ["2" 3] 6 7 7 8 9
---
pivot is a[0] = 1
[1 ((1))] 2 2 3 6 7 7 8 9
a[1] = 1 can be safely swapped
["1" "1"] 2 2 3 6 7 7 8 9
[1 "1"] 2 2 3 6 7 7 8 9
---
================pass================

smaller or greater

🧐 你发现了吗,所有的提示都是 smaller

smaller pivot a[3] = 2 at the left of a[4] = 3

如果我们让右边先走

while i < j:
    while i < j and a[j] >= a[pivot_ind]:
        j -= 1
    while i < j and a[i] <= a[pivot_ind]:
        i += 1

就会全部变成 greater

pivot is a[9] = 7
[(10) 4 6 1 1 2 7 8 (3) 7]
[3 4 6 1 1 ((2)) 7 8 10 7]
greater pivot a[9] = 7 at the right of a[5] = 2, i = min(i+1, r)
[3 4 6 1 1 2 "7" 8 10 "7"]
[3 4 6 1 1 2 "7" 8 10 7]
---
pivot is a[7] = 8
3 4 6 1 1 2 7 [8 (10) (7)]
3 4 6 1 1 2 7 [8 ((7)) 10]
a[8] = 7 can be safely swapped
3 4 6 1 1 2 7 ["8" "7" 10]
3 4 6 1 1 2 7 [7 "8" 10]
---
pivot is a[4] = 1
[((3)) 4 6 1 1 2] 7 7 8 10
a[0] = 3 can be safely swapped
["3" 4 6 1 "1" 2] 7 7 8 10
["1" 4 6 1 3 2] 7 7 8 10
---
pivot is a[3] = 1
1 [((4)) 6 1 3 2] 7 7 8 10
a[1] = 4 can be safely swapped
1 ["4" 6 "1" 3 2] 7 7 8 10
1 ["1" 6 4 3 2] 7 7 8 10
---
pivot is a[4] = 3
1 1 [(6) 4 3 (2)] 7 7 8 10
1 1 [((2)) 4 3 6] 7 7 8 10
greater pivot a[4] = 3 at the right of a[2] = 2, i = min(i+1, r)
1 1 [2 "4" "3" 6] 7 7 8 10
1 1 [2 "3" 4 6] 7 7 8 10
---
pivot is a[4] = 4
1 1 2 3 [((4)) 6] 7 7 8 10
a[4] = 4 can be safely swapped
1 1 2 3 [""4"" 6] 7 7 8 10
1 1 2 3 ["4" 6] 7 7 8 10
---

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment