Algorithms_4th学习:第2.3节 快速排序(Quicksort)—— C++ 视角详解
一、快速排序是什么?
快速排序是实际应用中使用最广泛的排序算法。它的优点:
- 实现不复杂
- 对各种类型的输入数据都表现良好
- 速度比其他所有排序算法都快(典型情况下)
- 原地排序:只需要一个很小的辅助栈,不需要额外的大数组
- 平均时间复杂度 O(NlogN)O(N \log N)O(NlogN)
它唯一的缺点是:实现细节上比较脆弱,稍有不慎就可能退化到 O(N2)O(N^2)O(N2)。
二、基本思路:分治法
快速排序和归并排序都用了"分而治之"的思想,但方向相反:
归并排序:
先把数组分成两半(各自独立)
递归排好左右两半
最后合并(需要额外工作)
快速排序:
先把数组重新排列(分区)
使得分区点左边的都 <= 分区点
使得分区点右边的都 >= 分区点
然后递归排左右两半(不需要合并)
用 ASCII 图示:
原始数组:
[Q U I C K S O R T E X A M P L E]
随机打乱后(防止最坏情况):
[K R A T E L E P U I M Q C X O S]
分区(以 K 为基准):
[E C A I E | K | L P U T M Q R X O S]
^
K 已归位,左边都 <= K,右边都 >= K
递归排左边: [A C E E I]
递归排右边: [L M O P Q R S T U X]
最终结果:
[A C E E I K L M O P Q R S T U X]
三、核心:分区(Partition)操作
分区是快速排序的灵魂。目标是选一个"基准元素",把数组重新排列成:
[左边:都 <= 基准] [基准] [右边:都 >= 基准]
分区过程图解
初始状态:以 a[lo] = K 为基准
i 从左往右扫,j 从右往左扫
[K R A T E L E P U I M Q C X O S]
^ ^
lo=0 hi=15
v(基准=K)
i=1 j=15
步骤1:i 向右扫,找到 a[i]=R >= K(停下)
j 向左扫,找到 a[j]=C <= K(停下)
交换 R 和 C:
[K C A T E L E P U I M Q R X O S]
i=1 j=12
步骤2:i 继续右扫,找到 a[i]=T >= K(停下)
j 继续左扫,找到 a[j]=I <= K(停下)
交换 T 和 I:
[K C A I E L E P U T M Q R X O S]
i=3 j=9
步骤3:i 继续右扫,找到 a[i]=L >= K(停下,i=5)
j 继续左扫,找到 a[j]=E <= K(停下,j=6)
交换 L 和 E:
[K C A I E E L P U T M Q R X O S]
i=5 j=6
步骤4:i 继续右扫,i=6;j 继续左扫,j=5
i >= j,循环结束!
最后:交换 a[lo]=K 与 a[j]=E:
[E C A I E K L P U T M Q R X O S]
^
j=5,K 已归位
分区的 C++ 实现
#include <vector>
#include <algorithm>
// 将 a[lo..hi] 进行分区
// 返回基准元素最终所在的位置 j
// 分区后保证:a[lo..j-1] <= a[j] <= a[j+1..hi]
int partition(std::vector<int>& a, int lo, int hi) {
int i = lo; // 左扫描指针(从 lo+1 开始扫)
int j = hi + 1; // 右扫描指针(从 hi 开始扫)
int v = a[lo]; // 基准元素,选第一个
while (true) {
// 向右扫描:找到第一个 >= v 的元素
// 注意:先 ++i 再比较,所以从 lo+1 开始
while (a[++i] < v) {
if (i == hi) break; // 防止右越界
}
// 向左扫描:找到第一个 <= v 的元素
// 注意:先 --j 再比较,所以从 hi 开始
while (v < a[--j]) {
if (j == lo) break; // 防止左越界(其实多余,因为 a[lo]=v 不会小于 v)
}
// 两个指针相遇或交叉,分区完成
if (i >= j) break;
// 交换:把不该在左边的和不该在右边的互换
std::swap(a[i], a[j]);
}
// 把基准元素放到它最终的位置 j
std::swap(a[lo], a[j]);
return j; // 返回基准元素的最终位置
}
四、完整快速排序(C++ 完整可运行代码)
#include <iostream>
#include <vector>
#include <algorithm> // std::swap
#include <random> // std::shuffle
#include <numeric> // std::iota
// ============================================================
// 分区函数
// 以 a[lo] 为基准,重排 a[lo..hi]
// 返回基准的最终位置 j
// 保证:a[lo..j-1] <= a[j] <= a[j+1..hi]
// ============================================================
int partition(std::vector<int>& a, int lo, int hi) {
int i = lo;
int j = hi + 1;
int v = a[lo]; // 基准值
while (true) {
// 从左往右找第一个 >= v 的
while (a[++i] < v)
if (i == hi) break;
// 从右往左找第一个 <= v 的
while (v < a[--j])
if (j == lo) break;
// 两指针相遇,退出
if (i >= j) break;
std::swap(a[i], a[j]);
}
// 基准归位
std::swap(a[lo], a[j]);
return j;
}
// ============================================================
// 递归排序函数
// 对 a[lo..hi] 进行快速排序
// ============================================================
void quicksortHelper(std::vector<int>& a, int lo, int hi) {
if (hi <= lo) return; // 子数组只有一个元素或为空,无需排序
int j = partition(a, lo, hi); // 分区,j 是基准的最终位置
quicksortHelper(a, lo, j - 1); // 递归排左半段
quicksortHelper(a, j + 1, hi); // 递归排右半段
}
// ============================================================
// 对外接口:先随机打乱,再排序
// 随机打乱是为了避免最坏情况(原始有序数组导致 O(N^2))
// ============================================================
void quicksort(std::vector<int>& a) {
// 使用 C++11 的随机数引擎进行 Fisher-Yates 洗牌
std::random_device rd;
std::mt19937 rng(rd());
std::shuffle(a.begin(), a.end(), rng);
quicksortHelper(a, 0, (int)a.size() - 1);
}
int main() {
std::vector<int> data = {5, 2, 8, 3, 1, 9, 4, 7, 6, 0};
std::cout << "排序前: ";
for (int x : data) std::cout << x << " ";
std::cout << "\n";
quicksort(data);
std::cout << "排序后: ";
for (int x : data) std::cout << x << " ";
std::cout << "\n";
return 0;
}
五、递归调用树(以 QUICKSORTEXAMPLE 为例)
六、性能分析
最好情况
每次分区都把数组恰好分成两半,递推关系为:
CN=2CN/2+NC_N = 2C_{N/2} + NCN=2CN/2+N
这和归并排序相同,解为:
CN∼Nlog2NC_N \sim N \log_2 NCN∼Nlog2N
平均情况(命题 K)
命题 K:快速排序对 NNN 个不同元素的数组,平均需要约 2NlnN2N \ln N2NlnN 次比较(以及约 16\frac{1}{6}61 倍的交换次数)。
推导过程(看不懂可以跳过,记住结论即可):
设 CNC_NCN 为平均比较次数,则有递推式:
CN=(N−1)+1N∑k=0N−1(Ck+CN−1−k)C_N = (N-1) + \frac{1}{N}\sum_{k=0}^{N-1}(C_k + C_{N-1-k})CN=(N−1)+N1k=0∑N−1(Ck+CN−1−k)
- 第一项 (N−1)(N-1)(N−1):每次分区要扫描整个数组,约需 N−1N-1N−1 次比较
- 求和项:基准落在位置 kkk 的概率为 1/N1/N1/N,左边 kkk 个、右边 N−1−kN-1-kN−1−k 个
两边乘以 NNN,整理后得:
NCN=N(N−1)+2(C0+C1+⋯+CN−1)NC_N = N(N-1) + 2(C_0 + C_1 + \cdots + C_{N-1})NCN=N(N−1)+2(C0+C1+⋯+CN−1)
再对 N−1N-1N−1 写同样的式子相减,变形得:
CNN+1=CN−1N+2N+1\frac{C_N}{N+1} = \frac{C_{N-1}}{N} + \frac{2}{N+1}N+1CN=NCN−1+N+12
用望远镜求和(telescoping):
CN∼2(N+1)(13+14+⋯+1N+1)∼2NlnNC_N \sim 2(N+1)\left(\frac{1}{3} + \frac{1}{4} + \cdots + \frac{1}{N+1}\right) \sim 2N \ln NCN∼2(N+1)(31+41+⋯+N+11)∼2NlnN
由于 2NlnN≈1.39Nlog2N2N \ln N \approx 1.39 N \log_2 N2NlnN≈1.39Nlog2N,平均比较次数只比最好情况多约 39%。
最坏情况(命题 L)
命题 L:最坏情况下(每次分区后一侧为空),快速排序需要约 N2/2N^2/2N2/2 次比较:
N+(N−1)+(N−2)+⋯+2+1=N(N−1)2≈N22N + (N-1) + (N-2) + \cdots + 2 + 1 = \frac{N(N-1)}{2} \approx \frac{N^2}{2}N+(N−1)+(N−2)+⋯+2+1=2N(N−1)≈2N2
但是,随机打乱数组后,出现最坏情况的概率极低,可以安全忽略。
书中有一个生动的比喻:对百万元素数组,快速排序退化到插入排序级别的概率,比你的电脑在排序期间被雷劈中的概率还低。
标准差约为 0.65N0.65N0.65N
这意味着随着 NNN 增大,实际运行时间高度集中在均值附近,不会大幅偏离。
七、四个实现注意事项
1. 原地分区(不开额外数组)
如果开辅助数组来做分区,逻辑简单但需要额外的 O(N)O(N)O(N) 空间,而且每次递归都分配的话,会严重拖慢速度。
2. 指针不越界
当最小元素或最大元素恰好是基准时,指针可能越界。代码中通过:
if (i == hi) break; // 右指针不超过 hi
if (j == lo) break; // 左指针不低于 lo
来防止越界。
3. 保持随机性
先随机打乱,使得每次分区的位置对任何输入都是"随机的"。另一种等价做法是在 partition() 中随机选基准,而不是固定选 a[lo]。
4. 正确处理相等元素
相等元素必须让两侧扫描都停下来(遇到 >=v 停右扫,遇到 <=v 停左扫),即使这会产生一些"不必要的交换"。
如果不这样做,当所有元素相同时,每次分区都只划分出 1 和 N-1 两段,退化成 O(N2)O(N^2)O(N2)。
八、三大改进
改进一:小数组改用插入排序
快速排序在小数组上由于递归开销,反而比插入排序慢。
把 if (hi <= lo) return; 改为:
if (hi <= lo + M) {
insertionSort(a, lo, hi);
return;
}
阈值 MMM 一般取 5 到 15 之间。
改进二:三数取中作为基准
随机选基准的改进版:取 a[lo]、a[mid]、a[hi] 三个元素,用它们的中位数作为基准,能更好地把数组分成两半。
示例:三个候选值 3, 9, 5
中位数是 5,用 5 作基准,效果比随机选好
改进三:三路分区(处理大量重复元素)
这是最重要的改进,接下来单独讲。
九、三路快速排序(Dijkstra 荷兰国旗问题)
背景
当数组中有大量重复元素时(比如按出生年份排人事档案),标准快速排序效率下降,因为重复的元素也会被反复分区。
三路分区将数组分成三段:
[ < v ] [ = v ] [ > v ]
所有等于基准的元素直接归位,不参与后续递归。
三路分区过程图解(荷兰国旗问题)
用颜色 B/R/W(蓝/红/白)演示,v = R:
初始:
lt=0 i=0 gt=11
[R B W W R W B R R W B R]
维护不变量:
a[lo..lt-1] 全部 < v(蓝色)
a[lt..i-1] 全部 = v(红色)
a[i..gt] 未检查
a[gt+1..hi] 全部 > v(白色)
当 a[i] < v:swap(a[lt], a[i]),lt++,i++
当 a[i] > v:swap(a[i], a[gt]),gt--
当 a[i] = v:i++
过程:
步骤1: a[0]=R=v,i++ → lt=0 i=1 gt=11
步骤2: a[1]=B<v,swap(0,1) → [B R ...] lt=1 i=2
步骤3: a[2]=W>v,swap(2,11) → [...W] gt=10
步骤4: a[2]=R=v,i++ → i=3
...(省略中间步骤)
最终:
[B B B R R R R R W W W W]
<R 部分 | =R 部分 | >R 部分
三路快速排序 C++ 完整代码
#include <iostream>
#include <vector>
#include <algorithm> // std::swap
#include <random> // std::shuffle
// ============================================================
// 三路快速排序核心函数
// 将 a[lo..hi] 分成三段:< v, = v, > v
// 等于 v 的元素直接归位,不参与后续递归
// ============================================================
void quicksort3way(std::vector<int>& a, int lo, int hi) {
if (hi <= lo) return;
int lt = lo; // [lo..lt-1] < v
int i = lo + 1; // [lt..i-1] = v(待检查从 i 开始)
int gt = hi; // [gt+1..hi] > v
int v = a[lo]; // 基准值
// i 从左往右扫,直到与 gt 相遇
while (i <= gt) {
if (a[i] < v) {
// a[i] 比基准小:把它换到左边(< 区域末尾)
std::swap(a[lt], a[i]);
lt++; // 左边扩展一位
i++; // i 也前进(换过来的那个等于 v)
} else if (a[i] > v) {
// a[i] 比基准大:把它换到右边(> 区域开头)
std::swap(a[i], a[gt]);
gt--; // 右边扩展一位
// 注意:i 不前进,因为换过来的元素还没检查
} else {
// a[i] == v:直接跳过,归入 = 区域
i++;
}
}
// 此时:a[lo..lt-1] < v = a[lt..gt] < a[gt+1..hi]
// 递归排左边(< v 的部分)和右边(> v 的部分)
// 等于 v 的部分 a[lt..gt] 已经在最终位置,不需要再动
quicksort3way(a, lo, lt - 1);
quicksort3way(a, gt + 1, hi);
}
// 对外接口:先随机打乱,再排序
void quicksort3waySort(std::vector<int>& a) {
std::random_device rd;
std::mt19937 rng(rd());
std::shuffle(a.begin(), a.end(), rng);
quicksort3way(a, 0, (int)a.size() - 1);
}
int main() {
// 含大量重复元素的数组,三路排序优势明显
std::vector<int> data = {3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 3, 2, 3};
std::cout << "排序前: ";
for (int x : data) std::cout << x << " ";
std::cout << "\n";
quicksort3waySort(data);
std::cout << "排序后: ";
for (int x : data) std::cout << x << " ";
std::cout << "\n";
return 0;
}
三路快速排序 (3-Way Quicksort) 详解
一、核心思想
普通快速排序每次只把数组分成两段(< 基准 和 >= 基准),遇到大量重复元素时效率退化。
三路快速排序把数组分成三段:
[ < v ] [ = v ] [ > v ]
等于基准的元素一次归位,不再参与后续递归,对重复元素多的数组效率提升显著。
二、三个指针的含义
index: lo lt i gt hi
array: [ < v | = v | ????? | > v ]
lt-1 lt gt gt+1
| 指针 | 含义 | 初始值 |
|---|---|---|
lt |
[lo, lt-1] 全部 < v |
lo |
i |
[lt, i-1] 全部 = v,当前检查位置 |
lo+1 |
gt |
[gt+1, hi] 全部 > v |
hi |
扫描区间 [i, gt] 是未处理区域,随着 i 前进、gt 后退逐渐缩小至空。
三、主循环三种情况
情况 A:a[i] < v
交换 a[lt] 和 a[i],然后 lt++,i++
之前: [ <v ... <v | =v ... =v | a[i]<v | ??? | >v ]
lt i gt
之后: [ <v ... <v | a[i] | =v ... =v | ??? | >v ]
lt i gt
换过来的那个一定是 = v(因为 lt 指向的是 = 区域起点),所以 i 可以安全前进。
情况 B:a[i] > v
交换 a[i] 和 a[gt],然后 gt--
之前: [ <v | =v | a[i]>v | ??? | a[gt] | >v ]
lt i gt
之后: [ <v | =v | a[gt] | ??? | >v ]
lt i gt
注意:i 不动,因为换过来的 a[gt] 还没检查!
情况 C:a[i] == v
直接 i++,扩展 = 区域
之前: [ <v | =v | a[i]=v | ??? | >v ]
lt i gt
之后: [ <v | =v ... =v | ??? | >v ]
lt i gt
四、完整流程图(Mermaid)
五、示例数据逐步追踪
初始数组(已随机打乱,假设打乱后为下列顺序):
原始: 3 1 4 1 5 9 2 6 5 3 5 3 2 3
以下演示对子数组 [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 3, 2, 3] 第一轮分区,基准 v = 3:
初始状态:
index: 0 1 2 3 4 5 6 7 8 9 10 11 12 13
value: 3 1 4 1 5 9 2 6 5 3 5 3 2 3
^ ^
lt,i gt
v = 3
步骤1: a[i=0]=3 == v => i++
步骤2: a[i=1]=1 < v => swap(a[lt=0], a[i=1]) lt++ i++
1 3 4 1 5 9 2 6 5 3 5 3 2 3
^ ^
lt i
步骤3: a[i=2]=4 > v => swap(a[i=2], a[gt=13]) gt--
1 3 3 1 5 9 2 6 5 3 5 3 2 4
^ ^ ^
lt i gt
步骤4: a[i=2]=3 == v => i++
步骤5: a[i=3]=1 < v => swap(a[lt=1], a[i=3]) lt++ i++
1 1 3 3 5 9 2 6 5 3 5 3 2 4
^ ^ ^
lt i gt
... (继续直到 i > gt)
最终分区结果示意:
[ 1 1 2 2 | 3 3 3 3 | 4 5 5 5 6 9 ]
< 3 段 = 3 段 > 3 段
六、核心代码标注版
void quicksort3way(std::vector<int>& a, int lo, int hi) {
if (hi <= lo) return; // 递归出口:区间为空或只有1个元素
int lt = lo; // [lo..lt-1] 全是 < v
int i = lo + 1; // [lt..i-1] 全是 = v,i 是待检查元素
int gt = hi; // [gt+1..hi] 全是 > v
int v = a[lo]; // 取第一个元素为基准(已提前随机打乱)
while (i <= gt) { // 未处理区 [i..gt] 不为空时继续
if (a[i] < v) {
std::swap(a[lt], a[i]); // 把小元素换到左边
lt++; // 左段扩展
i++; // 换来的必是 =v,直接跳过
} else if (a[i] > v) {
std::swap(a[i], a[gt]); // 把大元素换到右边
gt--; // 右段扩展
// i 不动!换来的元素未知,需要下轮检查
} else {
i++; // 等于基准,直接并入 =v 段
}
}
// 循环结束后三段已完全分好:
// a[lo..lt-1] < v
// a[lt..gt] = v <- 已在最终位置,不再递归
// a[gt+1..hi] > v
quicksort3way(a, lo, lt - 1); // 递归排 < v 的部分
quicksort3way(a, gt + 1, hi); // 递归排 > v 的部分
}
七、递归树结构(ASCII)
以 [3,1,4,1,5,9,2,6,5,3,5,3,2,3] 为例(示意,实际基准随打乱结果变化):
quicksort3way([3,1,4,1,5,9,2,6,5,3,5,3,2,3])
│
│ 第一轮:v=3,分为三段
│ < 3: [1,1,2,2] = 3: [3,3,3,3] > 3: [4,5,5,5,6,9]
│ ^已归位,不递归^
├─── quicksort3way([1,1,2,2])
│ │ v=1,分为三段
│ │ < 1: [] = 1: [1,1] > 1: [2,2]
│ ├─── quicksort3way([]) <- 空,返回
│ └─── quicksort3way([2,2])
│ │ v=2,全部 = 2,无需递归
│ └─── done
└─── quicksort3way([4,5,5,5,6,9])
│ v=4,分为三段
│ < 4: [] = 4: [4] > 4: [5,5,5,6,9]
├─── quicksort3way([]) <- 空,返回
└─── quicksort3way([5,5,5,6,9])
│ v=5,分为三段
│ < 5: [] = 5: [5,5,5] > 5: [6,9]
└─── quicksort3way([6,9])
│ v=6,分为三段
│ < 6: [] = 6: [6] > 6: [9]
└─── done
八、关键细节总结
| 问题 | 解答 |
|---|---|
为何 a[i] < v 时 i 可以前进? |
a[lt] 一定是 = v(lt 是 = 区域起点),交换后 a[i] 变成 = v,可跳过 |
为何 a[i] > v 时 i 不前进? |
a[gt] 是未知值,刚换过来需要重新检查 |
| 为何要提前随机打乱? | 避免有序输入时基准总在端点,导致递归树退化为 O(n^2) |
| 等于基准的元素为何不再递归? | 分区后 a[lt..gt] 已在最终正确位置,无需移动 |
| 时间复杂度优势在哪? | 重复元素多时,= v 段很大,递归规模大幅缩减,接近 O(n) |
九、三路 vs 普通快排对比
普通快排(大量重复时):
level 0: [3,1,4,1,5,3,2,3,3,3,3,3,3,3] <- 13个3
level 1: [1,1,2] + [3] + [3,3,3,3,3,3,3,3,3,3,4,5]
level 2: ... 每层只缩小一个3,O(n^2)
三路快排:
level 0: [3,1,4,1,5,3,2,3,3,3,3,3,3,3]
分出 [1,1,2] | [3,3,3,3,3,3,3,3,3] | [4,5]
10个3一次全部归位!只需递归两个小段
本文档结合 Sedgewick《算法》三路快排实现,适用于 C++17 及以上标准。
十、三路分区的递归树演示
以 [R B W W R W B R R W B R](R=1, B=0, W=2)为例:
可以看到,所有 R 在一次分区后就全部归位,不再参与递归,效率极高。
十一、信息熵与三路排序的最优性
信息熵的定义
设数组有 kkk 种不同的键值,第 iii 种出现的概率为 pi=fi/Np_i = f_i / Npi=fi/N,则香农熵为:
H=−∑i=1kpilog2piH = -\sum_{i=1}^{k} p_i \log_2 p_iH=−i=1∑kpilog2pi
直觉解释:HHH 越大,键值越"杂乱",排序所需信息越多;HHH 越小(重复多),排序所需信息越少。
命题 M(下界)
任何基于比较的排序算法,排序 NNN 个元素所需的比较次数不能少于:
NH−NNH - NNH−N
其中 HHH 是键值的香农熵。
命题 N(三路排序的上界)
三路快速排序对 NNN 个元素排序所需的比较次数约为:
2(ln2)⋅N⋅H≈1.386⋅NH2(\ln 2) \cdot N \cdot H \approx 1.386 \cdot NH2(ln2)⋅N⋅H≈1.386⋅NH
这与下界只差一个常数因子(约 39%),因此三路快速排序是熵最优的。
两个极端情况对比
| 情形 | 熵 HHH | 三路排序时间 | 归并排序时间 |
|---|---|---|---|
| 所有元素都不同(pi=1/Np_i = 1/Npi=1/N) | log2N\log_2 Nlog2N | O(NlogN)O(N \log N)O(NlogN) | O(NlogN)O(N \log N)O(NlogN) |
| 所有元素都相同(k=1,p1=1k=1, p_1=1k=1,p1=1) | 000 | O(N)O(N)O(N) | O(NlogN)O(N \log N)O(NlogN) |
| 只有常数种不同值 | 常数 | O(N)O(N)O(N) | O(NlogN)O(N \log N)O(NlogN) |
| 三路排序在有重复元素时远优于归并排序,在没有重复时与归并相当。 |
十二、带三项改进的完整快速排序(C++ 完整代码)
#include <iostream>
#include <vector>
#include <algorithm> // std::swap, std::min
#include <random>
// ============================================================
// 改进一:小数组用插入排序
// ============================================================
void insertionSort(std::vector<int>& a, int lo, int hi) {
for (int i = lo + 1; i <= hi; i++) {
int key = a[i];
int j = i - 1;
while (j >= lo && a[j] > key) {
a[j + 1] = a[j];
j--;
}
a[j + 1] = key;
}
}
const int CUTOFF = 10; // 小于等于 10 个元素时用插入排序
// ============================================================
// 改进二:三数取中,选出最优基准
// 同时把三个候选值排好,充当"哨兵",省去部分边界检查
// ============================================================
int medianOfThree(std::vector<int>& a, int lo, int hi) {
int mid = lo + (hi - lo) / 2;
// 排列 a[lo], a[mid], a[hi] 使得 a[lo] <= a[mid] <= a[hi]
if (a[mid] < a[lo]) std::swap(a[mid], a[lo]);
if (a[hi] < a[lo]) std::swap(a[hi], a[lo]);
if (a[hi] < a[mid]) std::swap(a[hi], a[mid]);
// 把中位数放到 lo 位置作为基准
std::swap(a[mid], a[lo]);
return a[lo];
}
// ============================================================
// 改进三:三路分区(处理重复元素)
// 将改进一、二、三合并在一起
// ============================================================
void quicksortImproved(std::vector<int>& a, int lo, int hi) {
// 改进一:小数组切换插入排序
if (hi - lo < CUTOFF) {
insertionSort(a, lo, hi);
return;
}
// 改进二:三数取中
int v = medianOfThree(a, lo, hi);
// 改进三:三路分区
int lt = lo, i = lo + 1, gt = hi;
while (i <= gt) {
if (a[i] < v) {
std::swap(a[lt++], a[i++]);
} else if (a[i] > v) {
std::swap(a[i], a[gt--]);
} else {
i++;
}
}
quicksortImproved(a, lo, lt - 1);
quicksortImproved(a, gt + 1, hi);
}
void quicksortFull(std::vector<int>& a) {
std::random_device rd;
std::mt19937 rng(rd());
std::shuffle(a.begin(), a.end(), rng);
quicksortImproved(a, 0, (int)a.size() - 1);
}
int main() {
// 测试1:普通数组
std::vector<int> data1 = {5, 2, 8, 3, 1, 9, 4, 7, 6, 0};
quicksortFull(data1);
std::cout << "普通数组排序: ";
for (int x : data1) std::cout << x << " ";
std::cout << "\n";
// 测试2:大量重复元素
std::vector<int> data2 = {3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 3, 2, 3};
quicksortFull(data2);
std::cout << "重复元素排序: ";
for (int x : data2) std::cout << x << " ";
std::cout << "\n";
// 测试3:已排序数组(不加随机打乱会退化)
std::vector<int> data3(20);
for (int i = 0; i < 20; i++) data3[i] = i;
quicksortFull(data3);
std::cout << "已排序数组: ";
for (int x : data3) std::cout << x << " ";
std::cout << "\n";
return 0;
}
十三、各版本快速排序性能对比
| 版本 | 平均情况 | 最坏情况 | 额外空间 | 稳定性 | 备注 |
|---|---|---|---|---|---|
| 标准快速排序 | ∼1.39Nlog2N\sim 1.39 N \log_2 N∼1.39Nlog2N | ∼N2/2\sim N^2/2∼N2/2 | O(logN)O(\log N)O(logN) 栈 | 不稳定 | 随机打乱后极少触发最坏情况 |
| 三路快速排序 | ∼1.386NH\sim 1.386 NH∼1.386NH | ∼N2/2\sim N^2/2∼N2/2 | O(logN)O(\log N)O(logN) 栈 | 不稳定 | 重复多时接近 O(N)O(N)O(N) |
| 带插入排序优化 | 实际最快 | 同上 | O(logN)O(\log N)O(logN) 栈 | 不稳定 | 工程首选 |
| 归并排序 | ∼NlogN\sim N \log N∼NlogN | ∼NlogN\sim N \log N∼NlogN | O(N)O(N)O(N) | 稳定 | 不受输入影响 |
其中 HHH 是香农熵,NNN 是元素个数。
十四、快速排序 vs 归并排序:为什么快速排序更快?
虽然两者平均都是 O(NlogN)O(N \log N)O(NlogN),但快速排序通常快 20%~30%,原因:
- 内层循环极短:分区的内循环只做"递增索引 + 比较",非常简单,CPU 流水线效率高
- 数据移动少:归并排序每次都要把数据复制到辅助数组再复制回来;快速排序只在原地交换
- 缓存友好:快速排序访问模式是顺序扫描,对 CPU 缓存非常友好
虽然快速排序比较次数比归并多约 39%,但每次比较所需的时间远少于归并排序中的数据移动,综合下来更快。
快速排序(Quicksort)——Q&A 与习题详解
原文出自《算法》第四版 2.3 节,以下从 C++ 角度进行中文解读,力求通俗易懂。
一、Q & A 解读
Q1:能不能直接把数组从中间一分为二,而不是让切分元素"自由落地"?
通俗理解:
快速排序每次选一个"基准值"(pivot),把数组分成"比它小"和"比它大"两部分。这个基准值最终落在哪个位置是不确定的——有时偏左,有时偏右,导致两边不均衡。
有人就想:能不能每次都精确地从中间切开,让两边一样长?这样递归深度最小,效率最高。
答案:
这个问题困扰了专家超过十年。要做到"从中间切",就必须先找到数组的中位数(median)。虽然理论上可以在线性时间 O(n)O(n)O(n) 内找到中位数,但那些算法本身就以快速排序的切分为基础,实现复杂、常数因子大,额外开销远远超过"均匀切分"带来的约 39% 的性能收益,得不偿失。
结论: 理论上可行,工程上不划算,实际中不用。
Q2:随机打乱数组会占用不少时间,真的值得吗?
通俗理解:
快速排序在排序前先把数组随机打乱一遍,这步额外花时间,有必要吗?
答案:
非常值得。原因有两点:
- 避免最坏情况。 如果不打乱,对于已经有序(或逆序)的数组,每次切分都极度不均衡(一边 0 个,一边 n−1n-1n−1 个),时间复杂度退化到 O(n2)O(n^2)O(n2)。打乱之后,这种情况出现的概率极低。
- 让运行时间可预测。 随机化之后,无论输入是什么数据,期望时间都是 O(nlogn)O(n \log n)O(nlogn),不会被"坏数据"拖垮。
这是 Hoare 在 1960 年提出快速排序时就建议的做法,是最早的随机化算法之一。
Q3:为什么要特别关注"相等键值"的元素?
通俗理解:
数组里有很多重复值时,该怎么处理?
答案:
这个问题在实际应用中影响很大,但被忽视了数十年。
- 老版本实现(遇到相等元素不停扫描):对于大量重复值的数组,时间复杂度退化为 O(n2)O(n^2)O(n2)。
- 改进版本(Algorithm 2.5,即三路切分):对大量重复值的数组达到 O(nlogn)O(n \log n)O(nlogn)。
- 最优版本(熵最优排序,三路快排):对大量重复值的数组可达到线性时间 O(n)O(n)O(n)。
实际项目中经常会遇到重复数据(比如按年龄、性别、状态码排序),所以这个优化很重要。
二、C++ 完整实现
2.1 标准快速排序(含随机打乱)
#include <iostream>
#include <vector>
#include <algorithm> // std::swap
#include <random> // std::mt19937, std::shuffle
#include <functional> // std::function
// ---------- 切分函数 ----------
// 以 arr[lo] 为基准,把 arr[lo..hi] 分成三段:
// arr[lo..j-1] <= arr[j] <= arr[j+1..hi]
// 返回基准元素最终所在的下标 j
int partition(std::vector<int>& arr, int lo, int hi) {
int pivot = arr[lo]; // 选最左边的元素作为基准
int i = lo; // 左指针,从 lo 出发向右走
int j = hi + 1; // 右指针,从 hi+1 出发向左走
while (true) {
// 左指针右移,直到找到 >= pivot 的元素
while (arr[++i] < pivot) {
if (i == hi) break;
}
// 右指针左移,直到找到 <= pivot 的元素
while (arr[--j] > pivot) {
if (j == lo) break;
}
// 两指针相遇,切分完成
if (i >= j) break;
// 交换:让小的留左边,大的留右边
std::swap(arr[i], arr[j]);
}
// 把基准放到正确位置(下标 j)
std::swap(arr[lo], arr[j]);
return j;
}
// ---------- 递归快速排序 ----------
void quicksort(std::vector<int>& arr, int lo, int hi) {
if (lo >= hi) return; // 子数组只有 0 或 1 个元素,无需排序
int j = partition(arr, lo, hi); // 切分,返回基准下标
quicksort(arr, lo, j - 1); // 递归排左半部分
quicksort(arr, j + 1, hi); // 递归排右半部分
}
// ---------- 对外接口(含随机打乱) ----------
void sort(std::vector<int>& arr) {
// 随机打乱,防止最坏情况
std::mt19937 rng(std::random_device{}());
std::shuffle(arr.begin(), arr.end(), rng);
quicksort(arr, 0, static_cast<int>(arr.size()) - 1);
}
int main() {
std::vector<int> arr = {5, 3, 8, 1, 9, 2, 7, 4, 6};
std::cout << "排序前: ";
for (int x : arr) std::cout << x << " ";
std::cout << "\n";
sort(arr);
std::cout << "排序后: ";
for (int x : arr) std::cout << x << " ";
std::cout << "\n";
return 0;
}
2.2 三路快排(处理大量重复键值,熵最优)
三路快排把数组分成三段:
[ < pivot ][ == pivot ][ > pivot ]
lo lt~gt hi
#include <iostream>
#include <vector>
#include <algorithm> // std::swap
#include <random> // std::mt19937, std::shuffle
// ---------- 三路切分递归排序 ----------
// 切分后数组变成:
// arr[lo..lt-1] < pivot
// arr[lt..gt] == pivot (这些元素已在最终位置,无需再排)
// arr[gt+1..hi] > pivot
void quicksort3way(std::vector<int>& arr, int lo, int hi) {
if (lo >= hi) return;
int lt = lo; // lt 左边的元素都 < pivot
int gt = hi; // gt 右边的元素都 > pivot
int i = lo + 1; // i 是当前待处理的元素
int pivot = arr[lo];
while (i <= gt) {
if (arr[i] < pivot) {
// 当前元素比基准小,换到左边区域
std::swap(arr[lt], arr[i]);
++lt;
++i;
} else if (arr[i] > pivot) {
// 当前元素比基准大,换到右边区域
// 注意:换过来的 arr[gt] 还没检查,i 不动
std::swap(arr[i], arr[gt]);
--gt;
} else {
// 当前元素等于基准,直接跳过
++i;
}
}
// arr[lt..gt] 全部等于 pivot,已就位,只需递归两侧
quicksort3way(arr, lo, lt - 1);
quicksort3way(arr, gt + 1, hi);
}
void sort3way(std::vector<int>& arr) {
std::mt19937 rng(std::random_device{}());
std::shuffle(arr.begin(), arr.end(), rng);
quicksort3way(arr, 0, static_cast<int>(arr.size()) - 1);
}
int main() {
// 包含大量重复键值的数组
std::vector<int> arr = {3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 3, 5};
std::cout << "排序前: ";
for (int x : arr) std::cout << x << " ";
std::cout << "\n";
sort3way(arr);
std::cout << "排序后: ";
for (int x : arr) std::cout << x << " ";
std::cout << "\n";
return 0;
}
三、习题详解
习题 2.3.1 — 演示 partition 过程
对数组 E A S Y Q U E S T I O N 演示切分(以第一个元素 E 为基准):
字母按字典序比较大小:A < E < I < N < O < Q < S < T < U < Y
初始状态:
pivot = E
E A S Y Q U E S T I O N
^ ^
lo hi
i=lo j=hi+1
过程(ASCII 演示):
初始: E A S Y Q U E S T I O N
pivot = E
i 右移找 >= E:i 停在 index 1(A < E,继续;S >= E,停)
实际 i 停在 index 2(S)
j 左移找 <= E:j 停在 index 6(E == E,停)
i=2 < j=6,交换 arr[2] 和 arr[6]:
E A E Y Q U S S T I O N
i 右移:Y > E,停在 index 3
j 左移:U > E,S > E,S > E,E == E,停在 index 2...
j=2,i=3,i >= j,退出循环
交换 arr[lo] 和 arr[j],即 arr[0] 和 arr[2]:
E A E Y Q U S S T I O N
^
j=2,基准 E 就位
结果:
左边 [E A] <= E
基准 E 在下标 2
右边 [Y Q U S S T I O N] >= E
注意:由于随机打乱会改变顺序,习题通常要求忽略初始打乱,直接用原始顺序演示。
习题 2.3.2 — 演示完整快速排序过程
对 E A S Y Q U E S T I O N 演示(忽略初始打乱):
递归树(每行是一次切分,| 表示基准就位):
E A S Y Q U E S T I O N
|
E A E | Y Q U S S T I O N
| |
A | E I | Y Q U S S T O N
|
N Q U S S T O | Y
|
N Q | S S T O | U
|
N | Q ...依此类推
完整展开篇幅较长,核心是:每次切分后基准就位,左右子数组独立递归。
习题 2.3.3 — 最大元素最多被交换几次?
结论: 最多被交换 ⌊log2N⌋\lfloor \log_2 N \rfloor⌊log2N⌋ 次。
推理:
最大元素只会在切分时被交换到右侧,之后在子数组中永远是"最大",每次切分后都会移到子数组的最右边,不再参与左边的切分。每次交换对应一层递归,最多有 ⌊log2N⌋\lfloor \log_2 N \rfloor⌊log2N⌋ 层,所以最多交换这么多次。
习题 2.3.4 — 不打乱时,哪些数组触发最坏情况?
不打乱时,若始终选 arr[lo] 为基准,触发最坏情况的条件是:每次切分都极度不均衡(一边为空)。
六种示例(N=10):
1. 已升序: 1 2 3 4 5 6 7 8 9 10
2. 已降序: 10 9 8 7 6 5 4 3 2 1
3. 除第一个外升序:5 1 2 3 4 6 7 8 9 10
4. 除最后一个外升序:1 2 3 4 5 6 7 8 10 9
5. 前半升后半升但错位:1 3 5 7 9 2 4 6 8 10
6. 所有元素相同: 5 5 5 5 5 5 5 5 5 5
最坏情况比较次数:
Cworst(N)=N(N−1)2=O(N2)C_{worst}(N) = \frac{N(N-1)}{2} = O(N^2)Cworst(N)=2N(N−1)=O(N2)
习题 2.3.5 — 只有两种键值时的排序
只有两种键值(比如 0 和 1),本质上就是荷兰国旗问题的简化版,一趟扫描即可:
#include <iostream>
#include <vector>
#include <algorithm> // std::swap
// 只有两种键值(0 和 1)的排序
// 思路:维护一个边界 boundary
// [0, boundary) 全是 0
// [boundary, i) 全是 1
void sortTwoKeys(std::vector<int>& arr) {
int boundary = 0; // boundary 左边全放 0
for (int i = 0; i < static_cast<int>(arr.size()); ++i) {
if (arr[i] == 0) {
// 遇到 0,换到前面,边界右移
std::swap(arr[boundary], arr[i]);
++boundary;
}
// 遇到 1,什么都不做,继续往右扫
}
}
int main() {
std::vector<int> arr = {1, 0, 1, 1, 0, 0, 1, 0, 1};
std::cout << "排序前: ";
for (int x : arr) std::cout << x << " ";
std::cout << "\n";
sortTwoKeys(arr);
std::cout << "排序后: ";
for (int x : arr) std::cout << x << " ";
std::cout << "\n";
return 0;
}
时间复杂度:O(N)O(N)O(N),空间复杂度:O(1)O(1)O(1)。
习题 2.3.6 — 计算精确比较次数 CNC_NCN 并与近似值对比
快速排序期望比较次数的精确递推式:
CN=(N+1)+1N∑k=0N−1(Ck+CN−1−k),C0=C1=0C_N = (N+1) + \frac{1}{N} \sum_{k=0}^{N-1}(C_k + C_{N-1-k}), \quad C_0 = C_1 = 0CN=(N+1)+N1k=0∑N−1(Ck+CN−1−k),C0=C1=0
近似值为:
CN≈2NlnNC_N \approx 2N \ln NCN≈2NlnN
#include <iostream>
#include <vector>
#include <cmath> // std::log
int main() {
// 计算 C[N] 的精确值
// 递推:C[n] = (n+1) + (2/n) * sum(C[0]..C[n-1])
// 等价形式:n*C[n] = n*(n+1) + 2*sum(C[0]..C[n-1])
// 用前缀和加速
int maxN = 10000;
std::vector<double> C(maxN + 1, 0.0);
// C[0] = 0, C[1] = 0(已初始化)
double prefixSum = 0.0; // 存储 sum(C[0]..C[k-1])
for (int n = 2; n <= maxN; ++n) {
// C[n] = (n+1) + (2.0/n) * prefixSum
C[n] = (n + 1) + (2.0 / n) * prefixSum;
prefixSum += C[n];
}
// 输出对比
std::cout << "N\t\tC_N (精确)\t2N*ln(N) (近似)\t误差%\n";
std::cout << "------------------------------------------------------------\n";
for (int N : {100, 1000, 10000}) {
double exact = C[N];
double approx = 2.0 * N * std::log(static_cast<double>(N));
double errPct = (approx - exact) / exact * 100.0;
std::cout << N << "\t\t"
<< static_cast<long long>(exact) << "\t\t"
<< static_cast<long long>(approx) << "\t\t"
<< errPct << "%\n";
}
return 0;
}
预期输出(近似):
N C_N (精确) 2N*ln(N) (近似) 误差%
------------------------------------------------------------
100 919 921 ~0.2%
1000 11,393 13,816 ~21% (近似值偏大)
10000 132,878 184,207 ~38%
2NlnN2N \ln N2NlnN 是上界估计,实际精确值略小,两者在大 N 时差距约 39%。
习题 2.3.7 — 大小为 0、1、2 的子数组期望数量
对 N 个不同键值的数组执行快速排序,期望产生的子数组数量:
设切分后子数组大小为 kkk(k=0,1,2,…k = 0, 1, 2, \ldotsk=0,1,2,…),可以用概率分析推导:
大小为 0 的子数组(空子数组):
E[size-0]≈N−1+2N≈NE[\text{size-0}] \approx N - 1 + \frac{2}{N} \approx NE[size-0]≈N−1+N2≈N
大小为 1 的子数组:
E[size-1]≈N+13⋅2≈2(N+1)3E[\text{size-1}] \approx \frac{N+1}{3} \cdot 2 \approx \frac{2(N+1)}{3}E[size-1]≈3N+1⋅2≈32(N+1)
大小为 2 的子数组:
E[size-2]≈N+16E[\text{size-2}] \approx \frac{N+1}{6}E[size-2]≈6N+1
不想推导的话,下面程序直接用实验验证:
#include <iostream>
#include <vector>
#include <random>
#include <numeric> // std::iota
// 统计各种子数组大小出现次数的全局计数器
long long cnt0 = 0, cnt1 = 0, cnt2 = 0;
int partition(std::vector<int>& arr, int lo, int hi) {
int pivot = arr[lo];
int i = lo, j = hi + 1;
while (true) {
while (arr[++i] < pivot) if (i == hi) break;
while (arr[--j] > pivot) if (j == lo) break;
if (i >= j) break;
std::swap(arr[i], arr[j]);
}
std::swap(arr[lo], arr[j]);
return j;
}
void quicksort(std::vector<int>& arr, int lo, int hi) {
if (lo > hi) { ++cnt0; return; } // 空子数组
if (lo == hi) { ++cnt1; return; } // 大小为 1
if (hi - lo == 1) { // 大小为 2
++cnt2;
if (arr[lo] > arr[hi]) std::swap(arr[lo], arr[hi]);
return;
}
int j = partition(arr, lo, hi);
quicksort(arr, lo, j - 1);
quicksort(arr, j + 1, hi);
}
int main() {
const int N = 1000;
const int TRIALS = 10000;
long long total0 = 0, total1 = 0, total2 = 0;
std::mt19937 rng(42);
for (int t = 0; t < TRIALS; ++t) {
std::vector<int> arr(N);
std::iota(arr.begin(), arr.end(), 1); // 填入 1..N
std::shuffle(arr.begin(), arr.end(), rng); // 随机打乱
cnt0 = cnt1 = cnt2 = 0;
quicksort(arr, 0, N - 1);
total0 += cnt0;
total1 += cnt1;
total2 += cnt2;
}
std::cout << "N = " << N << ",实验次数 = " << TRIALS << "\n\n";
std::cout << "大小为 0 的子数组期望数量:" << total0 / TRIALS << "\n";
std::cout << "大小为 1 的子数组期望数量:" << total1 / TRIALS << "\n";
std::cout << "大小为 2 的子数组期望数量:" << total2 / TRIALS << "\n";
// 理论值对比
std::cout << "\n理论值对比:\n";
std::cout << "大小为 0(理论 ~N):" << N << "\n";
std::cout << "大小为 1(理论 ~2(N+1)/3):" << 2*(N+1)/3 << "\n";
std::cout << "大小为 2(理论 ~(N+1)/6):" << (N+1)/6 << "\n";
return 0;
}
习题 2.3.8 — 所有元素相等时的比较次数
若 N 个元素全部相等:
标准快速排序(不停扫):每次切分,左右指针各走一遍,共约 N24\frac{N^2}{4}4N2 次比较,O(N2)O(N^2)O(N2)。
三路快排:第一次切分后所有元素都进入"等于"区间,直接结束,只需 N−1N-1N−1 次比较,O(N)O(N)O(N)。
标准版的比较次数公式:
CN≈N24C_N \approx \frac{N^2}{4}CN≈4N2
习题 2.3.9 — 两种或三种键值时快速排序的行为
只有两种键值(如 0 和 1):
标准快速排序每次切分后,"等于基准"的元素可能分散在两侧,导致大量重复元素仍需排序,时间复杂度接近 O(N2)O(N^2)O(N2)。
三路快排可以在 O(N)O(N)O(N) 内完成(大量相等元素一次性归位)。
只有三种键值(如 1、2、3):
类似,标准版接近 O(N2)O(N^2)O(N2),三路快排接近 O(N)O(N)O(N)(三堆各自归位)。
习题 2.3.10 — 切比雪夫不等式界定概率
快速排序比较次数的均值与方差(N=106N = 10^6N=106):
均值:
μ≈2NlnN=2×106×ln(106)≈2.77×107\mu \approx 2N \ln N = 2 \times 10^6 \times \ln(10^6) \approx 2.77 \times 10^7μ≈2NlnN=2×106×ln(106)≈2.77×107
方差(理论结果):
σ2≈(0.21)N2≈2.1×1011\sigma^2 \approx (0.21) N^2 \approx 2.1 \times 10^{11}σ2≈(0.21)N2≈2.1×1011
标准差:
σ≈2.1×1011≈4.58×105\sigma \approx \sqrt{2.1 \times 10^{11}} \approx 4.58 \times 10^5σ≈2.1×1011≈4.58×105
题目问:比较次数超过 0.1N2=10110.1 N^2 = 10^{11}0.1N2=1011 的概率?
超出量(偏差):
δ=1011−2.77×107≈1011\delta = 10^{11} - 2.77 \times 10^7 \approx 10^{11}δ=1011−2.77×107≈1011
偏差是标准差的倍数:
k=δσ≈10114.58×105≈2.18×105k = \frac{\delta}{\sigma} \approx \frac{10^{11}}{4.58 \times 10^5} \approx 2.18 \times 10^5k=σδ≈4.58×1051011≈2.18×105
由切比雪夫不等式:
P(∣X−μ∣≥kσ)≤1k2P(|X - \mu| \geq k\sigma) \leq \frac{1}{k^2}P(∣X−μ∣≥kσ)≤k21
代入得:
P(CN≥1011)≤1(2.18×105)2≈14.75×1010≈2.1×10−11P(C_N \geq 10^{11}) \leq \frac{1}{(2.18 \times 10^5)^2} \approx \frac{1}{4.75 \times 10^{10}} \approx 2.1 \times 10^{-11}P(CN≥1011)≤(2.18×105)21≈4.75×10101≈2.1×10−11
结论: 概率极小(约 2×10−112 \times 10^{-11}2×10−11),快速排序在实践中非常安全。
习题 2.3.11 — 扫描过等值元素时退化为 O(N2)O(N^2)O(N2)
若修改扫描规则:遇到等于基准的元素不停下,继续扫描,会发生什么?
分析:
以全部元素相等(N 个 5)为例:
- i 会一路扫到最右边(hi),j 一路扫到最左边(lo)
- 两指针相遇时,基准落在最左端或最右端,子数组长度为 N−1N-1N−1 和 000
- 每次切分只缩减 1 个元素,共需 N 次切分
- 总比较次数:N+(N−1)+…+1=N(N−1)2=O(N2)N + (N-1) + \ldots + 1 = \frac{N(N-1)}{2} = O(N^2)N+(N−1)+…+1=2N(N−1)=O(N2)
结论: 只要键值种类是常数(不随 N 增长),这种实现都会退化到 O(N2)O(N^2)O(N2)。
习题 2.3.12 — 三路快排切分演示
对 B A B A B A B A C A D A B R A 演示第一次三路切分:
数组:B A B A B A B A C A D A B R A
pivot = B(第一个元素)
lt=0, gt=14, i=1
逐步扫描:
i=1: A < B → swap(arr[lt=0], arr[i=1]) → A B B A B A B A C A D A B R A, lt=1, i=2
i=2: B == B → i=3
i=3: A < B → swap(arr[lt=1], arr[i=3]) → A A B B B A B A C A D A B R A, lt=2, i=4
i=4: B == B → i=5
i=5: A < B → swap(arr[lt=2], arr[i=5]) → A A A B B B B A C A D A B R A, lt=3, i=6
i=6: B == B → i=7
i=7: A < B → swap(arr[lt=3], arr[i=7]) → A A A A B B B B C A D A B R A, lt=4, i=8
i=8: C > B → swap(arr[i=8], arr[gt=14]) → A A A A B B B B A A D A B R C, gt=13
i=8: A < B → swap(arr[lt=4], arr[i=8]) → A A A A A B B B B A D A B R C, lt=5, i=9
i=9: A < B → swap(arr[lt=5], arr[i=9]) → A A A A A A B B B B D A B R C, lt=6, i=10
i=10: D > B → swap(arr[i=10], arr[gt=13]) → A A A A A A B B B B R A B D C, gt=12
i=10: R > B → swap(arr[i=10], arr[gt=12]) → A A A A A A B B B B B A R D C, gt=11
i=10: B == B → i=11
i=11: A < B → swap(arr[lt=6], arr[i=11]) → A A A A A A A B B B B B R D C, lt=7, i=12
i=12 > gt=11,退出循环
最终结果:
A A A A A A A | B B B B B | R D C
[lt=0..6] [lt=7..gt=11] [gt+1=12..14]
全 < B 全 == B 全 > B
习题 2.3.13 — 递归深度
| 情形 | 递归深度 |
|---|---|
| 最好情况(每次均匀切分) | O(logN)O(\log N)O(logN) |
| 平均情况 | O(logN)O(\log N)O(logN) |
| 最坏情况(每次极度不均) | O(N)O(N)O(N) |
最坏情况(如已排序数组不打乱):递归深度达到 NNN,每层一个调用帧,可能导致栈溢出。
随机打乱后,最坏情况概率极低,期望深度为:
E[depth]≈2lnNE[\text{depth}] \approx 2 \ln NE[depth]≈2lnN
习题 2.3.14 — 第 i 大和第 j 大元素被比较的概率
结论: 在快速排序中,第 iii 大和第 j 大元素(i<ji < ji<j)被直接比较的概率为:
P(i 与 j 被比较)=2j−i+1P(\text{i 与 j 被比较}) = \frac{2}{j - i + 1}P(i 与 j 被比较)=j−i+12
推理:
考虑元素集合 {i,i+1,…,j}\{i, i+1, \ldots, j\}{i,i+1,…,j},共 j−i+1j - i + 1j−i+1 个元素。只有当 iii 或 jjj 是这个集合中最先被选为基准的元素时,它们才会直接比较。因为:
- 若选了 iii 或 jjj:两者必然直接比较(一个是基准,另一个在同一子数组里)
- 若选了 iii 和 jjj 之间的某个元素:iii 和 jjj 被分到不同子数组,永远不比较
iii 或 jjj 最先被选为基准的概率(共 j−i+1j-i+1j−i+1 种等可能的选择):
P=2j−i+1P = \frac{2}{j - i + 1}P=j−i+12
用此推导期望比较次数(命题 K):
CN=∑1≤i<j≤N2j−i+1C_N = \sum_{1 \leq i < j \leq N} \frac{2}{j - i + 1}CN=1≤i<j≤N∑j−i+12
令 d=j−id = j - id=j−i(距离从 1 到 N−1N-1N−1),对于每个距离 ddd 有 N−dN - dN−d 对:
CN=∑d=1N−1(N−d)⋅2d+1=2N∑d=1N−11d+1−2∑d=1N−1dd+1C_N = \sum_{d=1}^{N-1} (N-d) \cdot \frac{2}{d+1} = 2N \sum_{d=1}^{N-1} \frac{1}{d+1} - 2\sum_{d=1}^{N-1} \frac{d}{d+1}CN=d=1∑N−1(N−d)⋅d+12=2Nd=1∑N−1d+11−2d=1∑N−1d+1d
当 NNN 足够大时,第一项主导,近似为:
CN≈2NlnNC_N \approx 2N \ln NCN≈2NlnN
这正是快速排序期望比较次数的经典结论。
快速排序练习详解 (2.3.15 - 2.3.19)
2.3.15 螺母和螺栓问题 (Nuts and Bolts)
题目理解
你有 NNN 个螺母和 NNN 个螺栓混在一起,需要快速找到配对的螺母和螺栓。规则如下:
- 每个螺母恰好匹配一个螺栓,反之亦然
- 你可以把一个螺母和一个螺栓拧在一起,判断哪个大哪个小(或者刚好匹配)
- 不能直接比较两个螺母,也不能直接比较两个螺栓
这就像你面前有一堆大小不同的螺母和螺栓,你只能通过"试拧"来判断大小关系。
核心思路
这个问题的关键约束是:同类元素之间不能比较。这让我们没法直接排序螺母或者螺栓,但可以用类似快速排序的分治策略:
- 随机选一个螺栓作为"基准"
- 用这个螺栓去和所有螺母比较,把螺母分成三组:比它小的、匹配的、比它大的
- 找到匹配的螺母后,用这个螺母去和所有螺栓比较,把螺栓也分成三组
- 现在,小螺母组和小螺栓组配对,大螺母组和大螺栓组配对
- 递归处理两个子问题
算法流程图
时间复杂度
- 和快速排序完全类似,期望时间复杂度为 O(NlogN)O(N \log N)O(NlogN)
- 最坏情况 O(N2)O(N^2)O(N2),但通过随机化可以避免
完整代码
#include <iostream>
#include <vector>
#include <algorithm> // for std::swap, std::shuffle
#include <random>
#include <ctime>
// 比较函数:模拟"用螺母试螺栓"或"用螺栓试螺母"
// 返回值: -1 表示 a < b, 0 表示匹配, 1 表示 a > b
int compare(int a, int b) {
if (a < b) return -1;
if (a > b) return 1;
return 0;
}
// 用一个基准元素对数组进行分区
// pivot: 基准值(来自另一个数组)
// arr: 要被分区的数组
// lo, hi: 分区范围
// 返回基准元素最终所在的位置
int partition(std::vector<int>& arr, int lo, int hi, int pivot) {
int i = lo; // 扫描指针
// 第一步:找到 pivot 在 arr 中的匹配元素,把它换到末尾暂存
for (int k = lo; k <= hi; k++) {
if (compare(arr[k], pivot) == 0) {
std::swap(arr[k], arr[hi]); // 把匹配的元素换到最后
break;
}
}
// 第二步:以 pivot 为基准,将 arr[lo..hi-1] 分区
// 小于 pivot 的放左边,大于 pivot 的放右边
for (int j = lo; j < hi; j++) {
if (compare(arr[j], pivot) < 0) {
std::swap(arr[i], arr[j]);
i++;
}
}
// 第三步:把暂存在末尾的匹配元素放到正确位置
std::swap(arr[i], arr[hi]);
return i; // 返回匹配元素的最终位置
}
// 核心递归函数:同时对螺母和螺栓进行匹配排序
void matchNutsAndBolts(std::vector<int>& nuts, std::vector<int>& bolts,
int lo, int hi) {
if (lo >= hi) return; // 基本情况:0或1个元素,无需处理
// 第一步:选 bolts[hi] 作为基准,对 nuts 进行分区
// 分区后,与 bolts[hi] 匹配的螺母在位置 pivotIndex
int pivotIndex = partition(nuts, lo, hi, bolts[hi]);
// 第二步:用匹配到的螺母 nuts[pivotIndex] 作为基准,对 bolts 进行分区
// 这保证了 bolts 的分区方式和 nuts 一致
partition(bolts, lo, hi, nuts[pivotIndex]);
// 此时 nuts[pivotIndex] 和 bolts[pivotIndex] 已经配对
// 第三步:递归处理左半部分和右半部分
matchNutsAndBolts(nuts, bolts, lo, pivotIndex - 1);
matchNutsAndBolts(nuts, bolts, pivotIndex + 1, hi);
}
int main() {
// 螺母和螺栓的大小(实际中我们不知道具体值,只能通过比较)
std::vector<int> nuts = {5, 3, 7, 1, 9, 2, 8, 4, 6, 10};
std::vector<int> bolts = {8, 2, 6, 10, 4, 1, 9, 3, 5, 7};
int n = nuts.size();
std::cout << "=== 螺母和螺栓匹配问题 ===" << std::endl;
std::cout << "匹配前:" << std::endl;
std::cout << "螺母: ";
for (int x : nuts) std::cout << x << " ";
std::cout << std::endl;
std::cout << "螺栓: ";
for (int x : bolts) std::cout << x << " ";
std::cout << std::endl;
// 随机打乱以获得期望 O(N log N) 性能
std::mt19937 rng(static_cast<unsigned>(time(nullptr)));
std::shuffle(nuts.begin(), nuts.end(), rng);
std::shuffle(bolts.begin(), bolts.end(), rng);
// 执行匹配
matchNutsAndBolts(nuts, bolts, 0, n - 1);
std::cout << "\n匹配后(已排序配对):" << std::endl;
std::cout << "螺母: ";
for (int x : nuts) std::cout << x << " ";
std::cout << std::endl;
std::cout << "螺栓: ";
for (int x : bolts) std::cout << x << " ";
std::cout << std::endl;
// 验证配对
std::cout << "\n配对结果:" << std::endl;
for (int i = 0; i < n; i++) {
std::cout << "螺母 " << nuts[i] << " <-> 螺栓 " << bolts[i]
<< (nuts[i] == bolts[i] ? " [匹配]" : " [错误!]") << std::endl;
}
return 0;
}
分区过程示例
假设螺栓 = {5, 3, 7, 1, 9},螺母 = {3, 9, 1, 5, 7},选螺栓 9 为基准:
用螺栓 9 对螺母分区:
[3, 9, 1, 5, 7]
找到匹配的螺母 9,换到末尾: [3, 7, 1, 5, 9]
扫描分区:
3 < 9 -> 左边
7 < 9 -> 左边
1 < 9 -> 左边
5 < 9 -> 左边
结果: [3, 7, 1, 5, | 9] pivotIndex = 4
用螺母 9 对螺栓分区:
[5, 3, 7, 1, 9]
找到匹配的螺栓 9,换到末尾: [5, 3, 7, 1, 9]
扫描分区: 全部 < 9
结果: [5, 3, 7, 1, | 9] pivotIndex = 4
配对成功: 螺母9 <-> 螺栓9
递归处理左半边: [3,7,1,5] vs [5,3,7,1]
2.3.16 最佳情况数组 (Best Case)
题目理解
要构造一个数组,使得快速排序(Algorithm 2.5,不做初始随机打乱)的每次分区都能产生大小差最多为1的两个子数组——也就是每次都恰好"对半分"。这是快速排序的最佳情况,此时比较次数约为 NlogNN \log NNlogN。
核心思路
逆向思维:我们不是去排序一个数组,而是逆向构造一个数组,使得标准快速排序的分区过程恰好每次都选中中位数。
方法:从一个已排序的数组出发,递归地把"将成为基准的元素"放到子数组的起始位置。
具体步骤:
- 从排序好的数组
[0, 1, 2, ..., N-1]开始 - 对于每个子数组
[lo, hi],中间位置 mid=lo+(hi−lo)/2mid = lo + (hi - lo) / 2mid=lo+(hi−lo)/2 - 将
mid位置的元素与lo位置的元素交换(因为标准快排选a[lo]作为基准) - 递归处理左右子数组
算法过程图
构造过程(N=7, 数组下标 0~6):
初始排好序: [0, 1, 2, 3, 4, 5, 6]
第一层: mid = 3, swap(arr[0], arr[3])
[3, 1, 2, 0, 4, 5, 6]
基准=3, 左子[1,2,0], 右子[4,5,6]
左子[lo=1,hi=3]: mid=2, swap(arr[1], arr[2])
[3, 2, 1, 0, 4, 5, 6]
右子[lo=4,hi=6]: mid=5, swap(arr[4], arr[5])
[3, 2, 1, 0, 5, 4, 6]
继续递归直到子数组大小<=1...
最终结果: [3, 1, 0, 2, 5, 4, 6]
排序这个数组时,每次分区都会把数组分成大小相等(或差1)的两半。
完整代码
#include <iostream>
#include <vector>
#include <algorithm>
// 构造最佳情况数组的核心递归函数
// arr: 当前正在构造的数组
// lo, hi: 当前处理的子数组范围
void bestCaseHelper(std::vector<int>& arr, int lo, int hi) {
if (hi <= lo) return; // 子数组大小 <= 1,无需处理
// 计算中间位置:这个位置的元素将成为分区基准
int mid = lo + (hi - lo) / 2;
// 把中间元素交换到 lo 位置
// 因为标准快排选 a[lo] 作为基准
// 这样分区时 a[lo] 恰好是中位数,能把数组对半分
std::swap(arr[lo], arr[mid]);
// 递归构造左子数组和右子数组
bestCaseHelper(arr, lo + 1, mid); // 左半边: [lo+1, mid]
bestCaseHelper(arr, mid + 1, hi); // 右半边: [mid+1, hi]
}
// 生成快速排序最佳情况数组
std::vector<int> bestCase(int n) {
// 从有序数组开始
std::vector<int> arr(n);
for (int i = 0; i < n; i++) {
arr[i] = i;
}
// 递归调整位置,使得每次分区都对半分
bestCaseHelper(arr, 0, n - 1);
return arr;
}
// ============= 验证:标准快排分区 =============
int partitionCount = 0; // 记录分区调用次数
int partition(std::vector<int>& a, int lo, int hi) {
partitionCount++;
int v = a[lo]; // 基准元素
int i = lo, j = hi + 1;
while (true) {
// 从左向右找到第一个 >= v 的元素
while (a[++i] < v) {
if (i == hi) break;
}
// 从右向左找到第一个 <= v 的元素
while (v < a[--j]) {
if (j == lo) break;
}
if (i >= j) break;
std::swap(a[i], a[j]);
}
std::swap(a[lo], a[j]);
// 打印分区信息
int leftSize = j - lo; // 左子数组大小
int rightSize = hi - j; // 右子数组大小
std::cout << " 分区 [" << lo << "," << hi << "]"
<< " 基准=" << v
<< " 位置=" << j
<< " 左大小=" << leftSize
<< " 右大小=" << rightSize
<< " 差=" << std::abs(leftSize - rightSize)
<< std::endl;
return j;
}
void quicksort(std::vector<int>& a, int lo, int hi) {
if (hi <= lo) return;
int j = partition(a, lo, hi);
quicksort(a, lo, j - 1);
quicksort(a, j + 1, hi);
}
int main() {
for (int n : {7, 15, 16}) {
std::cout << "=== N = " << n << " ===" << std::endl;
std::vector<int> arr = bestCase(n);
std::cout << "最佳情况数组: ";
for (int x : arr) std::cout << x << " ";
std::cout << std::endl;
// 验证:用标准快排排序,看每次分区是否对半分
std::cout << "分区过程:" << std::endl;
partitionCount = 0;
quicksort(arr, 0, n - 1);
std::cout << "排序结果: ";
for (int x : arr) std::cout << x << " ";
std::cout << std::endl;
std::cout << "总分区次数: " << partitionCount << std::endl;
std::cout << std::endl;
}
return 0;
}
为什么这是最佳情况?
快速排序的递推关系:
C(N)=2C(N/2)+NC(N) = 2C(N/2) + NC(N)=2C(N/2)+N
当每次分区都对半分时,递归树的深度最小(为 lgN\lg NlgN),每层的总比较次数为 NNN,所以总比较次数为:
C(N)≈NlgNC(N) \approx N \lg NC(N)≈NlgN
这比平均情况的 ∼1.39NlgN\sim 1.39 N \lg N∼1.39NlgN 要好约 39%。
2.3.17 哨兵 (Sentinels)
题目理解
标准快排的分区循环中有两个边界检查:
while (a[++i] < v) if (i == hi) break; // 检查右边界
while (v < a[--j]) if (j == lo) break; // 检查左边界
每次内循环迭代都要做边界检查,虽然每次只花一点点时间,但在 NNN 很大时积少成多。哨兵技术可以去掉这两个检查:
左边界:a[lo] 就是基准 v,所以 j 向左扫描时 a[lo] == v 不会小于 v,自然停下。左边界检查本来就是多余的。
右边界:在打乱数组后,把整个数组最大的元素放到 a[N-1]。这个最大元素永远不会比任何基准小,所以 i 向右扫描时一定会在这里停下,不会越界。
完整代码
#include <iostream>
#include <vector>
#include <algorithm>
#include <random>
#include <ctime>
// 简单插入排序(用于小数组优化,此处用于辅助)
void insertionSort(std::vector<int>& a, int lo, int hi) {
for (int i = lo + 1; i <= hi; i++) {
int key = a[i];
int j = i - 1;
while (j >= lo && a[j] > key) {
a[j + 1] = a[j];
j--;
}
a[j + 1] = key;
}
}
// 不带边界检查的分区函数
// 前提:a[a.size()-1] 是整个数组的最大值(右哨兵)
// a[lo] 是基准(天然的左哨兵)
int partition(std::vector<int>& a, int lo, int hi) {
int v = a[lo]; // 基准元素,同时是左哨兵
int i = lo;
int j = hi + 1;
while (true) {
// 从左向右扫描,找到第一个 >= v 的元素
// 不需要检查 i == hi,因为右哨兵保证 i 不会越界
while (a[++i] < v) { }
// 从右向左扫描,找到第一个 <= v 的元素
// 不需要检查 j == lo,因为 a[lo] == v 保证 j 不会越过 lo
while (v < a[--j]) { }
if (i >= j) break;
std::swap(a[i], a[j]);
}
std::swap(a[lo], a[j]);
return j;
}
void quicksort(std::vector<int>& a, int lo, int hi) {
if (hi <= lo) return;
int j = partition(a, lo, hi);
quicksort(a, lo, j - 1);
quicksort(a, j + 1, hi);
}
void sort(std::vector<int>& a) {
int n = a.size();
if (n <= 1) return;
// 随机打乱数组
std::mt19937 rng(static_cast<unsigned>(time(nullptr)));
std::shuffle(a.begin(), a.end(), rng);
// *** 关键:找到最大元素,放到数组末尾作为右哨兵 ***
int maxIndex = 0;
for (int i = 1; i < n; i++) {
if (a[i] > a[maxIndex]) {
maxIndex = i;
}
}
// 把最大元素换到 a[n-1]
std::swap(a[maxIndex], a[n - 1]);
// 开始快排(a[n-1]是最大值,作为所有子数组的右哨兵)
quicksort(a, 0, n - 1);
}
int main() {
std::vector<int> a = {38, 27, 43, 3, 9, 82, 10, 55, 1, 100, 64, 17};
int n = a.size();
std::cout << "排序前: ";
for (int x : a) std::cout << x << " ";
std::cout << std::endl;
sort(a);
std::cout << "排序后: ";
for (int x : a) std::cout << x << " ";
std::cout << std::endl;
// 验证
bool sorted = true;
for (int i = 1; i < n; i++) {
if (a[i] < a[i - 1]) {
sorted = false;
break;
}
}
std::cout << (sorted ? "验证通过!" : "排序错误!") << std::endl;
return 0;
}
为什么哨兵能工作?
数组: [..., lo, ..., hi, ..., MAX]
^ ^
基准v 右哨兵(最大值)
左哨兵原理:
j 从右向左扫描,遇到 <= v 的元素就停
a[lo] = v, 所以 j 最远到 lo 就会停 (因为 v <= v)
-> 不需要 j == lo 的边界检查
右哨兵原理:
i 从左向右扫描,遇到 >= v 的元素就停
a[N-1] = MAX >= v, 所以 i 最远到 N-1 就会停
对于内部子数组,右边相邻子数组的最左元素 > v, 也能当哨兵
-> 不需要 i == hi 的边界检查
2.3.18 三取样切分 (Median-of-3 Partitioning)
题目理解
标准快排选 a[lo] 作为基准,运气不好时可能选到最小或最大元素,导致极度不平衡的分区。
三取样切分的改进:从子数组中取三个元素(通常是 a[lo], a[mid], a[hi]),选它们的中位数作为基准。这样:
- 基准更可能接近真正的中位数
- 分区更均衡
- 减少最坏情况出现的概率
三取样中位数选择图解
假设 a[lo]=7, a[mid]=3, a[hi]=5
排序这三个: 3, 5, 7
中位数 = 5 = a[hi]
安排后:
a[lo]=3 (最小) a[mid]=5 (中位数) a[hi]=7 (最大)
把中位数 a[mid] 换到 a[lo]:
a[lo]=5 (基准) ... a[hi]=7 (天然右哨兵!)
额外好处:a[lo] 和 a[hi] 已经在正确的一侧
a[lo]=3 < 5(基准), a[hi]=7 > 5(基准)
完整代码
#include <iostream>
#include <vector>
#include <algorithm>
#include <random>
#include <chrono>
#include <ctime>
// 交换并返回三个元素的中位数的下标
// 同时将三个元素排好序:a[lo] <= a[mid] <= a[hi]
int medianOf3(std::vector<int>& a, int lo, int mid, int hi) {
// 用最多3次比较将 a[lo], a[mid], a[hi] 排好序
if (a[mid] < a[lo]) std::swap(a[lo], a[mid]);
if (a[hi] < a[lo]) std::swap(a[lo], a[hi]);
if (a[hi] < a[mid]) std::swap(a[mid], a[hi]);
// 现在 a[lo] <= a[mid] <= a[hi]
// 中位数在 mid 位置
return mid;
}
// 使用三取样中位数的分区函数
int partition(std::vector<int>& a, int lo, int hi) {
int mid = lo + (hi - lo) / 2;
// 选三个元素的中位数,排好序后 a[lo]<=a[mid]<=a[hi]
medianOf3(a, lo, mid, hi);
// 把中位数(a[mid])换到 a[lo+1],作为基准
// a[lo] 已经 <= 中位数,充当左哨兵
// a[hi] 已经 >= 中位数,充当右哨兵
std::swap(a[mid], a[lo + 1]);
int v = a[lo + 1]; // 基准
int i = lo + 1; // 从 lo+1 开始(a[lo] 已经在正确一侧)
int j = hi; // 从 hi 开始(a[hi] 已经在正确一侧,还当右哨兵)
while (true) {
while (a[++i] < v) { } // a[hi] >= v, 右哨兵保证不越界
while (a[--j] > v) { } // a[lo] <= v, 左哨兵保证不越界
if (i >= j) break;
std::swap(a[i], a[j]);
}
// 把基准放到正确位置
std::swap(a[lo + 1], a[j]);
return j;
}
// 插入排序:用于小数组
void insertionSort(std::vector<int>& a, int lo, int hi) {
for (int i = lo + 1; i <= hi; i++) {
int key = a[i];
int j = i - 1;
while (j >= lo && a[j] > key) {
a[j + 1] = a[j];
j--;
}
a[j + 1] = key;
}
}
void quicksort(std::vector<int>& a, int lo, int hi) {
// 小数组切换到插入排序
if (hi - lo + 1 <= 10) {
insertionSort(a, lo, hi);
return;
}
int j = partition(a, lo, hi);
quicksort(a, lo, j - 1);
quicksort(a, j + 1, hi);
}
void sort(std::vector<int>& a) {
int n = a.size();
if (n <= 1) return;
// 随机打乱
std::mt19937 rng(static_cast<unsigned>(time(nullptr)));
std::shuffle(a.begin(), a.end(), rng);
quicksort(a, 0, n - 1);
}
// ============= 倍增测试 =============
// 标准快排(用于对比)
namespace StandardQS {
int partitionStd(std::vector<int>& a, int lo, int hi) {
int v = a[lo];
int i = lo, j = hi + 1;
while (true) {
while (a[++i] < v) if (i == hi) break;
while (v < a[--j]) if (j == lo) break;
if (i >= j) break;
std::swap(a[i], a[j]);
}
std::swap(a[lo], a[j]);
return j;
}
void quicksort(std::vector<int>& a, int lo, int hi) {
if (hi <= lo) return;
int j = partitionStd(a, lo, hi);
quicksort(a, lo, j - 1);
quicksort(a, j + 1, hi);
}
void sort(std::vector<int>& a) {
std::mt19937 rng(42);
std::shuffle(a.begin(), a.end(), rng);
quicksort(a, 0, a.size() - 1);
}
}
int main() {
std::cout << "=== 三取样切分 vs 标准快排 倍增测试 ===" << std::endl;
std::cout << " N 标准快排(ms) 三取样(ms) 加速比" << std::endl;
std::mt19937 rng(12345);
for (int n = 1000; n <= 1000000; n *= 2) {
// 生成随机数组
std::vector<int> base(n);
for (int i = 0; i < n; i++) base[i] = rng() % (n * 10);
// 测试标准快排
std::vector<int> a1 = base;
auto t1 = std::chrono::high_resolution_clock::now();
StandardQS::sort(a1);
auto t2 = std::chrono::high_resolution_clock::now();
double stdTime = std::chrono::duration<double, std::milli>(t2 - t1).count();
// 测试三取样快排
std::vector<int> a2 = base;
auto t3 = std::chrono::high_resolution_clock::now();
sort(a2);
auto t4 = std::chrono::high_resolution_clock::now();
double med3Time = std::chrono::duration<double, std::milli>(t4 - t3).count();
double ratio = stdTime / med3Time;
std::cout << " " << n
<< "\t\t" << stdTime
<< "\t\t" << med3Time
<< "\t\t" << ratio
<< std::endl;
}
return 0;
}
性能分析
三取样切分相比标准快排的改进:
| 指标 | 标准快排 | 三取样切分 |
|---|---|---|
| 平均比较次数 | ∼1.39NlgN\sim 1.39 N \lg N∼1.39NlgN | ∼1.19NlgN\sim 1.19 N \lg N∼1.19NlgN |
| 交换次数 | ∼0.23NlgN\sim 0.23 N \lg N∼0.23NlgN | ∼0.15NlgN\sim 0.15 N \lg N∼0.15NlgN |
| 最坏情况概率 | 较高 | 大幅降低 |
三取样使比较次数减少约 14%,交换次数减少约 35%。
2.3.19 五取样切分 (Median-of-5 Partitioning)
题目理解
把三取样推广到五取样:从子数组中随机抽5个元素,取它们的中位数作为基准。这样基准更接近真正的中位数,分区更均衡。
关键技巧:把5个样本排好序后,最小的两个放到子数组的左端,最大的两个放到右端,中位数参与分区。这样左端的两个元素天然 <= 基准,右端的两个元素天然 >= 基准,它们不需要参与分区扫描。
额外挑战:设计一个最多用6次比较就能找到5个元素中位数的算法(标准是7次,但6次是可能的)。
五取样中位数排序网络
找5个元素的中位数,最少需要6次比较。思路:
5个元素: a, b, c, d, e
第1步: 比较 (a,b) 和 (c,d),让小的在前
if a>b then swap(a,b) // 比较1
if c>d then swap(c,d) // 比较2
现在 a<=b, c<=d
第2步: 比较两个"小的":a 和 c
if a>c then swap(a,c), swap(b,d) // 比较3
现在 a <= c <= d, a <= b
a 一定不是中位数(最多第2大),可以丢弃
第3步: e 和 b 比较
if b>e then swap(b,e) // 比较4 (不一定需要swap)
... 继续比较得到中位数
(完整的6次比较网络见代码)
完整代码
#include <iostream>
#include <vector>
#include <algorithm>
#include <random>
#include <chrono>
#include <ctime>
// 辅助: 条件交换,保证 a[i] <= a[j]
// 每次调用消耗1次比较
inline void compSwap(std::vector<int>& a, int i, int j) {
if (a[i] > a[j]) std::swap(a[i], a[j]);
}
// 找5个元素的中位数,使用6次比较
// 输入: 数组 a 中的5个位置 p, q, r, s, t
// 输出: 将中位数放到 r 位置,较小的两个在 p,q 较大的两个在 s,t
// 返回中位数所在位置(即 r)
//
// 使用的排序网络(6次比较):
// 1. compSwap(p, q) -> a[p] <= a[q]
// 2. compSwap(s, t) -> a[s] <= a[t]
// 3. compSwap(p, s) -> a[p] 是 {p,q,s,t} 中最小的
// 同时 swap(q, t) -> 保持配对关系
// 4. compSwap(q, r) -> 处理第5个元素
// 5. compSwap(r, s) -> 缩小候选范围
// 6. compSwap(q, r) -> 最终中位数在 r
int medianOf5(std::vector<int>& a, int p, int q, int r, int s, int t) {
// 比较1: 让 a[p] <= a[q]
if (a[p] > a[q]) std::swap(a[p], a[q]);
// 比较2: 让 a[s] <= a[t]
if (a[s] > a[t]) std::swap(a[s], a[t]);
// 比较3: 比较两对中较小的: a[p] vs a[s]
// 让 a[p] <= a[s],同时交换配对元素保持关系
if (a[p] > a[s]) {
std::swap(a[p], a[s]);
std::swap(a[q], a[t]);
}
// 此时: a[p] 是最小的,可以排除
// a[p] <= a[q], a[p] <= a[s] <= a[t]
// 比较4: 把第5个元素 r 和 q 比较
if (a[q] > a[r]) std::swap(a[q], a[r]);
// a[q] <= a[r]
// 比较5: a[r] vs a[s]
if (a[r] > a[s]) std::swap(a[r], a[s]);
// a[r] <= a[s]
// 比较6: a[q] vs a[r]
if (a[q] > a[r]) std::swap(a[q], a[r]);
// 中位数现在在 r 位置
return r;
}
// 使用五取样中位数的分区函数
int partition(std::vector<int>& a, int lo, int hi) {
int n = hi - lo + 1;
if (n < 5) {
// 元素不足5个,退化为简单选基准
int v = a[lo];
int i = lo, j = hi + 1;
while (true) {
while (a[++i] < v) if (i == hi) break;
while (v < a[--j]) if (j == lo) break;
if (i >= j) break;
std::swap(a[i], a[j]);
}
std::swap(a[lo], a[j]);
return j;
}
// 随机选5个不同的位置
// 简化:选 lo, lo+n/4, lo+n/2, lo+3*n/4, hi
int p = lo;
int q = lo + n / 4;
int r = lo + n / 2;
int s = lo + 3 * n / 4;
int t = hi;
// 找中位数,结果放在 r 位置
// 同时: a[p],a[q] <= 中位数 <= a[s],a[t]
medianOf5(a, p, q, r, s, t);
// 把中位数换到 lo+2 位置作为基准
// 把两个小元素放到 lo, lo+1 (它们已经 <= 基准)
// 把两个大元素放到 hi-1, hi (它们已经 >= 基准)
// 注意: p=lo, t=hi 已经在正确位置
// a[q] 放到 lo+1
std::swap(a[lo + 1], a[q]);
// a[r] (中位数) 放到 lo+2
std::swap(a[lo + 2], a[r]);
// a[s] 放到 hi-1
std::swap(a[hi - 1], a[s]);
// 现在:
// a[lo], a[lo+1] <= 基准 = a[lo+2]
// a[hi-1], a[hi] >= 基准
int v = a[lo + 2]; // 基准
int i = lo + 2; // 从基准位置开始向右
int j = hi - 1; // 从 hi-1 开始向左
while (true) {
while (a[++i] < v) { } // a[hi-1] >= v, 右哨兵
while (a[--j] > v) { } // a[lo+1] <= v, 左哨兵
if (i >= j) break;
std::swap(a[i], a[j]);
}
// 把基准放到正确位置
std::swap(a[lo + 2], a[j]);
return j;
}
// 插入排序:用于小数组
void insertionSort(std::vector<int>& a, int lo, int hi) {
for (int i = lo + 1; i <= hi; i++) {
int key = a[i];
int j = i - 1;
while (j >= lo && a[j] > key) {
a[j + 1] = a[j];
j--;
}
a[j + 1] = key;
}
}
void quicksort(std::vector<int>& a, int lo, int hi) {
if (hi - lo + 1 <= 15) {
insertionSort(a, lo, hi);
return;
}
int j = partition(a, lo, hi);
quicksort(a, lo, j - 1);
quicksort(a, j + 1, hi);
}
void sort(std::vector<int>& a) {
int n = a.size();
if (n <= 1) return;
std::mt19937 rng(static_cast<unsigned>(time(nullptr)));
std::shuffle(a.begin(), a.end(), rng);
quicksort(a, 0, n - 1);
}
// ============= 标准快排和三取样快排(用于对比) =============
namespace StandardQS {
void sort(std::vector<int>& a) {
std::mt19937 rng(42);
std::shuffle(a.begin(), a.end(), rng);
// 简单递归快排
std::function<void(int, int)> qs = [&](int lo, int hi) {
if (hi <= lo) return;
int v = a[lo], i = lo, j = hi + 1;
while (true) {
while (a[++i] < v) if (i == hi) break;
while (v < a[--j]) if (j == lo) break;
if (i >= j) break;
std::swap(a[i], a[j]);
}
std::swap(a[lo], a[j]);
qs(lo, j - 1);
qs(j + 1, hi);
};
qs(0, a.size() - 1);
}
}
namespace Med3QS {
int med3(std::vector<int>& a, int lo, int mid, int hi) {
if (a[mid] < a[lo]) std::swap(a[lo], a[mid]);
if (a[hi] < a[lo]) std::swap(a[lo], a[hi]);
if (a[hi] < a[mid]) std::swap(a[mid], a[hi]);
return mid;
}
void sort(std::vector<int>& a) {
std::mt19937 rng(42);
std::shuffle(a.begin(), a.end(), rng);
std::function<void(int, int)> qs = [&](int lo, int hi) {
if (hi - lo + 1 <= 10) {
for (int i = lo + 1; i <= hi; i++) {
int key = a[i]; int j = i - 1;
while (j >= lo && a[j] > key) { a[j+1] = a[j]; j--; }
a[j+1] = key;
}
return;
}
int mid = lo + (hi - lo) / 2;
med3(a, lo, mid, hi);
std::swap(a[mid], a[lo + 1]);
int v = a[lo + 1], i = lo + 1, j = hi;
while (true) {
while (a[++i] < v) {}
while (a[--j] > v) {}
if (i >= j) break;
std::swap(a[i], a[j]);
}
std::swap(a[lo + 1], a[j]);
qs(lo, j - 1);
qs(j + 1, hi);
};
qs(0, a.size() - 1);
}
}
int main() {
std::cout << "=== 五取样切分 vs 三取样 vs 标准快排 倍增测试 ===" << std::endl;
std::cout << " N\t\t标准(ms)\t三取样(ms)\t五取样(ms)" << std::endl;
std::mt19937 rng(12345);
for (int n = 10000; n <= 1000000; n *= 2) {
std::vector<int> base(n);
for (int i = 0; i < n; i++) base[i] = rng() % (n * 10);
// 标准快排
std::vector<int> a1 = base;
auto t1 = std::chrono::high_resolution_clock::now();
StandardQS::sort(a1);
auto t2 = std::chrono::high_resolution_clock::now();
double stdTime = std::chrono::duration<double, std::milli>(t2 - t1).count();
// 三取样
std::vector<int> a2 = base;
auto t3 = std::chrono::high_resolution_clock::now();
Med3QS::sort(a2);
auto t4 = std::chrono::high_resolution_clock::now();
double med3Time = std::chrono::duration<double, std::milli>(t4 - t3).count();
// 五取样
std::vector<int> a3 = base;
auto t5 = std::chrono::high_resolution_clock::now();
sort(a3);
auto t6 = std::chrono::high_resolution_clock::now();
double med5Time = std::chrono::duration<double, std::milli>(t6 - t5).count();
std::cout << " " << n
<< "\t\t" << stdTime
<< "\t\t" << med3Time
<< "\t\t" << med5Time
<< std::endl;
}
// 验证正确性
std::cout << "\n=== 正确性验证 ===" << std::endl;
std::vector<int> test = {64, 25, 12, 22, 11, 90, 55, 33, 77, 44, 88, 1, 99, 7, 36};
std::cout << "排序前: ";
for (int x : test) std::cout << x << " ";
std::cout << std::endl;
sort(test);
std::cout << "排序后: ";
for (int x : test) std::cout << x << " ";
std::cout << std::endl;
bool ok = true;
for (size_t i = 1; i < test.size(); i++) {
if (test[i] < test[i - 1]) { ok = false; break; }
}
std::cout << (ok ? "验证通过!" : "排序错误!") << std::endl;
return 0;
}
三种方案对比
比较次数系数 (乘以 N lg N)
+-------+-------+-------+
| 标准 | 三取样 | 五取样 |
+---------------+-------+-------+-------+
| 比较次数 | 1.39 | 1.19 | 1.09 |
| 交换次数 | 0.23 | 0.15 | 0.12 |
| 额外开销 | 无 | 3比较 | 6比较 |
+---------------+-------+-------+-------+
各方法适用场景总结
总结
| 练习 | 核心思想 | 时间复杂度 |
|---|---|---|
| 2.3.15 螺母螺栓 | 用快排分治思想解决受限比较问题 | 期望 O(NlogN)O(N \log N)O(NlogN) |
| 2.3.16 最佳情况 | 逆向构造使每次分区对半分的数组 | O(NlogN)O(N \log N)O(NlogN) |
| 2.3.17 哨兵 | 利用天然边界消除内循环边界检查 | O(NlogN)O(N \log N)O(NlogN),常数更小 |
| 2.3.18 三取样 | 取3个元素中位数作基准,分区更均衡 | ∼1.19NlgN\sim 1.19 N \lg N∼1.19NlgN 次比较 |
| 2.3.19 五取样 | 取5个元素中位数作基准,进一步优化 | ∼1.09NlgN\sim 1.09 N \lg N∼1.09NlgN 次比较 |
快速排序的这些优化技巧在实际工程中非常重要。Java 的 Arrays.sort() (基本类型) 和 C++ 的 std::sort() 都综合使用了三取样/五取样中位数、小数组插入排序、尾递归优化等技术。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐
所有评论(0)