diff --git a/library/core/src/array/drain.rs b/library/core/src/array/drain.rs index 1c6137191324c..f73fb6b114beb 100644 --- a/library/core/src/array/drain.rs +++ b/library/core/src/array/drain.rs @@ -1,6 +1,6 @@ use crate::marker::{Destruct, PhantomData}; use crate::mem::{ManuallyDrop, SizedTypeProperties, conjure_zst}; -use crate::ptr::{NonNull, drop_in_place, from_raw_parts_mut, null_mut}; +use crate::ptr::{drop_in_place, slice_from_raw_parts_mut}; impl<'l, 'f, T, U, const N: usize, F: FnMut(T) -> U> Drain<'l, 'f, T, N, F> { /// This function returns a function that lets you index the given array in const. @@ -18,15 +18,8 @@ impl<'l, 'f, T, U, const N: usize, F: FnMut(T) -> U> Drain<'l, 'f, T, N, F> { #[rustc_const_unstable(feature = "array_try_map", issue = "79711")] pub(super) const unsafe fn new(array: &'l mut ManuallyDrop<[T; N]>, f: &'f mut F) -> Self { // dont drop the array, transfers "ownership" to Self - let ptr: NonNull = NonNull::from_mut(array).cast(); - // SAFETY: - // Adding `slice.len()` to the starting pointer gives a pointer - // at the end of `slice`. `end` will never be dereferenced, only checked - // for direct pointer equality with `ptr` to check if the drainer is done. - unsafe { - let end = if T::IS_ZST { null_mut() } else { ptr.as_ptr().add(N) }; - Self { ptr, end, f, l: PhantomData } - } + let end = array.as_mut_ptr_range().end; + Self { end, remaining: N, f, l: PhantomData } } } @@ -35,20 +28,26 @@ impl<'l, 'f, T, U, const N: usize, F: FnMut(T) -> U> Drain<'l, 'f, T, N, F> { #[unstable(feature = "array_try_map", issue = "79711")] pub(super) struct Drain<'l, 'f, T, const N: usize, F> { // FIXME(const-hack): This is essentially a slice::IterMut<'static>, replace when possible. - /// The pointer to the next element to return, or the past-the-end location - /// if the drainer is empty. - /// - /// This address will be used for all ZST elements, never changed. + /// Pointer to the past-the-end element. /// As we "own" this array, we dont need to store any lifetime. - ptr: NonNull, - /// For non-ZSTs, the non-null pointer to the past-the-end element. - /// For ZSTs, this is null. end: *mut T, + /// The number of elements still to be drained. + remaining: usize, f: &'f mut F, l: PhantomData<&'l mut [T; N]>, } +impl Drain<'_, '_, T, N, F> { + /// Returns a pointer to the next element to be drained, or the past-the-end element if there + /// are no remaining elements to be drained. + const fn ptr(&mut self) -> *mut T { + // SAFETY: By the type invariants, self.remaining is always the number of elements prior to + // self.end that are still to be drained. + unsafe { self.end.sub(self.remaining) } + } +} + #[rustc_const_unstable(feature = "array_try_map", issue = "79711")] #[unstable(feature = "array_try_map", issue = "79711")] impl const FnOnce<(usize,)> for &mut Drain<'_, '_, T, N, F> @@ -73,15 +72,14 @@ where &mut self, (_ /* ignore argument */,): (usize,), ) -> Self::Output { + let p = self.ptr(); + // decrement before moving; if `f` panics, we drop the rest. + self.remaining -= 1; if T::IS_ZST { // its UB to call this more than N times, so returning more ZSTs is valid. // SAFETY: its a ZST? we conjur. (self.f)(unsafe { conjure_zst::() }) } else { - // increment before moving; if `f` panics, we drop the rest. - let p = self.ptr; - // SAFETY: caller guarantees never called more than N times (see `Drain::new`) - self.ptr = unsafe { self.ptr.add(1) }; // SAFETY: we are allowed to move this. (self.f)(unsafe { p.read() }) } @@ -91,18 +89,9 @@ where #[unstable(feature = "array_try_map", issue = "79711")] impl const Drop for Drain<'_, '_, T, N, F> { fn drop(&mut self) { - if !T::IS_ZST { - // SAFETY: we cant read more than N elements - let slice = unsafe { - from_raw_parts_mut::<[T]>( - self.ptr.as_ptr(), - // SAFETY: `start <= end` - self.end.offset_from_unsigned(self.ptr.as_ptr()), - ) - }; + let slice = slice_from_raw_parts_mut(self.ptr(), self.remaining); - // SAFETY: By the type invariant, we're allowed to drop all these. (we own it, after all) - unsafe { drop_in_place(slice) } - } + // SAFETY: By the type invariant, we're allowed to drop all these. (we own it, after all) + unsafe { drop_in_place(slice) } } } diff --git a/library/coretests/tests/array.rs b/library/coretests/tests/array.rs index 2b4429092e98b..86c8128d214e2 100644 --- a/library/coretests/tests/array.rs +++ b/library/coretests/tests/array.rs @@ -313,26 +313,34 @@ fn array_map() { #[test] #[cfg_attr(not(panic = "unwind"), ignore = "test requires unwinding support")] fn array_map_drop_safety() { - static DROPPED: AtomicUsize = AtomicUsize::new(0); - struct DropCounter; - impl Drop for DropCounter { + static OLD_DROPPED: AtomicUsize = AtomicUsize::new(0); + static NEW_DROPPED: AtomicUsize = AtomicUsize::new(0); + struct OldDropCounter; + struct NewDropCounter; + impl Drop for OldDropCounter { fn drop(&mut self) { - DROPPED.fetch_add(1, Ordering::SeqCst); + OLD_DROPPED.fetch_add(1, Ordering::SeqCst); + } + } + impl Drop for NewDropCounter { + fn drop(&mut self) { + NEW_DROPPED.fetch_add(1, Ordering::SeqCst); } } let num_to_create = 5; let success = std::panic::catch_unwind(|| { - let items = [0; 10]; + let items = [const { OldDropCounter }; 8]; let mut nth = 0; let _ = items.map(|_| { assert!(nth < num_to_create); nth += 1; - DropCounter + NewDropCounter }); }); assert!(success.is_err()); - assert_eq!(DROPPED.load(Ordering::SeqCst), num_to_create); + assert_eq!(OLD_DROPPED.load(Ordering::SeqCst), 8); + assert_eq!(NEW_DROPPED.load(Ordering::SeqCst), num_to_create); } #[test]