Skip to main content

core/slice/sort/stable/
quicksort.rs

1//! This module contains a stable quicksort and partition implementation.
2
3use crate::mem::MaybeUninit;
4use crate::slice::sort::shared::FreezeMarker;
5use crate::slice::sort::shared::pivot::choose_pivot;
6use crate::slice::sort::shared::smallsort::StableSmallSortTypeImpl;
7use crate::{intrinsics, ptr};
8
9/// Sorts `v` recursively using quicksort.
10/// `scratch.len()` must be at least `max(v.len() - v.len() / 2, SMALL_SORT_GENERAL_SCRATCH_LEN)`
11/// otherwise the implementation may abort.
12///
13/// `limit` when initialized with `c*log(v.len())` for some c ensures we do not
14/// overflow the stack or go quadratic.
15#[inline(never)]
16pub fn quicksort<T, F: FnMut(&T, &T) -> bool>(
17    mut v: &mut [T],
18    scratch: &mut [MaybeUninit<T>],
19    mut limit: u32,
20    mut left_ancestor_pivot: Option<&T>,
21    is_less: &mut F,
22) {
23    loop {
24        let len = v.len();
25
26        if len <= T::small_sort_threshold() {
27            T::small_sort(v, scratch, is_less);
28            return;
29        }
30
31        if limit == 0 {
32            // We have had too many bad pivots, switch to O(n log n) fallback
33            // algorithm. In our case that is driftsort in eager mode.
34            crate::slice::sort::stable::drift::sort(v, scratch, true, is_less);
35            return;
36        }
37        limit -= 1;
38
39        let pivot_pos = choose_pivot(v, is_less);
40
41        // SAFETY: We only access the temporary copy for Freeze types, otherwise
42        // self-modifications via `is_less` would not be observed and this would
43        // be unsound. Our temporary copy does not escape this scope.
44        // We use `MaybeUninit` to avoid re-tag issues. FIXME: use `MaybeDangling`.
45        let pivot_copy = unsafe { ptr::read((&raw const v[pivot_pos]).cast::<MaybeUninit<T>>()) };
46        let pivot_ref =
47            // SAFETY: We created the value in an init state.
48            (!has_direct_interior_mutability::<T>()).then_some(unsafe { &*pivot_copy.as_ptr() });
49
50        // We choose a pivot, and check if this pivot is equal to our left
51        // ancestor. If true, we do a partition putting equal elements on the
52        // left and do not recurse on it. This gives O(n log k) sorting for k
53        // distinct values, a strategy borrowed from pdqsort. For types with
54        // interior mutability we can't soundly create a temporary copy of the
55        // ancestor pivot, and use left_partition_len == 0 as our method for
56        // detecting when we re-use a pivot, which means we do at most three
57        // partition operations with pivot p instead of the optimal two.
58        let mut perform_equal_partition = false;
59        if let Some(la_pivot) = left_ancestor_pivot {
60            perform_equal_partition = !is_less(la_pivot, &v[pivot_pos]);
61        }
62
63        let mut left_partition_len = 0;
64        if !perform_equal_partition {
65            left_partition_len = stable_partition(v, scratch, pivot_pos, false, is_less);
66            perform_equal_partition = left_partition_len == 0;
67        }
68
69        if perform_equal_partition {
70            let mid_eq = stable_partition(v, scratch, pivot_pos, true, &mut |a, b| !is_less(b, a));
71            v = &mut v[mid_eq..];
72            left_ancestor_pivot = None;
73            continue;
74        }
75
76        // Process left side with the next loop iter, right side with recursion.
77        let (left, right) = v.split_at_mut(left_partition_len);
78        quicksort(right, scratch, limit, pivot_ref, is_less);
79        v = left;
80    }
81}
82
83/// Partitions `v` using pivot `p = v[pivot_pos]` and returns the number of
84/// elements less than `p`. The relative order of elements that compare < p and
85/// those that compare >= p is preserved - it is a stable partition.
86///
87/// If `is_less` is not a strict total order or panics, `scratch.len() < v.len()`,
88/// or `pivot_pos >= v.len()`, the result and `v`'s state is sound but unspecified.
89fn stable_partition<T, F: FnMut(&T, &T) -> bool>(
90    v: &mut [T],
91    scratch: &mut [MaybeUninit<T>],
92    pivot_pos: usize,
93    pivot_goes_left: bool,
94    is_less: &mut F,
95) -> usize {
96    let len = v.len();
97
98    if intrinsics::unlikely(scratch.len() < len || pivot_pos >= len) {
99        core::intrinsics::abort()
100    }
101
102    let v_base = v.as_ptr();
103    let scratch_base = scratch.as_mut_ptr().cast_init();
104
105    // The core idea is to write the values that compare as less-than to the left
106    // side of `scratch`, while the values that compared as greater or equal than
107    // `v[pivot_pos]` go to the right side of `scratch` in reverse. See
108    // PartitionState for details.
109
110    // SAFETY: see individual comments.
111    unsafe {
112        // SAFETY: we made sure the scratch has length >= len and that pivot_pos
113        // is in-bounds. v and scratch are disjoint slices.
114        let pivot = v_base.add(pivot_pos);
115        let mut state = PartitionState::new(v_base, scratch_base, len);
116
117        let mut pivot_in_scratch = ptr::null_mut();
118        let mut loop_end_pos = pivot_pos;
119
120        // SAFETY: this loop is equivalent to calling state.partition_one
121        // exactly len times.
122        loop {
123            // Ideally the outer loop won't be unrolled, to save binary size,
124            // but we do want the inner loop to be unrolled for small types, as
125            // this gave significant performance boosts in benchmarks. Unrolling
126            // through for _ in 0..UNROLL_LEN { .. } instead of manually improves
127            // compile times but has a ~10-20% performance penalty on opt-level=s.
128            if const { size_of::<T>() <= 16 } {
129                const UNROLL_LEN: usize = 4;
130                let unroll_end = v_base.add(loop_end_pos.saturating_sub(UNROLL_LEN - 1));
131                while state.scan < unroll_end {
132                    state.partition_one(is_less(&*state.scan, &*pivot));
133                    state.partition_one(is_less(&*state.scan, &*pivot));
134                    state.partition_one(is_less(&*state.scan, &*pivot));
135                    state.partition_one(is_less(&*state.scan, &*pivot));
136                }
137            }
138
139            let loop_end = v_base.add(loop_end_pos);
140            while state.scan < loop_end {
141                state.partition_one(is_less(&*state.scan, &*pivot));
142            }
143
144            if loop_end_pos == len {
145                break;
146            }
147
148            // We avoid comparing pivot with itself, as this could create deadlocks for
149            // certain comparison operators. We also store its location later for later.
150            pivot_in_scratch = state.partition_one(pivot_goes_left);
151
152            loop_end_pos = len;
153        }
154
155        // `pivot` must be copied into its correct position again, because a
156        // comparison operator might have modified it.
157        if has_direct_interior_mutability::<T>() {
158            ptr::copy_nonoverlapping(pivot, pivot_in_scratch, 1);
159        }
160
161        // SAFETY: partition_one being called exactly len times guarantees that scratch
162        // is initialized with a permuted copy of `v`, and that num_left <= v.len().
163        // Copying scratch[0..num_left] and scratch[num_left..v.len()] back is thus
164        // sound, as the values in scratch will never be read again, meaning our copies
165        // semantically act as moves, permuting `v`.
166
167        // Copy all the elements < p directly from swap to v.
168        let v_base = v.as_mut_ptr();
169        ptr::copy_nonoverlapping(scratch_base, v_base, state.num_left);
170
171        // Copy the elements >= p in reverse order.
172        for i in 0..len - state.num_left {
173            ptr::copy_nonoverlapping(
174                scratch_base.add(len - 1 - i),
175                v_base.add(state.num_left + i),
176                1,
177            );
178        }
179
180        state.num_left
181    }
182}
183
184struct PartitionState<T> {
185    // The start of the scratch auxiliary memory.
186    scratch_base: *mut T,
187    // The current element that is being looked at, scans left to right through slice.
188    scan: *const T,
189    // Counts the number of elements that went to the left side, also works around:
190    // https://github.com/rust-lang/rust/issues/117128
191    num_left: usize,
192    // Reverse scratch output pointer.
193    scratch_rev: *mut T,
194}
195
196impl<T> PartitionState<T> {
197    /// # Safety
198    ///
199    /// `scan` and `scratch` must point to valid disjoint buffers of length `len`. The
200    /// scan buffer must be initialized.
201    unsafe fn new(scan: *const T, scratch: *mut T, len: usize) -> Self {
202        // SAFETY: See function safety comment.
203        unsafe { Self { scratch_base: scratch, scan, num_left: 0, scratch_rev: scratch.add(len) } }
204    }
205
206    /// Depending on the value of `towards_left` this function will write a value
207    /// to the growing left or right side of the scratch memory. This forms the
208    /// branchless core of the partition.
209    ///
210    /// # Safety
211    ///
212    /// This function may be called at most `len` times. If it is called exactly
213    /// `len` times the scratch buffer then contains a copy of each element from
214    /// the scan buffer exactly once - a permutation, and num_left <= len.
215    unsafe fn partition_one(&mut self, towards_left: bool) -> *mut T {
216        // SAFETY: see individual comments.
217        unsafe {
218            // SAFETY: in-bounds because this function is called at most len times, and thus
219            // right now is incremented at most len - 1 times. Similarly, num_left < len and
220            // num_right < len, where num_right == i - num_left at the start of the ith
221            // iteration (zero-indexed).
222            self.scratch_rev = self.scratch_rev.sub(1);
223
224            // SAFETY: now we have scratch_rev == base + len - (i + 1). This means
225            // scratch_rev + num_left == base + len - 1 - num_right < base + len.
226            let dst_base = if towards_left { self.scratch_base } else { self.scratch_rev };
227            let dst = dst_base.add(self.num_left);
228            ptr::copy_nonoverlapping(self.scan, dst, 1);
229
230            self.num_left += towards_left as usize;
231            self.scan = self.scan.add(1);
232            dst
233        }
234    }
235}
236
237trait IsFreeze {
238    fn is_freeze() -> bool;
239}
240
241impl<T> IsFreeze for T {
242    default fn is_freeze() -> bool {
243        false
244    }
245}
246impl<T: FreezeMarker> IsFreeze for T {
247    fn is_freeze() -> bool {
248        true
249    }
250}
251
252#[must_use]
253fn has_direct_interior_mutability<T>() -> bool {
254    // If a type has interior mutability it may alter itself during comparison
255    // in a way that must be preserved after the sort operation concludes.
256    // Otherwise a type like Mutex<Option<Box<str>>> could lead to double free.
257    !T::is_freeze()
258}