Python 算法教程(6):快速排序详解

前言

快速排序(Quick Sort)是 20 世纪十大算法之一,由 Tony Hoare 于 1960 年提出。它是实际应用中最常用的排序算法,Java 的 Arrays.sort()、C 的 qsort() 都基于快速排序的变体。

一、算法原理深度解析

1.1 分治思想

快速排序的核心是分治法(Divide and Conquer):

  1. 分解 - 选择基准值,将数组分为两部分
  2. 解决 - 递归排序两个子数组
  3. 合并 - 无需合并,原地排序完成

1.2 分区策略(关键!)

数组:[10, 80, 30, 90, 40, 50, 70] 基准:50

L 指针从左找大于基准的数
R 指针从右找小于基准的数
交换 L 和 R 指向的元素
重复直到 L 和 R 相遇

过程演示:
初始:[10, 80, 30, 90, 40, 50, 70]
      L↑                    R↑
L 停在 80(>50),R 停在 50(=50)
交换:[10, 50, 30, 90, 40, 80, 70]
继续:[10, 50, 30, 40, 90, 80, 70]
相遇在 90 位置

最终分区:[10, 30, 40, 50] [90, 80, 70]
          小于 50     基准    大于 50

二、多种代码实现

2.1 简洁版本(适合学习)

def quick_sort_simple(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[len(arr) // 2]
    left = [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quick_sort_simple(left) + middle + quick_sort_simple(right)

# 测试
arr = [10, 80, 30, 90, 40, 50, 70]
print(quick_sort_simple(arr))  # [10, 30, 40, 50, 70, 80, 90]

优点:代码简洁,易于理解
缺点:需要 O(n) 额外空间,不是原地排序

2.2 原地分区版本(推荐)

def quick_sort(arr):
    def partition(left, right):
        if left >= right:
            return
        pivot_idx = (left + right) // 2
        pivot = arr[pivot_idx]
        i, j = left, right
        
        while i <= j:
            while arr[i] < pivot:
                i += 1
            while arr[j] > pivot:
                j -= 1
            if i <= j:
                arr[i], arr[j] = arr[j], arr[i]
                i += 1
                j -= 1
        
        partition(left, j)
        partition(i, right)
    
    partition(0, len(arr) - 1)
    return arr

优点:原地排序,空间 O(log n)
缺点:代码稍复杂

2.3 三路快排(处理重复元素)

def quick_sort_3way(arr):
    def sort_3way(left, right):
        if left >= right:
            return
        pivot = arr[left]
        lt = left      # arr[left...lt] < pivot
        gt = right     # arr[gt...right] > pivot
        i = left + 1   # arr[lt+1...i] = pivot
        
        while i <= gt:
            if arr[i] < pivot:
                arr[lt], arr[i] = arr[i], arr[lt]
                lt += 1
                i += 1
            elif arr[i] > pivot:
                arr[i], arr[gt] = arr[gt], arr[i]
                gt -= 1
            else:
                i += 1
        
        sort_3way(left, lt - 1)
        sort_3way(gt + 1, right)
    
    sort_3way(0, len(arr) - 1)
    return arr

优点:大量重复元素时性能最优
适用场景:数组中有大量重复元素

三、复杂度分析

3.1 时间复杂度

情况 时间复杂度 说明 发生概率
最好 O(n log n) 每次平分 理想情况
平均 O(n log n) 随机数据 常见
最坏 O(n²) 已有序 选择固定基准

3.2 空间复杂度

  • 原地版本:O(log n) - 递归调用栈
  • 简洁版本:O(n) - 需要额外数组

四、性能优化策略

4.1 基准值选择优化

# 三数取中法(避免最坏情况)
def median_of_three(arr, left, right):
    mid = (left + right) // 2
    if arr[left] > arr[mid]:
        arr[left], arr[mid] = arr[mid], arr[left]
    if arr[left] > arr[right]:
        arr[left], arr[right] = arr[right], arr[left]
    if arr[mid] > arr[right]:
        arr[mid], arr[right] = arr[right], arr[mid]
    return mid  # 返回中间值的索引

4.2 小数组使用插入排序

# 当子数组长度 < 10 时,使用插入排序
def insertion_sort_small(arr, left, right):
    for i in range(left + 1, right + 1):
        key = arr[i]
        j = i - 1
        while j >= left and arr[j] > key:
            arr[j + 1] = arr[j]
            j -= 1
        arr[j + 1] = key

五、性能测试

import time
import random

def performance_test():
    sizes = [100, 1000, 5000]
    
    for size in sizes:
        arr = [random.randint(1, 10000) for _ in range(size)]
        
        start = time.time()
        quick_sort(arr.copy())
        end = time.time()
        
        print(f"数组大小:{size}, 耗时:{end-start:.4f}秒")

performance_test()
# 示例输出:
# 数组大小:100, 耗时:0.0002 秒
# 数组大小:1000, 耗时:0.0025 秒
# 数组大小:5000, 耗时:0.0142 秒

六、LeetCode 实战

6.1 第 K 个最大元素

# LeetCode 215. 数组中的第 K 个最大元素
def findKthLargest(nums, k):
    def partition(left, right):
        pivot = nums[right]
        store = left
        for i in range(left, right):
            if nums[i] > pivot:
                nums[store], nums[i] = nums[i], nums[store]
                store += 1
        nums[store], nums[right] = nums[right], nums[store]
        return store
    
    def quick_select(left, right, k_smallest):
        if left == right:
            return nums[left]
        pivot_idx = partition(left, right)
        if k_smallest == pivot_idx:
            return nums[k_smallest]
        elif k_smallest < pivot_idx:
            return quick_select(left, pivot_idx - 1, k_smallest)
        else:
            return quick_select(pivot_idx + 1, right, k_smallest)
    
    return quick_select(0, len(nums) - 1, k - 1)

# 测试
nums = [3,2,1,5,6,4]
k = 2
print(findKthLargest(nums, k))  # 输出:5

七、面试考点

7.1 高频问题

  1. 快排的时间复杂度? - 最好 O(n log n),最坏 O(n²),平均 O(n log n)
  2. 如何避免最坏情况? - 随机选择基准、三数取中
  3. 快排是稳定的吗? - 不稳定,相等元素可能交换
  4. 空间复杂度? - O(log n) 递归栈
  5. 何时退化为 O(n²)? - 已有序且选择端点为基准

7.2 与其他排序对比

排序算法    平均时间    最坏时间    空间    稳定性
快速排序    O(n log n)  O(n²)      O(log n)  不稳定
归并排序    O(n log n)  O(n log n)  O(n)     稳定
堆排序      O(n log n)  O(n log n)  O(1)     不稳定
插入排序    O(n²)       O(n²)      O(1)     稳定

选择建议:
- 一般场景:快速排序
- 要求稳定:归并排序
- 空间紧张:堆排序
- 小规模:插入排序

八、可视化代码

def quick_sort_visual(arr, depth=0):
    indent = "  " * depth
    print(f"{indent}排序:{arr}")
    
    if len(arr) <= 1:
        return arr
    
    pivot = arr[len(arr) // 2]
    left = [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    
    print(f"{indent}基准:{pivot}")
    print(f"{indent}左:{left}, 中:{middle}, 右:{right}")
    
    result = quick_sort_visual(left, depth + 1) + middle + quick_sort_visual(right, depth + 1)
    return result

# 测试
arr = [10, 80, 30, 90, 40, 50, 70]
quick_sort_visual(arr)

总结

快速排序是必须掌握的经典算法:

  • ✅ 理解分治思想
  • ✅ 掌握原地分区实现
  • ✅ 了解优化策略(三数取中、三路快排)
  • ✅ 能够手写代码
  • ✅ 理解时间复杂度分析

练习建议:在 LeetCode 上刷 5-10 道快速排序相关题目,如第 K 大元素、颜色分类等。

下一篇:堆排序 - 利用堆数据结构的排序算法,适合 TopK 问题。

发表回复

后才能评论