From 8fd23971776461ad26cf461fdd3d22f62ed2fec4 Mon Sep 17 00:00:00 2001 From: Kitsu Date: Fri, 7 Oct 2022 10:02:17 +0300 Subject: [PATCH 1/2] Fix take_array unsoundness --- src/drop.rs | 6 +++--- src/from.rs | 36 ++++++++++++++++++++++++++++++------ src/lib.rs | 34 +++++++++++++++++++++++++++------- 3 files changed, 60 insertions(+), 16 deletions(-) diff --git a/src/drop.rs b/src/drop.rs index 1f4775b..f7fa343 100644 --- a/src/drop.rs +++ b/src/drop.rs @@ -89,7 +89,7 @@ mod tests { buf.push(CounterGuard::new(&mut cnt)); assert_eq!(cnt, 3); - let arr: [_; 3] = buf.into(); + let arr: [CounterGuard; 3] = buf.try_into().unwrap(); assert_eq!(cnt, 3); std::mem::drop(arr); @@ -112,11 +112,11 @@ mod tests { buf.push(CounterGuard::new(&mut cnt)); assert_eq!(cnt, 3); - let arr = buf.take_array(); + let arr = unsafe { buf.take_array() }; assert_eq!(buf.len(), 0); assert_eq!(cnt, 3); std::mem::drop(arr); assert_eq!(cnt, 0); } -} \ No newline at end of file +} diff --git a/src/from.rs b/src/from.rs index 60aa715..e56bd54 100644 --- a/src/from.rs +++ b/src/from.rs @@ -1,9 +1,33 @@ -use std::convert::From; use crate::LocalVec; +use std::convert::{From, TryFrom}; +use std::mem::MaybeUninit; -impl From> for [T; N] { +impl From> for [MaybeUninit; N] { fn from(mut local_vec: LocalVec) -> Self { - local_vec.take_array() + local_vec.take_uninit_array() + } +} + +#[derive(Debug)] +pub struct NotFull; + +impl std::fmt::Display for NotFull { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("array is not full") + } +} + +impl std::error::Error for NotFull {} + +impl TryFrom> for [T; N] { + type Error = NotFull; + + fn try_from(mut local_vec: LocalVec) -> Result { + if !local_vec.is_full() { + return Err(NotFull); + } + // Safety: checked for is_full before. + Ok(unsafe { local_vec.take_array() }) } } @@ -13,16 +37,16 @@ mod test { #[test] fn test_into_array() { - let mut vec = LocalVec::<_, 4>::new(); + let mut vec = LocalVec::::new(); vec.push(0); vec.push(1); vec.push(2); vec.push(3); - let arr: [_; 4] = vec.into(); + let arr: [u32; 4] = vec.try_into().unwrap(); assert_eq!(arr[0], 0); assert_eq!(arr[1], 1); assert_eq!(arr[2], 2); assert_eq!(arr[3], 3); } -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index b1a10c9..74bde22 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -127,13 +127,23 @@ impl LocalVecImpl { } #[must_use = "consider using clear() instead"] - /// steal the elements stored - pub fn take_array(&mut self) -> [T; N] { - let arr: [T; N] = unsafe { + /// Takes possibly uninitialized elements. + pub fn take_uninit_array(&mut self) -> [MaybeUninit; N] { + unsafe { self.set_len(0); - std::mem::transmute_copy(&self.buf) - }; - arr + std::ptr::read(&self.buf) + } + } + + #[must_use = "consider using clear() instead"] + /// Takes the stored elements. + /// + /// # Safety + /// The container must be full. + pub unsafe fn take_array(&mut self) -> [T; N] { + debug_assert!(self.is_full()); + let prev = std::mem::replace(self, LocalVecImpl::new()); + std::mem::transmute_copy(&prev.buf) } #[inline] @@ -312,10 +322,20 @@ mod tests { #[test] fn test_take_array() { + let arr = [2; 4]; + let mut vec = LocalVecImpl::<_, 4>::from_array(arr); + assert_eq!(vec.len(), 4); + let taken = unsafe { vec.take_array() }; + assert_eq!(arr, taken); + assert_eq!(vec.len(), 0); + } + + #[test] + fn test_take_uninit_array() { let arr = [7; 4]; let mut vec = LocalVecImpl::<_, 6>::from_array(arr); assert_eq!(vec.len(), 4); - let _ = vec.take_array(); + let _ = vec.take_uninit_array(); assert_eq!(vec.len(), 0); } From 6b5e1661dbf39dcde4cba4762e910f316e867de5 Mon Sep 17 00:00:00 2001 From: Kitsu Date: Fri, 7 Oct 2022 11:05:10 +0300 Subject: [PATCH 2/2] Use Rc for test_drop_* counter --- src/drop.rs | 101 +++++++++++++++++++++++++--------------------------- 1 file changed, 49 insertions(+), 52 deletions(-) diff --git a/src/drop.rs b/src/drop.rs index f7fa343..2bd5858 100644 --- a/src/drop.rs +++ b/src/drop.rs @@ -10,113 +10,110 @@ impl Drop for LocalVec { #[cfg(test)] mod tests { - struct CounterGuard(*mut u8); + use crate::LocalVec; + use std::rc::Rc; + + #[derive(Clone)] + struct Counter(Rc<()>); - impl<'a> CounterGuard { - pub fn new(cnt: &'a mut u8) -> CounterGuard { - *cnt += 1; - CounterGuard(cnt as *mut u8) + impl Counter { + fn new() -> Self { + Self(Rc::new(())) } - } - impl<'a> Drop for CounterGuard { - fn drop(&mut self) { - unsafe { - *self.0 -= 1; - } + fn count(&self) -> usize { + Rc::strong_count(&self.0) } } - use crate::LocalVec; - #[test] fn test_drop() { - let mut cnt = 0u8; + let cnt = Counter::new(); let mut buf = LocalVec::<_, 3>::new(); - assert_eq!(cnt, 0); + assert_eq!(cnt.count(), 1); - buf.push(CounterGuard::new(&mut cnt)); - assert_eq!(cnt, 1); + buf.push(cnt.clone()); + assert_eq!(cnt.count(), 2); - buf.push(CounterGuard::new(&mut cnt)); - assert_eq!(cnt, 2); + buf.push(cnt.clone()); + assert_eq!(cnt.count(), 3); - buf.push(CounterGuard::new(&mut cnt)); - assert_eq!(cnt, 3); + buf.push(cnt.clone()); + assert_eq!(cnt.count(), 4); std::mem::drop(buf); - assert_eq!(cnt, 0); + assert_eq!(cnt.count(), 1); } #[test] fn test_drop_after_set_len() { - let mut cnt = 0u8; + let cnt = Counter::new(); let mut buf = LocalVec::<_, 3>::new(); - assert_eq!(cnt, 0); + assert_eq!(cnt.count(), 1); - buf.push(CounterGuard::new(&mut cnt)); - assert_eq!(cnt, 1); + buf.push(cnt.clone()); + assert_eq!(cnt.count(), 2); - buf.push(CounterGuard::new(&mut cnt)); - assert_eq!(cnt, 2); + buf.push(cnt.clone()); + assert_eq!(cnt.count(), 3); - buf.push(CounterGuard::new(&mut cnt)); - assert_eq!(cnt, 3); + buf.push(cnt.clone()); + assert_eq!(cnt.count(), 4); unsafe { buf.set_len(1); } std::mem::drop(buf); - assert_eq!(cnt, 2); + assert_eq!(cnt.count(), 3); } #[test] fn test_drop_after_into_array() { - let mut cnt = 0u8; + let cnt = Counter::new(); let mut buf = LocalVec::<_, 3>::new(); - assert_eq!(cnt, 0); + assert_eq!(cnt.count(), 1); - buf.push(CounterGuard::new(&mut cnt)); - assert_eq!(cnt, 1); + buf.push(cnt.clone()); + assert_eq!(cnt.count(), 2); - buf.push(CounterGuard::new(&mut cnt)); - assert_eq!(cnt, 2); + buf.push(cnt.clone()); + assert_eq!(cnt.count(), 3); - buf.push(CounterGuard::new(&mut cnt)); - assert_eq!(cnt, 3); + buf.push(cnt.clone()); + assert_eq!(cnt.count(), 4); - let arr: [CounterGuard; 3] = buf.try_into().unwrap(); - assert_eq!(cnt, 3); + let arr: [Counter; 3] = buf.try_into().unwrap(); + assert_eq!(cnt.count(), 4); std::mem::drop(arr); - assert_eq!(cnt, 0); + assert_eq!(cnt.count(), 1); } #[test] fn test_drop_after_take_array() { - let mut cnt = 0u8; + let cnt = Counter::new(); let mut buf = LocalVec::<_, 3>::new(); - assert_eq!(cnt, 0); + assert_eq!(cnt.count(), 1); - buf.push(CounterGuard::new(&mut cnt)); - assert_eq!(cnt, 1); + buf.push(cnt.clone()); + assert_eq!(cnt.count(), 2); - buf.push(CounterGuard::new(&mut cnt)); - assert_eq!(cnt, 2); + buf.push(cnt.clone()); + assert_eq!(cnt.count(), 3); - buf.push(CounterGuard::new(&mut cnt)); - assert_eq!(cnt, 3); + buf.push(cnt.clone()); + assert_eq!(cnt.count(), 4); let arr = unsafe { buf.take_array() }; assert_eq!(buf.len(), 0); - assert_eq!(cnt, 3); + assert_eq!(cnt.count(), 4); std::mem::drop(arr); - assert_eq!(cnt, 0); + assert_eq!(cnt.count(), 1); } }