快速排序:为什么右边先移动?任意枢轴值选取

2025-06-06 13:08:31

快速排序

gist 可运行代码

Partition的含义是让某个基准元素归位

排序算法——快速排序(Quicksort)基准值的三种选取和优化方法

左侧基准

右侧先走

指针小于(严格)

i < j

值等于(非严格)

a[j] >= pivot,a[i] <= pivot

如果值比较的时候是严格小于或者大于,那么遇到相等数值的时候,指针是无法移动的

1,2,3,1,1

^ ^

| |

int partition(vector a, int l, int r) {

if (l >= r) return;

/* int mid = (l + r) / 2;

swap(a[mid], a[l]); */

int pivot = a[l];

int i = l, j = r;

while (l < r) {

while (a[j] >= pivot && i < j) j--;

while (a[i] <= pivot && i < j) i++;

swap(a[i], a[j]);

}

swap(a[i], a[l]);

return i;

}

void quicksort(vector a) {

qs(a, 0, a.size() - 1);

}

void qs(vector a, int l, int r) {

if (l >= r) return;

int p = partition(a, l, r);

qs(a, p + 1, r);

qs(a, l, p - 1);

}

以下分析以pivot = a[l]为例

为什么指针相遇之后需要交换,而不是覆盖

如果不交换,直接把 pivot 放入相关位置的话,会有一个元素a[i]被覆盖掉

那么就需要交换,关键问题在于,a[i]可以直接交换吗

这个和下一个为什么右边的先移动这个问题有关

为什么右边的先移动

在 Wiki 当中,算法发明者的原始实现版本就不是右边先移动的

另外,国外版本通常以最右侧值作为枢值,也就要左边先移动了

有一些解释,但是根本原因,还是在于 Partition 函数的作用

我们扫描数组的最终目的,是找到一个位置,安放基准值。

更准确地说,是把基准值和某个值交换位置,这个交换不可以破坏 ij 已扫描过的区间有序性

其实不管哪一边先走,都可以满足如下语义

如果i != j,a[l:i](闭区间) 所有元素 <= a[l]

如果i != j,a[j:r](闭区间) 所有元素 >= a[l]

但是一旦指针相遇,语义就不确定了

如果让 j 先走,相遇的时候有两种情况

while (l < r) {

// [1] j向左碰到了i,此时i这处一定是检验过的

// 因此满足 a[i](即a[j]) <= a[l]

while (a[j] <= pivot && i < j) j--;

// [2] i向右碰到了j,我们假设这种情况可以成立

// 那么此时j已经停在了一个 a[j]<=a[l] 的地方

// 那么此时如果指针相撞,i == j

// a[i] <= a[l] 的条件自然是满足的

while (a[i] <= pivot && i < j) i++;

swap(a[i], a[j]);

}

我们也可以让左边先走

只不过这个时候需要单独验证,最后的 pivot 位置是否满足要求

下面我们用 Python 脚本验证一下(便于打印)

import numpy as np

import random

from tqdm import trange

a = [3, 1, 2, 5, 6, 1, 7, 3, 4, 2, 6, 8, 1, 3, 2, 6, 1]

SHOW = False

def show_ptr(a, l, r, pos):

_list = [str(_) for _ in a]

_list[pos] = f'"{_list[pos]}"'

_list[l] = f'[{_list[l]}'

_list[r] = f'{_list[r]}]'

print(' '.join(_list))

def show_change(a, l, r, i, j):

_list = [str(_) for _ in a]

_list[i] = f'({_list[i]})'

_list[j] = f'({_list[j]})'

_list[l] = f'[{_list[l]}'

_list[r] = f'{_list[r]}]'

print(' '.join(_list))

def qs_right_first(a, l, r):

if l >= r:

return

pivot = a[l]

i = l

j = r

while i < j:

while i < j and a[j] >= pivot:

j -= 1

while i < j and a[i] <= pivot:

i += 1

if SHOW:

show_change(a, l, r, i, j)

a[i], a[j] = a[j], a[i]

# # wrong

# pivot, a[i] = a[i], pivot

# right

a[l], a[i] = a[i], a[l]

if SHOW:

show_ptr(a, l, r, i)

qs_right_first(a, i+1, r)

qs_right_first(a, l, i-1)

def qs_left_first(a, l, r):

if l >= r:

return

pivot = a[l]

i = l

j = r

while i < j:

while i < j and a[i] <= pivot:

i += 1

while i < j and a[j] >= pivot:

j -= 1

if SHOW:

show_change(a, l, r, i, j)

a[i], a[j] = a[j], a[i]

# if i,j does not meet exchange is unproblematic

# However, if they meet, we need to check

if a[i] > pivot:

i = i-1

if SHOW:

print(

f'cannot put pivot a[{l}] = {a[l]} at a[{i +1 }] = {a[i + 1]}, i--')

show_ptr(a, l, r, i)

# move the pivot to its location

a[l], a[i] = a[i], a[l]

if SHOW:

show_ptr(a, l, r, i)

qs_right_first(a, i+1, r)

qs_right_first(a, l, i-1)

if __name__ == "__main__":

K = 10

BOUND = 20

random.seed(781935)

SHOW = True

a = [random.randint(0, BOUND) for _ in range(K)]

a = np.array(a)

# a = np.array([7,1,8,3,5])

_a = a.copy()

# qs_right_first(_a, 0, len(_a) - 1)

qs_left_first(_a, 0, len(_a) - 1)

a.sort()

print((a == _a).all())

K = 1000

BOUND = 50

random.seed(781935)

SHOW = False

ans = True

for i in trange(500):

a = [random.randint(0, BOUND) for _ in range(K)]

a = np.array(a)

_a = a.copy()

# qs_right_first(_a, 0, len(_a) - 1)

qs_left_first(_a, 0, len(_a) - 1)

a.sort()

ans = np.logical_and(ans, ((a == _a).all()))

print('qs_left_first', ans)

[17 10 3 5 10 10 10 0 (18) (12)]

# 18,12 交换之后,i先走一步,撞到了j(在len-1处)

[17 10 3 5 10 10 10 0 12 ((18))]

# 此时从while退出,不知道这个位置能不能作为pivot的归位处,因此需要进行判断

cannot put pivot a[0] = 17 at a[9] = 18, i--

# 向后一步,这一片区域都是i扫过的区域,因此pivot一定可以归位

[17 10 3 5 10 10 10 0 "12" 18]

[12 10 3 5 10 10 10 0 "17" 18]

[12 10 3 5 10 10 10 ((0))] 17 18

[0 10 3 5 10 10 10 "12"] 17 18

[((0)) 10 3 5 10 10 10] 12 17 18

["0" 10 3 5 10 10 10] 12 17 18

0 [10 3 ((5)) 10 10 10] 12 17 18

0 [5 3 "10" 10 10 10] 12 17 18

0 5 3 10 [((10)) 10 10] 12 17 18

0 5 3 10 ["10" 10 10] 12 17 18

0 5 3 10 10 [((10)) 10] 12 17 18

0 5 3 10 10 ["10" 10] 12 17 18

0 [5 ((3))] 10 10 10 10 12 17 18

0 [3 "5"] 10 10 10 10 12 17 18

更进一步,是否可以任意选取枢轴值

以枢轴值为标准扫描

枢轴值归位

分析

i,j 相遇,左右两侧的序列一定是满足要求的

最后返回的也一定是i=j这个位置,其他所有位置的值都已经通过了检验

另外,任意取pivot,a[pivot_ind]一定是在原位的

因为移动只对应i != j的情况,而i,j不会在pivot处停留

而如果相撞在pivot_ind处,自然也没有影响

不过,这个位置的值本身和pivot的关系,没有通过检验

如果相等,i这个位置就自然满足要求

如果不相等,我们需要把a[i]换成pivot,因为我们的扫描过程是以枢轴值为标准的

a[i]需要和pivot进行交换

代码

pivot is a[9] = 2

[(6) 1 (1) 2 7 8 3 7 9 2]

[1 1 ((6)) 2 7 8 3 7 9 2]

a[2] = 6 can be safely swapped

[1 1 "6" 2 7 8 3 7 9 "2"]

[1 1 "2" 2 7 8 3 7 9 6]

---

pivot is a[5] = 8

1 1 2 [2 7 8 3 7 (9) (6)]

1 1 2 [2 7 8 3 7 6 ((9))]

smaller pivot a[5] = 8 at the left of a[9] = 9, i = max(l, i-1)

1 1 2 [2 7 "8" 3 7 "6" 9]

1 1 2 [2 7 6 3 7 "8" 9]

---

pivot is a[5] = 6

1 1 2 [2 (7) 6 (3) 7] 8 9

1 1 2 [2 3 6 ((7)) 7] 8 9

smaller pivot a[5] = 6 at the left of a[6] = 7, i = max(l, i-1)

1 1 2 [2 3 ""6"" 7 7] 8 9

1 1 2 [2 3 "6" 7 7] 8 9

---

pivot is a[6] = 7

1 1 2 2 3 6 [7 ((7))] 8 9

a[7] = 7 can be safely swapped

1 1 2 2 3 6 ["7" "7"] 8 9

1 1 2 2 3 6 [7 "7"] 8 9

---

pivot is a[3] = 2

1 1 2 [2 ((3))] 6 7 7 8 9

smaller pivot a[3] = 2 at the left of a[4] = 3, i = max(l, i-1)

1 1 2 [""2"" 3] 6 7 7 8 9

1 1 2 ["2" 3] 6 7 7 8 9

---

pivot is a[0] = 1

[1 ((1))] 2 2 3 6 7 7 8 9

a[1] = 1 can be safely swapped

["1" "1"] 2 2 3 6 7 7 8 9

[1 "1"] 2 2 3 6 7 7 8 9

---

================pass================

smaller or greater

🧐 你发现了吗,所有的提示都是 smaller

smaller pivot a[3] = 2 at the left of a[4] = 3

如果我们让右边先走

while i < j:

while i < j and a[j] >= a[pivot_ind]:

j -= 1

while i < j and a[i] <= a[pivot_ind]:

i += 1

就会全部变成 greater

pivot is a[9] = 7

[(10) 4 6 1 1 2 7 8 (3) 7]

[3 4 6 1 1 ((2)) 7 8 10 7]

greater pivot a[9] = 7 at the right of a[5] = 2, i = min(i+1, r)

[3 4 6 1 1 2 "7" 8 10 "7"]

[3 4 6 1 1 2 "7" 8 10 7]

---

pivot is a[7] = 8

3 4 6 1 1 2 7 [8 (10) (7)]

3 4 6 1 1 2 7 [8 ((7)) 10]

a[8] = 7 can be safely swapped

3 4 6 1 1 2 7 ["8" "7" 10]

3 4 6 1 1 2 7 [7 "8" 10]

---

pivot is a[4] = 1

[((3)) 4 6 1 1 2] 7 7 8 10

a[0] = 3 can be safely swapped

["3" 4 6 1 "1" 2] 7 7 8 10

["1" 4 6 1 3 2] 7 7 8 10

---

pivot is a[3] = 1

1 [((4)) 6 1 3 2] 7 7 8 10

a[1] = 4 can be safely swapped

1 ["4" 6 "1" 3 2] 7 7 8 10

1 ["1" 6 4 3 2] 7 7 8 10

---

pivot is a[4] = 3

1 1 [(6) 4 3 (2)] 7 7 8 10

1 1 [((2)) 4 3 6] 7 7 8 10

greater pivot a[4] = 3 at the right of a[2] = 2, i = min(i+1, r)

1 1 [2 "4" "3" 6] 7 7 8 10

1 1 [2 "3" 4 6] 7 7 8 10

---

pivot is a[4] = 4

1 1 2 3 [((4)) 6] 7 7 8 10

a[4] = 4 can be safely swapped

1 1 2 3 [""4"" 6] 7 7 8 10

1 1 2 3 ["4" 6] 7 7 8 10

---