diff --git a/src/Windows/libraries/formatters/include/m/formatters/FILETIME.h b/src/Windows/libraries/formatters/include/m/formatters/FILETIME.h index c86afcf5..4f1a5ef8 100644 --- a/src/Windows/libraries/formatters/include/m/formatters/FILETIME.h +++ b/src/Windows/libraries/formatters/include/m/formatters/FILETIME.h @@ -102,8 +102,6 @@ struct std::formatter st.wDay, st.wHour, st.wMinute, - st.wHour, - st.wMinute, st.wSecond, st.wMilliseconds); else @@ -116,8 +114,6 @@ struct std::formatter st.wDay, st.wHour, st.wMinute, - st.wHour, - st.wMinute, st.wSecond, st.wMilliseconds); } diff --git a/src/Windows/libraries/formatters/test/test_FILETIME.cpp b/src/Windows/libraries/formatters/test/test_FILETIME.cpp index 73a5a83b..f419b24a 100644 --- a/src/Windows/libraries/formatters/test/test_FILETIME.cpp +++ b/src/Windows/libraries/formatters/test/test_FILETIME.cpp @@ -22,7 +22,7 @@ constexpr FILETIME negative_ft{0x0, 0x81e00000}; TEST(FILETIME, lbasic) { auto s = std::format(L"{}", fmtFILETIME(ft1)); - EXPECT_EQ(s, L"{ Tu 2029-02-20 23:41:23.041 }"s); + EXPECT_EQ(s, L"{ Tu 2029-02-20 23:41:22.111 }"s); } TEST(FILETIME, lnegative) @@ -34,7 +34,7 @@ TEST(FILETIME, lnegative) TEST(FILETIME, basic) { auto s = std::format("{}", fmtFILETIME(ft1)); - EXPECT_EQ(s, "{ Tu 2029-02-20 23:41:23.041 }"s); + EXPECT_EQ(s, "{ Tu 2029-02-20 23:41:22.111 }"s); } TEST(FILETIME, negative) diff --git a/src/Windows/libraries/formatters/test/test_FILE_NOTIFY_EXTENDED_INFORMATION.cpp b/src/Windows/libraries/formatters/test/test_FILE_NOTIFY_EXTENDED_INFORMATION.cpp index 0e457163..02687c44 100644 --- a/src/Windows/libraries/formatters/test/test_FILE_NOTIFY_EXTENDED_INFORMATION.cpp +++ b/src/Windows/libraries/formatters/test/test_FILE_NOTIFY_EXTENDED_INFORMATION.cpp @@ -108,5 +108,5 @@ TEST(FILE_NOTIFY_EXTENDED_INFORMATION, first) auto s = std::format(L"{}", *p); EXPECT_EQ( s, - L"{ NextEntryOffset: 0, Action: FILE_ACTION_ADDED, CreationTime: { Su 2043-05-31 11:40:11.040 }, LastModificationTime: { Su 2043-05-31 11:40:11.040 }, LastChangeTime: { Su 2043-05-31 11:40:11.040 }, LastAccessTime: { Mo 1601-01-01 00:00:00.000 }, AllocatedLength: 0, FileSize: 0, FileAttributes: 0, ReparsePointTag: 0, FileId: 0, ParentFileId: 0, FileName: \"README.TXT\" }"s); + L"{ NextEntryOffset: 0, Action: FILE_ACTION_ADDED, CreationTime: { Su 2043-05-31 11:40:44.848 }, LastModificationTime: { Su 2043-05-31 11:40:44.848 }, LastChangeTime: { Su 2043-05-31 11:40:44.848 }, LastAccessTime: { Mo 1601-01-01 00:00:00.000 }, AllocatedLength: 0, FileSize: 0, FileAttributes: 0, ReparsePointTag: 0, FileId: 0, ParentFileId: 0, FileName: \"README.TXT\" }"s); } diff --git a/src/include/m/utility/string_inserter.h b/src/include/m/utility/string_inserter.h index f4d0265f..b24203f4 100644 --- a/src/include/m/utility/string_inserter.h +++ b/src/include/m/utility/string_inserter.h @@ -10,6 +10,7 @@ #include #include +#include #include namespace m @@ -72,8 +73,11 @@ namespace m [[nodiscard]] constexpr value_type& operator*() noexcept { + // m::math::add fail-stops on overflow rather than wrapping m_index + 1 to a + // small size: a wrap would shrink the string and leave the subsequent index + // out of bounds (a wild write). if (m_string->size() <= m_index) - m_string->resize(m_index + 1); + m_string->resize(m::math::add(m_index, size_type{1}, size_type{})); return (*m_string)[m_index]; } diff --git a/src/include/m/utility/stringish.h b/src/include/m/utility/stringish.h index 895ce482..335caedf 100644 --- a/src/include/m/utility/stringish.h +++ b/src/include/m/utility/stringish.h @@ -18,38 +18,44 @@ namespace m requires(character) class basic_sstring; + // These predicates use std::decay_t (not remove_cvref_t) on the queried type so that a + // string-literal argument, whose deduced type is an array (e.g. char const[N]) when bound + // to a forwarding reference, decays to its pointer form (char const*) before being matched. + // For every other accepted type decay_t is equivalent to remove_cvref_t, so this only adds + // the array forms (mapped to their already-accepted pointer types); it never accepts + // anything new beyond that. template concept stringish = - std::same_as, std::basic_string> || - std::same_as, std::basic_string_view> || - std::same_as, m::basic_sstring> || - std::same_as, TChar*> || std::same_as, TChar const*>; + std::same_as, std::basic_string> || + std::same_as, std::basic_string_view> || + std::same_as, m::basic_sstring> || + std::same_as, TChar*> || std::same_as, TChar const*>; template concept any_stringish = - std::same_as, std::basic_string> || - std::same_as, std::basic_string> || - std::same_as, std::basic_string> || - std::same_as, std::basic_string> || - std::same_as, std::basic_string> || - std::same_as, std::basic_string_view> || - std::same_as, std::basic_string_view> || - std::same_as, std::basic_string_view> || - std::same_as, std::basic_string_view> || - std::same_as, std::basic_string_view> || - std::same_as, m::basic_sstring> || - std::same_as, m::basic_sstring> || - std::same_as, m::basic_sstring> || - std::same_as, m::basic_sstring> || - std::same_as, m::basic_sstring> || - std::same_as, char*> || std::same_as, wchar_t*> || - std::same_as, char8_t*> || std::same_as, char16_t*> || - std::same_as, char32_t*> || - std::same_as, char const*> || - std::same_as, wchar_t const*> || - std::same_as, char8_t const*> || - std::same_as, char16_t const*> || - std::same_as, char32_t const*>; + std::same_as, std::basic_string> || + std::same_as, std::basic_string> || + std::same_as, std::basic_string> || + std::same_as, std::basic_string> || + std::same_as, std::basic_string> || + std::same_as, std::basic_string_view> || + std::same_as, std::basic_string_view> || + std::same_as, std::basic_string_view> || + std::same_as, std::basic_string_view> || + std::same_as, std::basic_string_view> || + std::same_as, m::basic_sstring> || + std::same_as, m::basic_sstring> || + std::same_as, m::basic_sstring> || + std::same_as, m::basic_sstring> || + std::same_as, m::basic_sstring> || + std::same_as, char*> || std::same_as, wchar_t*> || + std::same_as, char8_t*> || std::same_as, char16_t*> || + std::same_as, char32_t*> || + std::same_as, char const*> || + std::same_as, wchar_t const*> || + std::same_as, char8_t const*> || + std::same_as, char16_t const*> || + std::same_as, char32_t const*>; template struct stringish_char_type; @@ -57,11 +63,11 @@ namespace m template struct stringish_char_type< T, - std::enable_if_t, std::basic_string> || - std::same_as, std::basic_string_view> || - std::same_as, m::basic_sstring> || - std::same_as, char*> || - std::same_as, char const*>>> + std::enable_if_t, std::basic_string> || + std::same_as, std::basic_string_view> || + std::same_as, m::basic_sstring> || + std::same_as, char*> || + std::same_as, char const*>>> { using type = char; }; @@ -69,11 +75,11 @@ namespace m template struct stringish_char_type< T, - std::enable_if_t, std::basic_string> || - std::same_as, std::basic_string_view> || - std::same_as, m::basic_sstring> || - std::same_as, wchar_t*> || - std::same_as, wchar_t const*>>> + std::enable_if_t, std::basic_string> || + std::same_as, std::basic_string_view> || + std::same_as, m::basic_sstring> || + std::same_as, wchar_t*> || + std::same_as, wchar_t const*>>> { using type = wchar_t; }; @@ -81,11 +87,11 @@ namespace m template struct stringish_char_type< T, - std::enable_if_t, std::basic_string> || - std::same_as, std::basic_string_view> || - std::same_as, m::basic_sstring> || - std::same_as, char8_t*> || - std::same_as, char8_t const*>>> + std::enable_if_t, std::basic_string> || + std::same_as, std::basic_string_view> || + std::same_as, m::basic_sstring> || + std::same_as, char8_t*> || + std::same_as, char8_t const*>>> { using type = char8_t; }; @@ -93,11 +99,11 @@ namespace m template struct stringish_char_type< T, - std::enable_if_t, std::basic_string> || - std::same_as, std::basic_string_view> || - std::same_as, m::basic_sstring> || - std::same_as, char16_t*> || - std::same_as, char16_t const*>>> + std::enable_if_t, std::basic_string> || + std::same_as, std::basic_string_view> || + std::same_as, m::basic_sstring> || + std::same_as, char16_t*> || + std::same_as, char16_t const*>>> { using type = char16_t; }; @@ -105,11 +111,11 @@ namespace m template struct stringish_char_type< T, - std::enable_if_t, std::basic_string> || - std::same_as, std::basic_string_view> || - std::same_as, m::basic_sstring> || - std::same_as, char32_t*> || - std::same_as, char32_t const*>>> + std::enable_if_t, std::basic_string> || + std::same_as, std::basic_string_view> || + std::same_as, m::basic_sstring> || + std::same_as, char32_t*> || + std::same_as, char32_t const*>>> { using type = char32_t; }; diff --git a/src/libraries/arefc_ptr/include/m/arefc_ptr/arefc_ptr.h b/src/libraries/arefc_ptr/include/m/arefc_ptr/arefc_ptr.h index c8be4181..2a2502bd 100644 --- a/src/libraries/arefc_ptr/include/m/arefc_ptr/arefc_ptr.h +++ b/src/libraries/arefc_ptr/include/m/arefc_ptr/arefc_ptr.h @@ -58,7 +58,13 @@ namespace m constexpr void operator()(aggregate* aggptr) const noexcept { - aggregate::deallocate(aggptr); + // The unique_ptr returned by aggregate::allocate() owns raw storage + // whose object has NOT been constructed yet (construction happens later + // in mmake_arefc_ex via the caller's callback). So on cleanup we must + // only deallocate, never run the object's destructor on unconstructed + // storage. Once the object is constructed, ownership is released to the + // arefc_ptr and destruction becomes the refcount's responsibility. + aggregate::deallocate(aggptr, /* do_destroy */ false); } }; @@ -500,7 +506,7 @@ namespace m ~arefc_ptr() { reset(); } arefc_ptr& - operator=(arefc_ptr& other) noexcept + operator=(arefc_ptr const& other) noexcept { if (this != &other) { @@ -516,11 +522,11 @@ namespace m arefc_ptr& operator=(arefc_ptr const& other) noexcept { - if (this != &other) - { - reset(); - put(other.addref()); - } + // No self-assignment guard: a different specialization arefc_ptr can never + // be the same object as *this, and comparing the unrelated pointer types would + // be ill-formed. (When U == T the non-template copy-assignment is selected.) + reset(); + put(other.addref()); return *this; } @@ -548,10 +554,9 @@ namespace m } void - reset(T* ptr_in = nullptr) noexcept + reset() noexcept { - auto const ptr = m_ptr.exchange(increment_ref(ptr_in), std::memory_order_acq_rel); - decrement_ref(ptr); + reset(nullptr); } T& @@ -588,47 +593,20 @@ namespace m return v; } - bool - compare_exchange_strong(arefc_ptr& expected, arefc_ptr const& desired) noexcept - { - // The trick here is to not mess up the reference counting! - // - // it would seem trivial to just "pass through" the m_ptr values but that's - // only part of the story. - // - - T* e = expected.get(); - T* d = desired.get(); - - T* old_e = e; // save a copy so we don't have to re-load - - // Pre-increment d's refcount so that if the CAS succeeds, m_ptr holds - // a valid reference to d without a window where the refcount is zero. - increment_ref(d); - - if (m_ptr.compare_exchange_strong(e, d, std::memory_order_acq_rel)) - { - M_INTERNAL_ERROR_CHECK(e == old_e); - - // Account for the fact that m_ptr no longer refers to `e` - decrement_ref(e); - return true; - } - - // The CAS failed: m_ptr still holds its current value (now captured in `e`). - // `d` was pre-incremented above but was never stored in m_ptr, so we must - // undo that increment to avoid a permanent ref-count leak on `desired`. - decrement_ref(d); - - // Update `expected` to reflect the actual current value of m_ptr. - expected.reset(e); - - return false; - } - private: constexpr arefc_ptr(T* ptr) noexcept: m_ptr(ptr) {} + // Raw-pointer reset: `ptr_in` must already point just past a control area + // (i.e. be a pointer obtained from an arefc_ptr-managed object) or be null. + // Private because passing an arbitrary raw pointer is unsafe; external callers + // use the no-arg reset(). + void + reset(T* ptr_in) noexcept + { + auto const ptr = m_ptr.exchange(increment_ref(ptr_in), std::memory_order_acq_rel); + decrement_ref(ptr); + } + arefc_ptr_impl::control_area_t* get_control_area() const { @@ -726,9 +704,13 @@ namespace m auto a = aggregate_type::allocate(extra_bytes_required); auto const object_span = a->get_object_byte_span(); + + // If `fn` throws, `a`'s deleter deallocates the raw storage without running + // a destructor on the never-constructed object (RAII cleanup). auto const ptr = std::invoke(std::forward(fn), object_span, std::forward(args)...); + // Construction succeeded: hand ownership to the arefc_ptr. a.release(); arefc_ptr retval(ptr); return retval; diff --git a/src/libraries/arefc_ptr/test/test_arefc_ptr.cpp b/src/libraries/arefc_ptr/test/test_arefc_ptr.cpp index 72415624..243796a5 100644 --- a/src/libraries/arefc_ptr/test/test_arefc_ptr.cpp +++ b/src/libraries/arefc_ptr/test/test_arefc_ptr.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -69,6 +70,17 @@ namespace int value; }; + // Constructor throws; destructor counts calls. Used to verify that a failed + // construction inside mmake_arefc_ex does NOT run the (never-constructed) + // object's destructor. + struct ThrowingCtor + { + inline static int s_dtor_calls = 0; + + ThrowingCtor() { throw std::runtime_error("boom"); } + ~ThrowingCtor() { ++s_dtor_calls; } + }; + } // namespace // ============================================================================ @@ -278,6 +290,20 @@ TEST(AreFcPtr_CopyAssign, OverwritesExisting) EXPECT_EQ(LifetimeTracker::s_live_count, 0); } +TEST(AreFcPtr_CopyAssign, FromConstSource) +{ + // Copy-assignment must bind to a const source. + auto p1 = m::mmake_arefc(); + p1->value = 5; + m::arefc_ptr const& cref = p1; + + m::arefc_ptr p2; + p2 = cref; // must compile and share ownership + + EXPECT_EQ(p2.get(), p1.get()); + EXPECT_EQ(p2->value, 5); +} + // ============================================================================ // Move assignment // ============================================================================ @@ -367,6 +393,19 @@ TEST(AreFcPtr_TypeCoercion, ToBaseKeepsObjectAlive) } } +TEST(AreFcPtr_TypeCoercion, ConvertingAssignFromDerived) +{ + // Exercises the converting operator=(arefc_ptr const&) (U != T). This template + // must compile (it previously had an ill-formed self-check comparing unrelated + // pointer types) and must share ownership through the base pointer. + auto derived = m::mmake_arefc(); + m::arefc_ptr base; + base = derived; // converting copy-assignment + + EXPECT_EQ(base.get(), static_cast(derived.get())); + EXPECT_NE(base.get(), nullptr); +} + // ============================================================================ // mmake_arefc_ex — extra bytes // ============================================================================ @@ -388,6 +427,23 @@ TEST(AreFcPtr_MakeEx, FnIsCalledAndObjectIsAccessible) EXPECT_EQ(p->value, 42); } +TEST(AreFcPtr_MakeEx, ThrowingConstructorDoesNotDestroyUnconstructed) +{ + // If the constructing callback throws, the object was never constructed, so + // mmake_arefc_ex must deallocate WITHOUT running the object's destructor. + ThrowingCtor::s_dtor_calls = 0; + + EXPECT_THROW( + { + auto p = m::mmake_arefc_ex( + 0, [](m::byte_span s) -> ThrowingCtor* { return ::new (s.data()) ThrowingCtor(); }); + (void)p; + }, + std::runtime_error); + + EXPECT_EQ(ThrowingCtor::s_dtor_calls, 0); +} + // ============================================================================ // Over-aligned type — exercises big_control_area path // ============================================================================ @@ -411,110 +467,6 @@ TEST(AreFcPtr_OverAligned, CopySharesObjectAndLifetime) EXPECT_EQ(p2->m_x, 100); } -// ============================================================================ -// compare_exchange_strong — success case -// ============================================================================ - -TEST(AreFcPtr_CAS, SuccessUpdatesTarget) -{ - auto p1 = m::mmake_arefc(); - p1->value = 1; - auto p2 = m::mmake_arefc(); - p2->value = 2; - - m::arefc_ptr target = p1; - m::arefc_ptr expected = p1; - - // target == expected, so CAS should swap target to p2. - bool ok = target.compare_exchange_strong(expected, p2); - - EXPECT_TRUE(ok); - EXPECT_EQ(target.get(), p2.get()); - // On success, `expected` is not modified. - EXPECT_EQ(expected.get(), p1.get()); -} - -TEST(AreFcPtr_CAS, SuccessNoRefLeak) -{ - // Verify that a successful CAS leaves every object with the correct - // final refcount — none should linger when their last holder is dropped. - LifetimeTracker::s_live_count = 0; - { - auto p1 = m::mmake_arefc(1); - auto p2 = m::mmake_arefc(2); - EXPECT_EQ(LifetimeTracker::s_live_count, 2); - - m::arefc_ptr target = p1; - m::arefc_ptr expected = p1; - - bool ok = target.compare_exchange_strong(expected, p2); - EXPECT_TRUE(ok); - - target.reset(); - expected.reset(); - EXPECT_EQ(LifetimeTracker::s_live_count, 2); // p1 and p2 still held by locals - } - EXPECT_EQ(LifetimeTracker::s_live_count, 0); -} - -// ============================================================================ -// compare_exchange_strong — failure case -// ============================================================================ - -TEST(AreFcPtr_CAS, FailureUpdatesExpected) -{ - auto p1 = m::mmake_arefc(); - p1->value = 1; - auto p2 = m::mmake_arefc(); - p2->value = 2; - auto p3 = m::mmake_arefc(); - p3->value = 3; - - m::arefc_ptr target = p3; // holds p3 - m::arefc_ptr expected = p1; // but we think it holds p1 — mismatch - - bool ok = target.compare_exchange_strong(expected, p2); - - EXPECT_FALSE(ok); - EXPECT_EQ(target.get(), p3.get()); // target unchanged - EXPECT_EQ(expected.get(), p3.get()); // expected updated to actual current value -} - -TEST(AreFcPtr_CAS, FailureNoRefLeakOnDesired) -{ - // This test exposes a ref-count leak bug that was present on the CAS failure - // path: increment_ref(desired) was called unconditionally before the CAS, - // but if the CAS failed, the increment was never balanced. - // - // After the fix (decrement_ref(d) on the failure path), `desired`'s refcount - // must return to exactly 1 (held only by the local `p2` variable) so that - // when `p2` drops at the end of the scope, the object is destroyed. - LifetimeTracker::s_live_count = 0; - { - auto p1 = m::mmake_arefc(1); - auto p2 = m::mmake_arefc(2); // this is `desired` - auto p3 = m::mmake_arefc(3); - EXPECT_EQ(LifetimeTracker::s_live_count, 3); - - m::arefc_ptr target = p3; - m::arefc_ptr expected = p1; - - bool ok = target.compare_exchange_strong(expected, p2); - EXPECT_FALSE(ok); - - // Drop the CAS temporaries so they hold no extra references. - target.reset(); - expected.reset(); - - // Still 3 live objects — p1, p2, p3 each held by one local variable. - EXPECT_EQ(LifetimeTracker::s_live_count, 3); - } - // All locals dropped — all 3 objects must be destroyed. - // If the ref-count leak is present, p2 will not be destroyed here and - // s_live_count will remain 1. - EXPECT_EQ(LifetimeTracker::s_live_count, 0); -} - // ============================================================================ // Thread safety — concurrent copies and drops must not corrupt the ref count // ============================================================================ diff --git a/src/libraries/const_string/include/m/const_string/const_string.h b/src/libraries/const_string/include/m/const_string/const_string.h index e2553e60..c5b549e0 100644 --- a/src/libraries/const_string/include/m/const_string/const_string.h +++ b/src/libraries/const_string/include/m/const_string/const_string.h @@ -95,7 +95,7 @@ namespace m std::strong_ordering operator<=>(basic_const_string const& r) const { - return operator<=>(this->view(), r.view()); + return this->view() <=> r.view(); } private: @@ -277,6 +277,15 @@ namespace m using u16const_string = basic_const_string; using u32const_string = basic_const_string; + // get_buffer_ptr() locates the trailing character buffer at (&m_size + 1), + // which only yields the correct address when m_size is the sole data member + // (i.e. the object is exactly one std::size_t in size). Lock that invariant. + static_assert(sizeof(const_string) == sizeof(std::size_t)); + static_assert(sizeof(wconst_string) == sizeof(std::size_t)); + static_assert(sizeof(u8const_string) == sizeof(std::size_t)); + static_assert(sizeof(u16const_string) == sizeof(std::size_t)); + static_assert(sizeof(u32const_string) == sizeof(std::size_t)); + template m::arefc_ptr make_const_string(StringishT&& str) diff --git a/src/libraries/debugging/include/m/debugging/dbg_format.h b/src/libraries/debugging/include/m/debugging/dbg_format.h index 111a6ce5..d583e0d9 100644 --- a/src/libraries/debugging/include/m/debugging/dbg_format.h +++ b/src/libraries/debugging/include/m/debugging/dbg_format.h @@ -87,10 +87,15 @@ namespace m template class output_debug_string_iter { - static constexpr inline std::array long_line_chars{{'.', '.', '.', '\n'}}; + // Includes a trailing null so that copying long_line_chars in full + // (size() elements) writes the ellipsis, newline, and terminator + // into the buffer. long_line_suffix is a view of just the visible + // characters (excluding the null) used for the limit computation. + static constexpr inline std::array long_line_chars{ + {'.', '.', '.', '\n', '\0'}}; static constexpr inline std::basic_string_view long_line_suffix = - std::basic_string_view(long_line_chars.data(), long_line_chars.size()); + std::basic_string_view(long_line_chars.data(), long_line_chars.size() - 1); static constexpr inline size_t limit = length - (long_line_suffix.size() + 1); @@ -119,8 +124,8 @@ namespace m { // If we're at the end of line, add the ellipsis, call // OutputDebugStringW() and restart. - std::copy_n(long_line_suffix.begin(), - long_line_suffix.size() + 1, + std::copy_n(long_line_chars.begin(), + long_line_chars.size(), &_buffer->_array[_index + 1]); write_to_debugger(&_buffer->_array[0]); @@ -142,9 +147,9 @@ namespace m { // If we're at the end of line, add the ellipsis, call // OutputDebugStringW() and restart. - std::copy_n(long_line_suffix.begin(), - long_line_suffix.size() + 1, - &_buffer->_array[_index]); + std::copy_n(long_line_chars.begin(), + long_line_chars.size(), + &_buffer->_array[_index + 1]); write_to_debugger(&_buffer->_array[0]); diff --git a/src/libraries/inplace_vector/include/m/inplace_vector/inplace_vector.h b/src/libraries/inplace_vector/include/m/inplace_vector/inplace_vector.h index 2d090969..bd1c0e75 100644 --- a/src/libraries/inplace_vector/include/m/inplace_vector/inplace_vector.h +++ b/src/libraries/inplace_vector/include/m/inplace_vector/inplace_vector.h @@ -28,6 +28,7 @@ #include #include +#include #include #include @@ -109,6 +110,29 @@ namespace m { std::construct_at(ptr, std::forward(value)) } -> std::same_as; }; + // Exposition-only synth-three-way (http://eel.is/c++draft/expos.only.entity): + // uses operator<=> when available, otherwise synthesizes a weak_ordering from + // operator< so that element types providing only < and == still order correctly. + inline constexpr auto synth_three_way = [](T const& t, U const& u) + requires requires { + { t < u } -> std::convertible_to; + { u < t } -> std::convertible_to; + } + { + if constexpr (std::three_way_comparable_with) + { + return t <=> u; + } + else + { + if (t < u) + return std::weak_ordering::less; + if (u < t) + return std::weak_ordering::greater; + return std::weak_ordering::equivalent; + } + }; + // Types implementing the `inplace_vector`'s storage namespace storage { @@ -479,12 +503,12 @@ namespace m constexpr reference back() { - return inplace_vector_impl::index(*this, size() - size_type{1}); + return inplace_vector_impl::index(*this, m::math::subtract(size(), 1, std::size_t{})); } constexpr const_reference back() const { - return inplace_vector_impl::index(*this, size() - size_type{1}); + return inplace_vector_impl::index(*this, m::math::subtract(size(), 1, std::size_t{})); } // [containers.sequences.inplace_vector.data], data access @@ -597,7 +621,7 @@ namespace m { M_IV_EXPECT(size() < capacity() && "inplace_vector out-of-memory"); std::construct_at(end(), std::forward(args)...); - unsafe_set_size(size() + size_type{1}); + unsafe_set_size(m::math::add(size(), 1, std::size_t{})); return back(); } @@ -669,7 +693,8 @@ namespace m { if constexpr (std::ranges::sized_range) { - if (size() + std::ranges::size(rnge) > capacity()) [[unlikely]] + if (m::math::add(size(), std::ranges::size(rnge), std::size_t{}) > capacity()) + [[unlikely]] throw std::bad_alloc(); } for (auto&& e: rnge) @@ -702,8 +727,9 @@ namespace m // internal_assert_valid_iterator_pair(first, last); if constexpr (std::random_access_iterator) { - if (size() + static_cast(std::distance(first, last)) > capacity()) - [[unlikely]] + if (m::math::add(size(), + static_cast(std::distance(first, last)), + std::size_t{}) > capacity()) [[unlikely]] throw std::bad_alloc{}; } auto b = end(); @@ -809,7 +835,8 @@ namespace m { internal_unsafe_destroy(std::move(new_first + (last - first), end(), new_first), end()); - unsafe_set_size(size() - static_cast(last - first)); + unsafe_set_size( + m::math::subtract(size(), static_cast(last - first), std::size_t{})); } return new_first; } @@ -837,7 +864,7 @@ namespace m else if (sz > N) [[unlikely]] throw std::bad_alloc{}; else if (sz > size()) - insert(end(), sz - size(), c); + insert(end(), m::math::subtract(sz, size(), std::size_t{}), c); else { internal_unsafe_destroy(begin() + sz, end()); @@ -883,7 +910,7 @@ namespace m { M_IV_EXPECT(size() > 0 && "pop_back from empty inplace_vector!"); internal_unsafe_destroy(end() - 1, end()); - unsafe_set_size(size() - 1); + unsafe_set_size(m::math::subtract(size(), 1, std::size_t{})); } constexpr inplace_vector(const inplace_vector& x) @@ -967,29 +994,18 @@ namespace m insert_range(begin(), il); } - constexpr friend int /*synth-three-way-result*/ + constexpr friend auto operator<=>(inplace_vector const& x, inplace_vector const& y) - { - if (x.size() < y.size()) - return -1; - if (x.size() > y.size()) - return +1; - - bool all_equal = true; - bool all_less = true; - for (size_type idx = 0; idx < x.size(); ++idx) - { - if (x[idx] < y[idx]) - all_equal = false; - if (x[idx] == y[idx]) - all_less = false; + requires requires(T const& a) { + inplace_vector_impl::synth_three_way(a, a); } - - if (all_equal) - return 0; - if (all_less) - return -1; - return 1; + { + return std::lexicographical_compare_three_way( + x.begin(), + x.end(), + y.begin(), + y.end(), + inplace_vector_impl::synth_three_way); } }; diff --git a/src/libraries/inplace_vector/test/test_inplace_vector.cpp b/src/libraries/inplace_vector/test/test_inplace_vector.cpp index 44ccfcf8..4fa67003 100644 --- a/src/libraries/inplace_vector/test/test_inplace_vector.cpp +++ b/src/libraries/inplace_vector/test/test_inplace_vector.cpp @@ -3,6 +3,7 @@ #include +#include #include #include #include @@ -18,6 +19,52 @@ using namespace std::chrono_literals; using namespace std::string_literals; using namespace std::string_view_literals; +namespace +{ + // A type with only operator< and operator== (no operator<=>); used to exercise + // the synthesized weak_ordering path of inplace_vector's operator<=>. + struct less_only + { + int m_v; + + friend constexpr bool + operator<(less_only const& a, less_only const& b) + { + return a.m_v < b.m_v; + } + + friend constexpr bool + operator==(less_only const& a, less_only const& b) + { + return a.m_v == b.m_v; + } + }; + + // A sized range that reports a dishonestly huge size() while iterating as empty. + // Used to drive inplace_vector::append_range's integer-overflow guard: the size + // check (size() + ranges::size(rnge)) wraps SIZE_MAX before any iteration happens. + struct lying_huge_range + { + static int const* + begin() + { + return nullptr; + } + + static int const* + end() + { + return nullptr; + } + + static std::size_t + size() + { + return std::numeric_limits::max(); + } + }; +} // namespace + struct SomeStruct { int x; @@ -232,3 +279,112 @@ TEST(InplaceVector, CountMoves) EXPECT_EQ(ch[8].m_s, "India"s); } + +TEST(InplaceVector, ThreeWayEqual) +{ + m::inplace_vector a; + m::inplace_vector b; + + a.assign({1, 2, 3}); + b.assign({1, 2, 3}); + + EXPECT_TRUE((a <=> b) == std::strong_ordering::equal); + EXPECT_TRUE(a == b); +} + +TEST(InplaceVector, ThreeWayLessByElement) +{ + m::inplace_vector a; + m::inplace_vector b; + + a.assign({1, 2, 3}); + b.assign({1, 9, 3}); + + // Same size, differ at the second element: a < b. + EXPECT_TRUE((a <=> b) == std::strong_ordering::less); + EXPECT_TRUE(a < b); + EXPECT_TRUE(b > a); + EXPECT_FALSE(a == b); +} + +TEST(InplaceVector, ThreeWayGreaterByElement) +{ + m::inplace_vector a; + m::inplace_vector b; + + a.assign({1, 9, 0}); + b.assign({1, 2, 9}); + + // First difference is at index 1 where 9 > 2, so a > b regardless of later elements. + EXPECT_TRUE((a <=> b) == std::strong_ordering::greater); + EXPECT_TRUE(a > b); +} + +TEST(InplaceVector, ThreeWayPrefixIsLess) +{ + m::inplace_vector a; + m::inplace_vector b; + + a.assign({1, 2}); + b.assign({1, 2, 3}); + + // A proper prefix orders before the longer sequence. + EXPECT_TRUE((a <=> b) == std::strong_ordering::less); + EXPECT_TRUE(a < b); + EXPECT_TRUE(b > a); +} + +TEST(InplaceVector, ThreeWayEmptyOrdering) +{ + m::inplace_vector empty; + m::inplace_vector nonempty; + + nonempty.assign({0}); + + EXPECT_TRUE((empty <=> empty) == std::strong_ordering::equal); + EXPECT_TRUE((empty <=> nonempty) == std::strong_ordering::less); + EXPECT_TRUE(empty < nonempty); +} + +TEST(InplaceVector, ThreeWaySynthFromLessOnly) +{ + // A type with only operator< and operator== (no operator<=>) must still order + // via the synthesized weak_ordering path. + m::inplace_vector a; + m::inplace_vector b; + + a.push_back(less_only{1}); + a.push_back(less_only{2}); + b.push_back(less_only{1}); + b.push_back(less_only{5}); + + EXPECT_TRUE((a <=> b) == std::weak_ordering::less); + EXPECT_TRUE(a < b); +} + +TEST(InplaceVector, AppendRangeIntegerOverflowThrows) +{ + // append_range checks size() + ranges::size(rnge) against capacity(). When the + // sum overflows SIZE_MAX, m::math::add throws std::overflow_error rather than + // wrapping and slipping past the capacity guard (the bug this check fixes). + m::inplace_vector v; + v.push_back(1); // size() == 1, so 1 + SIZE_MAX overflows. + + EXPECT_THROW(v.append_range(lying_huge_range{}), std::overflow_error); + + // The vector must be unchanged after the failed append. + EXPECT_EQ(v.size(), 1u); + EXPECT_EQ(v[0], 1); +} + +TEST(InplaceVector, AppendRangeOverCapacityThrows) +{ + // A range that fits in SIZE_MAX arithmetic but exceeds capacity still throws + // bad_alloc (the ordinary capacity guard, distinct from the overflow guard). + m::inplace_vector v; + + std::array const src{1, 2, 3, 4, 5}; + + EXPECT_THROW(v.append_range(src), std::bad_alloc); +} + diff --git a/src/libraries/math/include/m/math/math.h b/src/libraries/math/include/m/math/math.h index 9f21018a..37d3a2c7 100644 --- a/src/libraries/math/include/m/math/math.h +++ b/src/libraries/math/include/m/math/math.h @@ -247,14 +247,12 @@ namespace m static constexpr ResultT try_negate(common_type_t v) { - constexpr common_type_t biggest_negative_as_positive = - m::try_cast(-((std::numeric_limits::min)() + 1)); - - if (v >= biggest_negative_as_positive) - throw std::overflow_error(std::format( - "m::math overflow: negative value magnitude exceeds maximum representable in target type")); - - return -m::try_cast(v); + // v is the magnitude (>= 0) of a negative result; produce -v in + // ResultT. Delegate to the unsigned->signed unary negate helper, + // which correctly admits the full negative range of ResultT + // (including its most-negative value) rather than stopping one + // short of it. + return unary_safe_math_helper::negate(v); } static constexpr ResultT @@ -706,21 +704,17 @@ namespace m static_cast((std::numeric_limits::max)()) + 1; auto quot = abs_min / promoted_r; - if (quot > static_cast((std::numeric_limits::max)())) - { - throw std::overflow_error("integer overflow"); - } - return m::try_cast(-static_cast(quot)); + // Delegate negation to the unsigned->signed unary helper, + // which admits the full negative range of ResultT + // (including its most-negative value) rather than + // rejecting it one short. + return unary_safe_math_helper::negate(quot); } auto abs_l = static_cast(-promoted_l); auto quot = abs_l / promoted_r; - if (quot > static_cast((std::numeric_limits::max)())) - { - throw std::overflow_error("integer overflow"); - } - return m::try_cast(-static_cast(quot)); + return unary_safe_math_helper::negate(quot); } } }; @@ -957,23 +951,19 @@ namespace m (std::numeric_limits::min)()) + 1)) + 1; auto l_promoted = static_cast(l); auto quot = l_promoted / abs_min; - - if (quot > static_cast((std::numeric_limits::max)())) - { - throw std::overflow_error("integer overflow"); - } - return m::try_cast(-static_cast(quot)); + + // Delegate negation to the unsigned->signed unary helper, + // which admits the full negative range of ResultT + // (including its most-negative value) rather than + // rejecting it one short. + return unary_safe_math_helper::negate(quot); } auto l_promoted = static_cast(l); auto abs_r = static_cast(-static_cast(r)); auto quot = l_promoted / abs_r; - - if (quot > static_cast((std::numeric_limits::max)())) - { - throw std::overflow_error("integer overflow"); - } - return m::try_cast(-static_cast(quot)); + + return unary_safe_math_helper::negate(quot); } else { @@ -1335,6 +1325,164 @@ namespace m }; + // + // Handle (signed [op] signed) -> unsigned + // + template + requires m::is_integral_non_bool_v && m::is_integral_non_bool_v && + m::is_integral_non_bool_v && std::is_signed_v && + std::is_signed_v && std::is_unsigned_v + struct safe_math_helper + { + // |v| as a uintmax_t, correct even when v is intmax_t's most + // negative value (whose magnitude is not representable in intmax_t). + static constexpr uintmax_t + abs_to_unsigned(intmax_t v) + { + if (v == (std::numeric_limits::min)()) + { + static_assert(((-(std::numeric_limits::max)()) - 1) == + (std::numeric_limits::min)()); + return static_cast((std::numeric_limits::max)()) + 1; + } + + return static_cast(v < 0 ? -v : v); + } + + static constexpr ResultT + add(LeftT l, RightT r) + { + // Compute l + r in ℤ, then require the result to be representable + // in the unsigned ResultT (i.e. non-negative and in range). + auto const pl = static_cast(l); + auto const pr = static_cast(r); + + if (pl >= 0 && pr >= 0) + { + // Both non-negative: add in unsigned space with overflow check. + auto const ul = static_cast(pl); + auto const ur = static_cast(pr); + auto const sum = ul + ur; + + if (sum < ul || sum < ur) + throw std::overflow_error("integer overflow"); + + return m::try_cast(sum); + } + + if (pl < 0 && pr < 0) + { + // Sum of two negatives is negative: not representable. + throw std::overflow_error("integer overflow"); + } + + // Mixed signs: the magnitudes partially cancel, so the sum is + // representable in intmax_t without overflow. + intmax_t const sum = pl + pr; + + if (sum < 0) + throw std::overflow_error("integer overflow"); + + return m::try_cast(static_cast(sum)); + } + + static constexpr ResultT + subtract(LeftT l, RightT r) + { + // Compute l - r in ℤ; the result must be non-negative. + auto const pl = static_cast(l); + auto const pr = static_cast(r); + + if (pl < pr) + throw std::overflow_error("integer overflow"); + + // pl >= pr, so the mathematical result is non-negative. Compute + // the magnitude in unsigned space, handling the case where the + // difference exceeds intmax_t's positive range. + uintmax_t result{}; + + if (pr >= 0) + { + // pl >= pr >= 0: both non-negative. + result = static_cast(pl) - static_cast(pr); + } + else + { + auto const abs_r = abs_to_unsigned(pr); + + if (pl >= 0) + { + // result = pl + |r|; guard the unsigned sum. + auto const upl = static_cast(pl); + result = upl + abs_r; + + if (result < upl) + throw std::overflow_error("integer overflow"); + } + else + { + // pl < 0 and pl >= pr, so |pl| <= |r|: result = |r| - |pl|. + result = abs_r - abs_to_unsigned(pl); + } + } + + return m::try_cast(result); + } + + static constexpr ResultT + multiply(LeftT l, RightT r) + { + if (l == 0 || r == 0) + return 0; + + auto const pl = static_cast(l); + auto const pr = static_cast(r); + + // A negative product cannot be represented in an unsigned type. + if ((pl < 0) != (pr < 0)) + throw std::overflow_error(std::format( + "m::math::multiply overflow: negative value cannot be represented in unsigned result type")); + + auto const abs_l = abs_to_unsigned(pl); + auto const abs_r = abs_to_unsigned(pr); + auto const prod = abs_l * abs_r; + + if (prod / abs_l != abs_r || prod / abs_r != abs_l) + throw std::overflow_error("integer overflow"); + + return m::try_cast(prod); + } + + static constexpr ResultT + divide(LeftT l, RightT r) + { + if (r == 0) + throw std::overflow_error(std::format( + "m::math::divide overflow: division by zero")); + + auto const pl = static_cast(l); + auto const pr = static_cast(r); + + auto const abs_l = abs_to_unsigned(pl); + auto const abs_r = abs_to_unsigned(pr); + auto const quot = abs_l / abs_r; + + // Integer division truncates toward zero. If the signs differ the + // mathematical quotient is negative unless it truncates to zero; + // a non-zero negative quotient is not representable in unsigned. + if ((pl < 0) != (pr < 0)) + { + if (quot == 0) + return 0; + + throw std::overflow_error(std::format( + "m::math::divide overflow: negative value cannot be represented in unsigned result type")); + } + + return m::try_cast(quot); + } + }; + // Unary ops, signed -> signed template requires m::is_integral_non_bool_v && m::is_integral_non_bool_v && diff --git a/src/libraries/math/test/CMakeLists.txt b/src/libraries/math/test/CMakeLists.txt index 9396c9d5..cfc8faf7 100644 --- a/src/libraries/math/test/CMakeLists.txt +++ b/src/libraries/math/test/CMakeLists.txt @@ -11,6 +11,7 @@ add_executable( safe_integers_addition_functor.cpp exercise_negation.cpp signed_signed_to_signed.cpp + signed_signed_to_unsigned.cpp signed_unsigned_to_unsigned.cpp test_addition.cpp test_subtraction.cpp @@ -18,6 +19,8 @@ add_executable( test_multiplication.cpp test_negation.cpp test_intermediate_overflow.cpp + test_constexpr.cpp + test_edge_cases.cpp ) target_compile_features(${TEST_EXE_NAME} PUBLIC ${M_CXX_STD}) diff --git a/src/libraries/math/test/signed_signed_to_unsigned.cpp b/src/libraries/math/test/signed_signed_to_unsigned.cpp new file mode 100644 index 00000000..b216f9ea --- /dev/null +++ b/src/libraries/math/test/signed_signed_to_unsigned.cpp @@ -0,0 +1,183 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include + +#include +#include +#include + +#include + +// +// Tests for the (signed [op] signed) -> unsigned specialization. +// +// Per the library model: the operation is performed in ℤ and the result is +// then required to be representable in the unsigned ResultT (i.e. non-negative +// and within range), otherwise an overflow_error is thrown. +// + +// ============================================================================ +// add +// ============================================================================ + +TEST(SignedSignedToUnsigned, AddBothNonNegative) +{ + EXPECT_EQ(m::math::add(int32_t{20}, int32_t{22}, uint32_t{}), 42u); + EXPECT_EQ(m::math::add(int32_t{0}, int32_t{0}, uint32_t{}), 0u); +} + +TEST(SignedSignedToUnsigned, AddMixedSignsNonNegativeResult) +{ + EXPECT_EQ(m::math::add(int32_t{-5}, int32_t{12}, uint32_t{}), 7u); + EXPECT_EQ(m::math::add(int32_t{12}, int32_t{-12}, uint32_t{}), 0u); +} + +TEST(SignedSignedToUnsigned, AddNegativeResultThrows) +{ + EXPECT_THROW(m::math::add(int32_t{-5}, int32_t{2}, uint32_t{}), std::overflow_error); + EXPECT_THROW(m::math::add(int32_t{-1}, int32_t{-1}, uint32_t{}), std::overflow_error); +} + +TEST(SignedSignedToUnsigned, AddWidenAvoidsFalseOverflow) +{ + constexpr auto max32 = (std::numeric_limits::max)(); + // max32 + max32 fits in uint32_t (and certainly in uint64_t). + EXPECT_EQ(m::math::add(max32, max32, uint64_t{}), + static_cast(max32) + static_cast(max32)); + EXPECT_EQ(m::math::add(max32, max32, uint32_t{}), + static_cast(max32) + static_cast(max32)); +} + +TEST(SignedSignedToUnsigned, AddNarrowResultOverflow) +{ + // 200 + 100 = 300 does not fit in uint8_t. + EXPECT_THROW(m::math::add(int32_t{200}, int32_t{100}, uint8_t{}), std::overflow_error); + // 200 + 55 = 255 fits exactly. + EXPECT_EQ(m::math::add(int32_t{200}, int32_t{55}, uint8_t{}), uint8_t{255}); +} + +TEST(SignedSignedToUnsigned, AddMostNegativeOperand) +{ + constexpr auto min64 = (std::numeric_limits::min)(); + constexpr auto max64 = (std::numeric_limits::max)(); + // min64 + max64 = -1 -> negative -> throws. + EXPECT_THROW(m::math::add(min64, max64, uint64_t{}), std::overflow_error); + // min64 + (min64 magnitude as positive is unrepresentable in int64, but the + // sum with a large positive can still be valid): + // min64 + max64 + 1 conceptually = 0; do it as two operands: + EXPECT_EQ(m::math::add(int64_t{min64 + 1}, max64, uint64_t{}), 0u); +} + +// ============================================================================ +// subtract +// ============================================================================ + +TEST(SignedSignedToUnsigned, SubtractBasic) +{ + EXPECT_EQ(m::math::subtract(int32_t{50}, int32_t{8}, uint32_t{}), 42u); + EXPECT_EQ(m::math::subtract(int32_t{8}, int32_t{8}, uint32_t{}), 0u); +} + +TEST(SignedSignedToUnsigned, SubtractNegativeResultThrows) +{ + EXPECT_THROW(m::math::subtract(int32_t{8}, int32_t{50}, uint32_t{}), std::overflow_error); + EXPECT_THROW(m::math::subtract(int32_t{-5}, int32_t{1}, uint32_t{}), std::overflow_error); +} + +TEST(SignedSignedToUnsigned, SubtractMinusNegativeIsAddition) +{ + EXPECT_EQ(m::math::subtract(int32_t{100}, int32_t{-50}, uint32_t{}), 150u); + EXPECT_EQ(m::math::subtract(int32_t{0}, int32_t{-42}, uint32_t{}), 42u); +} + +TEST(SignedSignedToUnsigned, SubtractLargeMagnitudeSpansBeyondIntmax) +{ + constexpr auto max64 = (std::numeric_limits::max)(); + constexpr auto min64 = (std::numeric_limits::min)(); + // max64 - min64 = 2^64 - 1, which fits exactly in uint64_t. + EXPECT_EQ(m::math::subtract(max64, min64, uint64_t{}), + (std::numeric_limits::max)()); +} + +TEST(SignedSignedToUnsigned, SubtractBothNegative) +{ + // -10 - (-30) = 20. + EXPECT_EQ(m::math::subtract(int32_t{-10}, int32_t{-30}, uint32_t{}), 20u); + // -30 - (-10) = -20 -> throws. + EXPECT_THROW(m::math::subtract(int32_t{-30}, int32_t{-10}, uint32_t{}), std::overflow_error); +} + +// ============================================================================ +// multiply +// ============================================================================ + +TEST(SignedSignedToUnsigned, MultiplyBasic) +{ + EXPECT_EQ(m::math::multiply(int32_t{6}, int32_t{7}, uint32_t{}), 42u); + EXPECT_EQ(m::math::multiply(int32_t{0}, int32_t{12345}, uint32_t{}), 0u); +} + +TEST(SignedSignedToUnsigned, MultiplyTwoNegativesIsPositive) +{ + EXPECT_EQ(m::math::multiply(int32_t{-6}, int32_t{-7}, uint32_t{}), 42u); +} + +TEST(SignedSignedToUnsigned, MultiplyOppositeSignsThrows) +{ + EXPECT_THROW(m::math::multiply(int32_t{-6}, int32_t{7}, uint32_t{}), std::overflow_error); + EXPECT_THROW(m::math::multiply(int32_t{6}, int32_t{-7}, uint32_t{}), std::overflow_error); +} + +TEST(SignedSignedToUnsigned, MultiplyOverflowThrows) +{ + constexpr auto max64 = (std::numeric_limits::max)(); + EXPECT_THROW(m::math::multiply(max64, max64, uint64_t{}), std::overflow_error); +} + +TEST(SignedSignedToUnsigned, MultiplyNarrowResultOverflow) +{ + EXPECT_THROW(m::math::multiply(int32_t{20}, int32_t{20}, uint8_t{}), std::overflow_error); + EXPECT_EQ(m::math::multiply(int32_t{15}, int32_t{17}, uint8_t{}), uint8_t{255}); +} + +// ============================================================================ +// divide +// ============================================================================ + +TEST(SignedSignedToUnsigned, DivideBasic) +{ + EXPECT_EQ(m::math::divide(int32_t{84}, int32_t{2}, uint32_t{}), 42u); + EXPECT_EQ(m::math::divide(int32_t{0}, int32_t{5}, uint32_t{}), 0u); +} + +TEST(SignedSignedToUnsigned, DivideTwoNegativesIsPositive) +{ + EXPECT_EQ(m::math::divide(int32_t{-84}, int32_t{-2}, uint32_t{}), 42u); +} + +TEST(SignedSignedToUnsigned, DivideByZeroThrows) +{ + EXPECT_THROW(m::math::divide(int32_t{1}, int32_t{0}, uint32_t{}), std::overflow_error); +} + +TEST(SignedSignedToUnsigned, DivideOppositeSignsNonZeroThrows) +{ + EXPECT_THROW(m::math::divide(int32_t{-84}, int32_t{2}, uint32_t{}), std::overflow_error); + EXPECT_THROW(m::math::divide(int32_t{84}, int32_t{-2}, uint32_t{}), std::overflow_error); +} + +TEST(SignedSignedToUnsigned, DivideOppositeSignsTruncatingToZeroIsOk) +{ + // -3 / 5 truncates toward zero to 0, which is representable. + EXPECT_EQ(m::math::divide(int32_t{-3}, int32_t{5}, uint32_t{}), 0u); + EXPECT_EQ(m::math::divide(int32_t{3}, int32_t{-5}, uint32_t{}), 0u); +} + +TEST(SignedSignedToUnsigned, DivideMostNegativeNumerator) +{ + constexpr auto min64 = (std::numeric_limits::min)(); + // min64 / -1 = 2^63, which fits in uint64_t (unlike the signed-result case). + EXPECT_EQ(m::math::divide(min64, int64_t{-1}, uint64_t{}), + static_cast((std::numeric_limits::max)()) + 1); +} diff --git a/src/libraries/math/test/test_constexpr.cpp b/src/libraries/math/test/test_constexpr.cpp new file mode 100644 index 00000000..41928ee1 --- /dev/null +++ b/src/libraries/math/test/test_constexpr.cpp @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include + +#include +#include +#include + +#include + +// +// The public math operations are all declared constexpr. These tests verify +// that they are genuinely usable in a constant-expression context (and produce +// the correct values there), which is a behavioral guarantee that the +// value-based runtime tests do not exercise. +// + +// ============================================================================ +// add +// ============================================================================ + +static_assert(m::math::add(uint32_t{10}, uint32_t{20}, uint32_t{}) == 30u); +static_assert(m::math::add(int32_t{-5}, int32_t{12}, int32_t{}) == 7); +static_assert(m::math::add(uint32_t{100}, int32_t{-40}, uint32_t{}) == 60u); +static_assert(m::math::add(int32_t{-40}, uint32_t{100}, int32_t{}) == 60); +static_assert(m::math::add(int32_t{-5}, int32_t{12}, uint32_t{}) == 7u); +static_assert(m::math::add(uint8_t{200}, uint8_t{200}, uint16_t{}) == 400u); + +// ============================================================================ +// subtract +// ============================================================================ + +static_assert(m::math::subtract(uint32_t{50}, uint32_t{8}, uint32_t{}) == 42u); +static_assert(m::math::subtract(int32_t{5}, int32_t{8}, int32_t{}) == -3); +static_assert(m::math::subtract(uint32_t{100}, int32_t{-50}, uint32_t{}) == 150u); +static_assert(m::math::subtract(int32_t{50}, int32_t{8}, uint32_t{}) == 42u); + +// ============================================================================ +// multiply +// ============================================================================ + +static_assert(m::math::multiply(uint32_t{6}, uint32_t{7}, uint32_t{}) == 42u); +static_assert(m::math::multiply(int32_t{-4}, int32_t{5}, int32_t{}) == -20); +static_assert(m::math::multiply(int32_t{-6}, int32_t{-7}, uint32_t{}) == 42u); +static_assert(m::math::multiply(uint32_t{0}, int32_t{-9}, uint32_t{}) == 0u); + +// ============================================================================ +// divide +// ============================================================================ + +static_assert(m::math::divide(uint32_t{20}, uint32_t{4}, uint32_t{}) == 5u); +static_assert(m::math::divide(int32_t{-20}, int32_t{4}, int32_t{}) == -5); +static_assert(m::math::divide(uint32_t{100}, int32_t{-10}, int32_t{}) == -10); + +// ============================================================================ +// negate +// ============================================================================ + +static_assert(m::math::negate(int32_t{7}, int32_t{}) == -7); +static_assert(m::math::negate(int32_t{-7}, int32_t{}) == 7); +static_assert(m::math::negate(int32_t{-7}, uint32_t{}) == 7u); +static_assert(m::math::negate(uint32_t{0}, int32_t{}) == 0); + +// ============================================================================ +// Runtime mirrors so the constant-expression guarantees are reported by the +// test runner as an executed test as well. +// ============================================================================ + +TEST(ConstexprEvaluation, ResultsUsableInConstantExpressions) +{ + // Each of these is a compile-time constant; capturing them in constexpr + // locals re-confirms constant-expression usability at this point too. + constexpr auto sum = m::math::add(int32_t{-5}, int32_t{12}, int32_t{}); + constexpr auto diff = m::math::subtract(int32_t{5}, int32_t{8}, int32_t{}); + constexpr auto prod = m::math::multiply(int32_t{-4}, int32_t{5}, int32_t{}); + constexpr auto quot = m::math::divide(int32_t{-20}, int32_t{4}, int32_t{}); + constexpr auto neg = m::math::negate(int32_t{7}, int32_t{}); + + EXPECT_EQ(sum, 7); + EXPECT_EQ(diff, -3); + EXPECT_EQ(prod, -20); + EXPECT_EQ(quot, -5); + EXPECT_EQ(neg, -7); +} diff --git a/src/libraries/math/test/test_edge_cases.cpp b/src/libraries/math/test/test_edge_cases.cpp new file mode 100644 index 00000000..eccd8429 --- /dev/null +++ b/src/libraries/math/test/test_edge_cases.cpp @@ -0,0 +1,250 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include + +#include +#include +#include + +#include + +// +// Additional edge-case coverage that the per-operation test files do not +// exercise: +// +// * producing exactly the most-negative representable result from a +// mixed-sign add/subtract (the "magnitude == |ResultT::min|" branches), +// * the 64-bit INT_MIN special-case branches of the mixed-sign multiply +// specializations (existing tests only reach these with 32-bit inputs, +// which do not take the intmax_t::min code paths), +// * multiplication that narrows to a smaller result type, and +// * negate boundaries at the full 64-bit width. +// + +// ============================================================================ +// Mixed-sign add: most-negative result boundary +// (signed + unsigned -> signed) and (unsigned + signed -> signed) +// ============================================================================ + +TEST(MixedSignMostNegative, SignedUnsignedAddProducesResultMin) +{ + constexpr auto min8 = (std::numeric_limits::min)(); // -128 + + // -200 + 72 = -128 = INT8_MIN: hits the "magnitude == |min|" branch. + EXPECT_EQ(m::math::add(int16_t{-200}, uint16_t{72}, int8_t{}), min8); + + // -201 + 72 = -129: one below INT8_MIN, must overflow. + EXPECT_THROW(m::math::add(int16_t{-201}, uint16_t{72}, int8_t{}), std::overflow_error); + + // -100 + 30 = -70: ordinary negative result just inside range. + EXPECT_EQ(m::math::add(int16_t{-100}, uint16_t{30}, int8_t{}), int8_t{-70}); +} + +TEST(MixedSignMostNegative, UnsignedSignedAddProducesResultMin) +{ + constexpr auto min8 = (std::numeric_limits::min)(); // -128 + + // 72 + (-200) = -128 = INT8_MIN. + EXPECT_EQ(m::math::add(uint16_t{72}, int16_t{-200}, int8_t{}), min8); + + // 72 + (-201) = -129: overflow. + EXPECT_THROW(m::math::add(uint16_t{72}, int16_t{-201}, int8_t{}), std::overflow_error); +} + +// ============================================================================ +// Mixed-sign subtract: most-negative result boundary +// (signed - unsigned -> signed) +// ============================================================================ + +TEST(MixedSignMostNegative, SignedUnsignedSubtractProducesResultMin) +{ + constexpr auto min8 = (std::numeric_limits::min)(); // -128 + + // -28 - 100 = -128 = INT8_MIN. + EXPECT_EQ(m::math::subtract(int16_t{-28}, uint16_t{100}, int8_t{}), min8); + + // -29 - 100 = -129: overflow. + EXPECT_THROW(m::math::subtract(int16_t{-29}, uint16_t{100}, int8_t{}), std::overflow_error); +} + +// ============================================================================ +// 64-bit INT_MIN special-case branches of the mixed-sign multiply helpers. +// The result type matches the 64-bit input width so the most-negative product +// is representable; existing tests only use 32-bit inputs and therefore never +// take the intmax_t::min branch. +// ============================================================================ + +TEST(IntMin64Multiply, SignedUnsignedToSigned) +{ + constexpr auto min64 = (std::numeric_limits::min)(); + + // INT64_MIN * 1 = INT64_MIN: the product equals |INT64_MIN| exactly. + EXPECT_EQ(m::math::multiply(min64, uint64_t{1}, int64_t{}), min64); + + // INT64_MIN * 2 overflows the 64-bit signed range. + EXPECT_THROW(m::math::multiply(min64, uint64_t{2}, int64_t{}), std::overflow_error); + + // INT64_MIN * 0 = 0 (early-out before the special case). + EXPECT_EQ(m::math::multiply(min64, uint64_t{0}, int64_t{}), 0); +} + +TEST(IntMin64Multiply, UnsignedSignedToSigned) +{ + constexpr auto min64 = (std::numeric_limits::min)(); + + // 1 * INT64_MIN = INT64_MIN. + EXPECT_EQ(m::math::multiply(uint64_t{1}, min64, int64_t{}), min64); + + // 2 * INT64_MIN overflows. + EXPECT_THROW(m::math::multiply(uint64_t{2}, min64, int64_t{}), std::overflow_error); + + // 0 * INT64_MIN = 0. + EXPECT_EQ(m::math::multiply(uint64_t{0}, min64, int64_t{}), 0); +} + +TEST(IntMin64Multiply, SignedSignedToSigned) +{ + constexpr auto min64 = (std::numeric_limits::min)(); + + // INT64_MIN * 1 = INT64_MIN exercises the abs(INT_MIN) handling on the + // left operand of the signed*signed helper. + EXPECT_EQ(m::math::multiply(min64, int64_t{1}, int64_t{}), min64); + + // 1 * INT64_MIN = INT64_MIN exercises it on the right operand. + EXPECT_EQ(m::math::multiply(int64_t{1}, min64, int64_t{}), min64); + + // INT64_MIN * -1 overflows (result would be INT64_MAX + 1). + EXPECT_THROW(m::math::multiply(min64, int64_t{-1}, int64_t{}), std::overflow_error); +} + +// ============================================================================ +// Multiplication that narrows to a smaller result type. +// ============================================================================ + +TEST(MultiplyNarrowing, UnsignedToUnsigned) +{ + // 15 * 17 = 255 fits exactly in uint8_t. + EXPECT_EQ(m::math::multiply(uint32_t{15}, uint32_t{17}, uint8_t{}), uint8_t{255}); + + // 20 * 20 = 400 does not fit in uint8_t. + EXPECT_THROW(m::math::multiply(uint32_t{20}, uint32_t{20}, uint8_t{}), std::overflow_error); +} + +TEST(MultiplyNarrowing, SignedToSignedPositive) +{ + // 12 * 10 = 120 fits in int8_t. + EXPECT_EQ(m::math::multiply(int32_t{12}, int32_t{10}, int8_t{}), int8_t{120}); + + // 13 * 10 = 130 exceeds INT8_MAX (127). + EXPECT_THROW(m::math::multiply(int32_t{13}, int32_t{10}, int8_t{}), std::overflow_error); +} + +TEST(MultiplyNarrowing, SignedToSignedNegative) +{ + constexpr auto min8 = (std::numeric_limits::min)(); // -128 + + // -12 * 10 = -120 fits in int8_t. + EXPECT_EQ(m::math::multiply(int32_t{-12}, int32_t{10}, int8_t{}), int8_t{-120}); + + // -16 * 8 = -128 = INT8_MIN fits exactly. + EXPECT_EQ(m::math::multiply(int32_t{-16}, int32_t{8}, int8_t{}), min8); + + // -16 * 9 = -144 is below INT8_MIN. + EXPECT_THROW(m::math::multiply(int32_t{-16}, int32_t{9}, int8_t{}), std::overflow_error); +} + +// ============================================================================ +// 64-bit divide: negative dividend via the (signed / unsigned -> signed) +// general branch (non-min, but with a 64-bit-wide magnitude). +// ============================================================================ + +TEST(Divide64, SignedUnsignedNegativeDividend) +{ + constexpr auto max64 = (std::numeric_limits::max)(); + + // -max64 / 1 = -max64 (largest-magnitude representable negative result). + EXPECT_EQ(m::math::divide(int64_t{-max64}, uint64_t{1}, int64_t{}), -max64); + + // -1000000000000 / 1000000 = -1000000. + EXPECT_EQ(m::math::divide(int64_t{-1000000000000LL}, uint64_t{1000000}, int64_t{}), + int64_t{-1000000}); + + // Truncation toward zero for a small-magnitude negative quotient. + EXPECT_EQ(m::math::divide(int64_t{-7}, uint64_t{2}, int64_t{}), int64_t{-3}); +} + +// ============================================================================ +// Divide producing exactly the most-negative result. Both the +// (signed / unsigned) -> signed and (unsigned / signed) -> signed helpers +// must yield ResultT::min() when the magnitude of the quotient equals +// |ResultT::min()|, rather than rejecting it as overflow. +// ============================================================================ + +TEST(DivideMostNegativeResult, SignedUnsignedToSignedAtMin) +{ + constexpr auto min64 = (std::numeric_limits::min)(); + + // INT64_MIN / 1 = INT64_MIN: magnitude is 2^63 == |INT64_MIN|. + EXPECT_EQ(m::math::divide(min64, uint64_t{1}, int64_t{}), min64); + + // INT64_MIN / 2 = -2^62, comfortably in range. + EXPECT_EQ(m::math::divide(min64, uint64_t{2}, int64_t{}), min64 / 2); + + // Most-negative result narrowed to int8_t: -128 / 1 = -128 is OK, + // but a magnitude one larger must overflow. + EXPECT_EQ(m::math::divide(int16_t{-128}, uint16_t{1}, int8_t{}), + (std::numeric_limits::min)()); + EXPECT_THROW(m::math::divide(int16_t{-129}, uint16_t{1}, int8_t{}), std::overflow_error); +} + +TEST(DivideMostNegativeResult, UnsignedSignedToSignedAtMin) +{ + constexpr auto min64 = (std::numeric_limits::min)(); + constexpr uint64_t abs_min = uint64_t{1} << 63; // |INT64_MIN| = 2^63 + + // 2^63 / -1 = -2^63 = INT64_MIN: hits the divisor == RightT::min-adjacent + // negative branch and produces the most-negative result. + EXPECT_EQ(m::math::divide(abs_min, int64_t{-1}, int64_t{}), min64); + + // 2^63 / INT64_MIN = -1 (divisor is RightT::min special case). + EXPECT_EQ(m::math::divide(abs_min, min64, int64_t{}), int64_t{-1}); + + // Narrowed: 128 / -1 = -128 = INT8_MIN is OK; 129 / -1 = -129 overflows. + EXPECT_EQ(m::math::divide(uint16_t{128}, int16_t{-1}, int8_t{}), + (std::numeric_limits::min)()); + EXPECT_THROW(m::math::divide(uint16_t{129}, int16_t{-1}, int8_t{}), std::overflow_error); +} + + +// ============================================================================ +// negate boundaries at full 64-bit width (unsigned -> signed and signed -> +// signed), which the existing negate tests only probe at narrower widths. +// ============================================================================ + +TEST(Negate64Boundary, UnsignedToSignedAtMin) +{ + constexpr auto min64 = (std::numeric_limits::min)(); + constexpr uint64_t abs_min = uint64_t{1} << 63; // |INT64_MIN| = 2^63 + + // Negating 2^63 yields exactly INT64_MIN. + EXPECT_EQ(m::math::negate(abs_min, int64_t{}), min64); + + // One above the boundary cannot be represented. + EXPECT_THROW(m::math::negate(abs_min + 1, int64_t{}), std::overflow_error); + + // Just inside the boundary. + EXPECT_EQ(m::math::negate(abs_min - 1, int64_t{}), -(std::numeric_limits::max)()); +} + +TEST(Negate64Boundary, SignedToSignedAtMin) +{ + constexpr auto min64 = (std::numeric_limits::min)(); + constexpr auto max64 = (std::numeric_limits::max)(); + + // INT64_MIN cannot be negated within int64_t. + EXPECT_THROW(m::math::negate(min64, int64_t{}), std::overflow_error); + + // INT64_MAX negates to -INT64_MAX. + EXPECT_EQ(m::math::negate(max64, int64_t{}), -max64); +} diff --git a/src/libraries/math/test/test_subtraction.cpp b/src/libraries/math/test/test_subtraction.cpp index 5dee98f0..4b2cc62e 100644 --- a/src/libraries/math/test/test_subtraction.cpp +++ b/src/libraries/math/test/test_subtraction.cpp @@ -92,6 +92,21 @@ TEST(SubtractionUnsignedUnsignedToSigned, NegativeResults) EXPECT_EQ(m::math::subtract(uint32_t{0}, uint32_t{100}, int32_t{}), -100); } +TEST(SubtractionUnsignedUnsignedToSigned, MostNegativeBoundary) +{ + // Regression: the two most-negative representable results must be produced, + // not rejected. 0 - 128 = -128 = int8 min; 0 - 127 = -127. + EXPECT_EQ(m::math::subtract(uint32_t{0}, uint32_t{128}, int8_t{}), int8_t{-128}); + EXPECT_EQ(m::math::subtract(uint32_t{0}, uint32_t{127}, int8_t{}), int8_t{-127}); + // Just past the boundary must still overflow. + EXPECT_THROW(m::math::subtract(uint32_t{0}, uint32_t{129}, int8_t{}), std::overflow_error); + + // Same boundary for unsigned - signed -> signed, which shares the code path. + EXPECT_EQ(m::math::subtract(uint32_t{0}, int32_t{128}, int8_t{}), int8_t{-128}); + EXPECT_EQ(m::math::subtract(uint32_t{0}, int32_t{127}, int8_t{}), int8_t{-127}); + EXPECT_THROW(m::math::subtract(uint32_t{0}, int32_t{129}, int8_t{}), std::overflow_error); +} + TEST(SubtractionUnsignedUnsignedToSigned, OverflowCases) { // Very large unsigned values might overflow signed result diff --git a/src/libraries/pil/src/buffered/registry_key_value_operations.cpp b/src/libraries/pil/src/buffered/registry_key_value_operations.cpp index 4446c357..87687e4b 100644 --- a/src/libraries/pil/src/buffered/registry_key_value_operations.cpp +++ b/src/libraries/pil/src/buffered/registry_key_value_operations.cpp @@ -331,6 +331,7 @@ namespace m::pil::impl::buffered { value_vector.resize(value_span.size()); vnv.m_reg_value_type = value_type; + vnv.m_value = std::move(value_vector); break; } diff --git a/src/libraries/pil/src/key_path.cpp b/src/libraries/pil/src/key_path.cpp index 6a58425a..13528b63 100644 --- a/src/libraries/pil/src/key_path.cpp +++ b/src/libraries/pil/src/key_path.cpp @@ -184,7 +184,7 @@ namespace m::pil key_path key_path::parent_path() const { - if (auto const i = m_value.try_find_first_of(wregistry_delimiter); i.has_value()) + if (auto const i = m_value.try_find_last_of(wregistry_delimiter); i.has_value()) return key_path{m_root_key, m_value.substr(0, i.value())}; return key_path{}; diff --git a/src/libraries/pil/src/registry_key.cpp b/src/libraries/pil/src/registry_key.cpp index 3fb6a742..81daa505 100644 --- a/src/libraries/pil/src/registry_key.cpp +++ b/src/libraries/pil/src/registry_key.cpp @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#include #include #include #include @@ -29,7 +30,7 @@ namespace m::pil key& key::operator=(key const& other) { - m_key.reset(other.m_key.get()); + m_key = other.m_key; return *this; } @@ -248,17 +249,46 @@ namespace m::pil std::vector retval; - // Turn the value into a UTF-16 string and scan through looking for - // the embedded null characters - char16_t const* cursor = reinterpret_cast(value.m_bytes.data()); - std::size_t remaining = value.m_bytes.size() / sizeof(char16_t); + // The REG_MULTI_SZ format is a sequence of null-terminated UTF-16 + // strings, terminated by an empty string (that is, an extra trailing + // null character). + // + // The raw value lives in a std::vector, which only + // guarantees 1-byte alignment, so we cannot reinterpret the buffer in + // place as char16_t: that would be undefined behavior (and can fault) + // on platforms that require char16_t to be 2-byte aligned. Copy the + // bytes into a properly aligned std::u16string first. + auto const byte_count = value.m_bytes.size(); + + // An odd byte count cannot be a well-formed sequence of char16_t; + // reject it rather than silently truncating the trailing byte. + if ((byte_count % sizeof(char16_t)) != 0) + throw std::runtime_error( + "REG_MULTI_SZ value has an odd byte count and is not valid UTF-16"); - for (;;) + std::u16string buffer(byte_count / sizeof(char16_t), u'\0'); + std::memcpy(buffer.data(), value.m_bytes.data(), byte_count); + + std::u16string_view remaining(buffer); + + while (!remaining.empty()) { - std::ignore = cursor; - std::ignore = remaining; - // do the scanning in the future - break; + auto const null_pos = remaining.find(u'\0'); + + std::u16string_view const token = + (null_pos == std::u16string_view::npos) ? remaining : remaining.substr(0, null_pos); + + // An empty string marks the end of the sequence. + if (token.empty()) + break; + + retval.push_back(to_registry_string(token)); + + // Advance past the string and its null terminator (if present). + if (null_pos == std::u16string_view::npos) + remaining = {}; + else + remaining.remove_prefix(null_pos + 1); } return retval; @@ -405,6 +435,8 @@ namespace m::pil if (s.size() != bytes.size()) bytes.resize(s.size()); + vt = type; + break; } diff --git a/src/libraries/sstring/include/m/sstring/sstring.h b/src/libraries/sstring/include/m/sstring/sstring.h index f6f3c87d..65c65be5 100644 --- a/src/libraries/sstring/include/m/sstring/sstring.h +++ b/src/libraries/sstring/include/m/sstring/sstring.h @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -263,7 +264,7 @@ namespace m } constexpr int - compare(basic_sstring s) const + compare(basic_sstring const& s) const { return compare(s.view()); } @@ -281,7 +282,7 @@ namespace m } constexpr int - compare(size_type pos1, size_t count1, basic_sstring str) const + compare(size_type pos1, size_t count1, basic_sstring const& str) const { return view().compare(pos1, count1, str.view()); } @@ -294,11 +295,11 @@ namespace m } constexpr int - compare(size_type pos1, - size_type count1, - basic_sstring str, - size_type pos2, - size_type count2) const + compare(size_type pos1, + size_type count1, + basic_sstring const& str, + size_type pos2, + size_type count2) const { return view().compare(pos1, count1, str.view(), pos2, count2); } @@ -351,7 +352,8 @@ namespace m // the m_c_str ptr into the member and leave it there into the future. // - if ((m_offset_and_size.m_offset + m_offset_and_size.m_size) == v.size()) + if (m::math::add(m_offset_and_size.m_offset, m_offset_and_size.m_size, std::size_t{}) == + v.size()) { local_c_str_ptr = v.data() + m_offset_and_size.m_offset; m_c_str.store(local_c_str_ptr, std::memory_order_release); @@ -373,7 +375,15 @@ namespace m return basic_sstring({view(), other.view()}); } + // Constrain with the stringish concept (a pure type predicate) rather than a + // requires-expression over to_basic_string_view_t(): the latter returns auto and + // would force its body to be instantiated during constraint checking, turning a + // missing view_converter specialization into a hard error (and breaking the + // self-contained totally_ordered static_asserts below, since the needed + // specializations live in a higher layer). stringish already decays array operands + // (string literals) to their pointer form. template + requires(m::stringish) basic_sstring operator+(StringishT&& other) const { @@ -404,14 +414,20 @@ namespace m throw std::out_of_range("start"); } - auto const max_len = v.size() - start; + auto const max_len = m::math::subtract(v.size(), start, std::size_t{}); if (length > max_len) length = max_len; if (length == 0) return basic_sstring{}; - return basic_sstring{m_arefc, offset_and_size{.m_offset = start, .m_size = length}}; + // start is relative to this view; translate to an absolute offset + // into the shared base buffer so substrings of substrings are correct. + return basic_sstring{ + m_arefc, + offset_and_size{ + .m_offset = m::math::add(m_offset_and_size.m_offset, start, std::size_t{}), + .m_size = length}}; } basic_sstring @@ -420,7 +436,9 @@ namespace m auto const v = view(); if (count > v.size()) count = v.size(); - return basic_sstring{m_arefc, offset_and_size{.m_offset = 0, .m_size = count}}; + return basic_sstring{ + m_arefc, + offset_and_size{.m_offset = m_offset_and_size.m_offset, .m_size = count}}; } basic_sstring @@ -429,7 +447,9 @@ namespace m auto const v = view(); if (count > v.size()) count = v.size(); - auto const offset = v.size() - count; + auto const offset = m::math::add(m_offset_and_size.m_offset, + m::math::subtract(v.size(), count, std::size_t{}), + std::size_t{}); return basic_sstring{m_arefc, offset_and_size{.m_offset = offset, .m_size = count}}; } @@ -442,7 +462,8 @@ namespace m if (split_point == view_type::npos) return std::make_pair(*this, basic_sstring{}); - return std::make_pair(substr(0, split_point), substr(split_point + 1)); + return std::make_pair(substr(0, split_point), + substr(m::math::add(split_point, 1, std::size_t{}))); } std::pair @@ -454,8 +475,9 @@ namespace m if (split_point == view_type::npos) return std::make_pair(*this, basic_sstring{}); - return std::make_pair(substr(0, split_point), - substr(split_point + view_to_find.size())); + return std::make_pair( + substr(0, split_point), + substr(m::math::add(split_point, view_to_find.size(), std::size_t{}))); } std::pair @@ -467,7 +489,8 @@ namespace m if (split_point == view_type::npos) return std::make_pair(*this, basic_sstring{}); - return std::make_pair(substr(0, split_point), substr(split_point + 1)); + return std::make_pair(substr(0, split_point), + substr(m::math::add(split_point, 1, std::size_t{}))); } bool @@ -530,10 +553,13 @@ namespace m last() const { M_INTERNAL_ERROR_CHECK(view().size() > 0); - return view().data()[view().size() - 1]; + return view().data()[m::math::subtract(view().size(), 1, std::size_t{})]; } + // See the note on operator+ above: stringish is a pure type predicate, so it does + // not instantiate to_basic_string_view_t's body during constraint checking. template + requires(m::stringish) constexpr bool operator==(StringishT&& r) const noexcept { @@ -552,8 +578,7 @@ namespace m [[nodiscard]] constexpr comparison_category_type operator<=>(view_type const& r) const noexcept { - auto const v = m::to_basic_string_view_t(r); - auto const compare_result = compare(v); + auto const compare_result = compare(r); return static_cast(compare_result <=> 0); } @@ -584,10 +609,12 @@ namespace m // And the offset can't be beyond the size of the whole string // minus the size of the substring. - M_INTERNAL_ERROR_CHECK(m_offset_and_size.m_offset <= - v.size() - m_offset_and_size.m_size); + M_INTERNAL_ERROR_CHECK( + m_offset_and_size.m_offset <= + m::math::subtract(v.size(), m_offset_and_size.m_size, std::size_t{})); - if (v.size() - m_offset_and_size.m_size == m_offset_and_size.m_offset) + if (m::math::subtract(v.size(), m_offset_and_size.m_size, std::size_t{}) == + m_offset_and_size.m_offset) { // If the end of the substring lines up with the end of // the whole string, if there is a m_c_str value, it @@ -618,7 +645,7 @@ namespace m if (cstr != nullptr) return cstr; - auto up = std::make_unique(size + 1); + auto up = std::make_unique(m::math::add(size, 1, std::size_t{})); if (size != 0) std::copy_n(ptr, size, up.get()); diff --git a/src/libraries/threadpool/include/m/threadpool/work_queue.h b/src/libraries/threadpool/include/m/threadpool/work_queue.h index b7124e1a..97af4910 100644 --- a/src/libraries/threadpool/include/m/threadpool/work_queue.h +++ b/src/libraries/threadpool/include/m/threadpool/work_queue.h @@ -236,6 +236,36 @@ namespace m return do_wait_for(std::chrono::duration_cast(dur)); } + /// + /// Closes the work queue, cancelling any not-yet-started work and + /// synchronously draining any in-flight work before returning. + /// + /// Not-yet-started work items are moved to the canceled terminal state + /// (their callbacks will never run), which also unblocks any thread + /// waiting in `work_item::wait()`/`wait_for()`/`wait_until()` on such an + /// item. In-flight callbacks are allowed to run to completion before + /// `close()` returns. + /// + /// `close()` lets an owner perform the drain deterministically on + /// its own thread. If `close()` is not called, the destructor performs + /// the same drain. Calling `close()` more than once is harmless. + /// + /// `close()` does not put the queue into a permanently-closed state: + /// it neither rejects nor synchronizes against `enqueue()`. The caller + /// is responsible for ensuring that no `enqueue()` happens during or + /// after `close()` (for example, by establishing a happens-before + /// relationship that retires all producers first). The guarantee is + /// therefore conditional: provided no work is enqueued once `close()` + /// begins, no further callbacks will execute against this queue after + /// `close()` returns. Work enqueued concurrently with, or after, + /// `close()` may still run. + /// + void + close() + { + do_close(); + } + // // "wait" and "wait_until" are omitted, intentionally. This is because // waiting indefinitely for a potentially large number of work items @@ -261,6 +291,9 @@ namespace m virtual std::shared_ptr do_enqueue(std::packaged_task&& task, m::wsstring const& description) = 0; + + virtual void + do_close() = 0; }; } // namespace m diff --git a/src/libraries/threadpool/src/Platforms/linux/work_queue.cpp b/src/libraries/threadpool/src/Platforms/linux/work_queue.cpp index 41a43461..fce4b9de 100644 --- a/src/libraries/threadpool/src/Platforms/linux/work_queue.cpp +++ b/src/libraries/threadpool/src/Platforms/linux/work_queue.cpp @@ -32,4 +32,11 @@ namespace m::threadpool_impl M_NOT_IMPLEMENTED("sorry no linux work queue"); } + void + work_queue::perform_platform_teardown() noexcept + { + // Nothing was ever initialized on this platform, so there is nothing + // to drain. + } + } // namespace m::threadpool_impl diff --git a/src/libraries/threadpool/src/Platforms/linux/work_queue.h b/src/libraries/threadpool/src/Platforms/linux/work_queue.h index c24a6e9f..e6578fac 100644 --- a/src/libraries/threadpool/src/Platforms/linux/work_queue.h +++ b/src/libraries/threadpool/src/Platforms/linux/work_queue.h @@ -40,6 +40,9 @@ namespace m::threadpool_impl void perform_platform_initialization() override; + void + perform_platform_teardown() noexcept override; + void on_new_work_item(std::shared_ptr const& wi) override; }; diff --git a/src/libraries/threadpool/src/Platforms/windows/work_queue.cpp b/src/libraries/threadpool/src/Platforms/windows/work_queue.cpp index c3fdd3c2..b06ef434 100644 --- a/src/libraries/threadpool/src/Platforms/windows/work_queue.cpp +++ b/src/libraries/threadpool/src/Platforms/windows/work_queue.cpp @@ -20,6 +20,15 @@ namespace m::threadpool_impl m::threadpool_impl::work_queue_base(wqep, description) {} + work_queue::~work_queue() + { + // If the owner did not call close(), drain here. Draining is idempotent + // and, because callbacks hold no ownership of the queue, this destructor + // can only run on the owning thread, never on a threadpool callback + // thread, so the synchronous wait below cannot deadlock against itself. + perform_platform_teardown(); + } + void work_queue::on_new_work_item(std::shared_ptr const&) { @@ -30,7 +39,7 @@ namespace m::threadpool_impl work_queue::perform_platform_initialization() { auto callback_context_ptr = std::make_unique(); - callback_context_ptr->m_work_queue = weak_from_this(); + callback_context_ptr->m_work_queue = this; auto wrk = win32::threadpool::tp_work( &work_queue::static_tp_work_callback, callback_context_ptr.get(), nullptr); @@ -38,19 +47,51 @@ namespace m::threadpool_impl using std::swap; swap(wrk, m_tp_work); - // Yes this induces a cycle. Can be fixed by having an explicit - // close() protocol or by having the pointer back to the queue - // be a weak reference which is a performance problem for each - // queue entry. Solvable/contained. swap(callback_context_ptr, m_callback_context); } + void + work_queue::perform_platform_teardown() noexcept + { + // Cancel pending callbacks and wait for any in-flight callback to + // finish. After this returns no callback will touch this queue. Safe to + // call repeatedly (close() then destructor). + if (m_platform_initialized) + { + m_tp_work.wait_for_callbacks(true); + + // wait_for_callbacks(true) cancels the threadpool callbacks that + // had not yet started, so any work items still sitting in + // m_ready_queue will never run on their own. Move each of them to + // the canceled terminal state and wake their waiters so that + // work_item::wait*() cannot block forever on never-started work + // (which is what close()'s "cancel not-yet-started work" contract + // promises). No callback can race us here: once + // wait_for_callbacks(true) has returned, none are running and none + // will start. + auto l = std::unique_lock(m_mutex); + + while (!m_ready_queue.empty()) + { + auto const wi = m_ready_queue.front(); + m_ready_queue.pop_front(); + wi->cancel_if_queued(); + } + + l.unlock(); + + // Wake queue-level waiters (wait_for) now that the ready queue is + // empty and no work remains in flight. + m_cv.notify_all(); + } + } + void CALLBACK work_queue::static_tp_work_callback(PTP_CALLBACK_INSTANCE, PVOID context, PTP_WORK) noexcept { auto const cctx = reinterpret_cast(context); - cctx->m_work_queue.lock()->tp_work_callback(); + cctx->m_work_queue->tp_work_callback(); } void diff --git a/src/libraries/threadpool/src/Platforms/windows/work_queue.h b/src/libraries/threadpool/src/Platforms/windows/work_queue.h index 5a237877..2c493fd3 100644 --- a/src/libraries/threadpool/src/Platforms/windows/work_queue.h +++ b/src/libraries/threadpool/src/Platforms/windows/work_queue.h @@ -49,13 +49,11 @@ namespace m::threadpool_impl /// we shall concern ourselves with at this time). /// /// - class work_queue : - public m::threadpool_impl::work_queue_base, - public std::enable_shared_from_this + class work_queue : public m::threadpool_impl::work_queue_base { public: work_queue() = default; - ~work_queue() = default; + ~work_queue() override; work_queue(work_queue const&) = delete; work_queue(work_queue&&) noexcept = delete; work_queue(m::work_queue_execution_policy wqep, std::wstring description); @@ -72,12 +70,20 @@ namespace m::threadpool_impl void perform_platform_initialization() override; + void + perform_platform_teardown() noexcept override; + void on_new_work_item(std::shared_ptr const& wi) override; struct callback_context { - std::weak_ptr m_work_queue; + // Raw back-pointer to the owning queue. It is kept valid for the + // duration of every callback by `perform_platform_teardown()`, + // which waits for all in-flight callbacks before the queue is + // destroyed. Because the callback holds no ownership, the queue's + // destructor can never run on a threadpool callback thread. + m::threadpool_impl::work_queue* m_work_queue{}; }; static void CALLBACK diff --git a/src/libraries/threadpool/src/work_item.cpp b/src/libraries/threadpool/src/work_item.cpp index 737289cb..5e5b21a7 100644 --- a/src/libraries/threadpool/src/work_item.cpp +++ b/src/libraries/threadpool/src/work_item.cpp @@ -75,6 +75,25 @@ namespace m::work_queue_impl return false; } + bool + work_item::cancel_if_queued() + { + { + auto l = std::unique_lock(m_mutex); + + // Only a not-yet-started item can be cancelled here. If it is + // already running or terminal, leave it alone. + if (m_work_item_state != work_item_state::queued) + return false; + + m_work_item_state = work_item_state::canceled; + } + + // Wake any waiters now that the state is terminal. + m_state_cv.notify_all(); + return true; + } + uint64_t work_item::do_id() { @@ -97,7 +116,10 @@ namespace m::work_queue_impl auto l = std::unique_lock(m_mutex); if (m_work_item_state == work_item_state::canceled) + { + m_state_cv.notify_all(); return; + } M_INTERNAL_ERROR_CHECK(m_work_item_state == work_item_state::queued); @@ -122,42 +144,42 @@ namespace m::work_queue_impl m_work_item_times.m_end_time = m::clock_type::now(); m_work_item_state = work_item_state::done; } + + // Wake any waiters now that the state is terminal. The packaged_task + // future goes ready inside m_packaged_task() above, strictly before the + // state transition, so waiters must observe the state itself rather than + // the future to guarantee the item is "done" when wait() returns. + m_state_cv.notify_all(); } void work_item::do_wait() { - m_future.wait(); + auto l = std::unique_lock(m_mutex); + m_state_cv.wait(l, [this] { + return m_work_item_state == work_item_state::done || + m_work_item_state == work_item_state::canceled; + }); } bool work_item::do_wait_for(std::chrono::milliseconds const& d) { - auto const future_status = m_future.wait_for(d); - - switch (future_status) - { - default: M_UNREACHABLE_CODE(); break; - - case std::future_status::deferred: return true; - case std::future_status::ready: return true; - case std::future_status::timeout: return false; - } + auto l = std::unique_lock(m_mutex); + return m_state_cv.wait_for(l, d, [this] { + return m_work_item_state == work_item_state::done || + m_work_item_state == work_item_state::canceled; + }); } bool work_item::do_wait_until(m::time_point_type const& tp) { - auto const future_status = m_future.wait_until(tp); - - switch (future_status) - { - default: M_UNREACHABLE_CODE(); break; - - case std::future_status::deferred: return true; - case std::future_status::ready: return true; - case std::future_status::timeout: return false; - } + auto l = std::unique_lock(m_mutex); + return m_state_cv.wait_until(l, tp, [this] { + return m_work_item_state == work_item_state::done || + m_work_item_state == work_item_state::canceled; + }); } } // namespace m::work_queue_impl diff --git a/src/libraries/threadpool/src/work_item.h b/src/libraries/threadpool/src/work_item.h index 35573cfe..5815b906 100644 --- a/src/libraries/threadpool/src/work_item.h +++ b/src/libraries/threadpool/src/work_item.h @@ -5,9 +5,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -42,6 +44,13 @@ namespace m::work_queue_impl do_work(); } + // Transition a still-queued item to the canceled terminal state and + // wake any waiters. No-op (returns false) if the item has already + // started running or already reached a terminal state. Used by queue + // teardown to release waiters blocked on work that will never start. + bool + cancel_if_queued(); + protected: utc_time_point_type do_enqueue_time() override; @@ -82,6 +91,7 @@ namespace m::work_queue_impl work_item_id_type m_id; // immutable once constructed m::wsstring m_description; // immutable once constructed std::mutex m_mutex; + std::condition_variable m_state_cv; // signaled when m_work_item_state becomes terminal work_item_times m_work_item_times; work_item_state m_work_item_state; std::packaged_task m_packaged_task; diff --git a/src/libraries/threadpool/src/work_queue_base.cpp b/src/libraries/threadpool/src/work_queue_base.cpp index 81ec718b..6d6686ff 100644 --- a/src/libraries/threadpool/src/work_queue_base.cpp +++ b/src/libraries/threadpool/src/work_queue_base.cpp @@ -71,4 +71,14 @@ namespace m::threadpool_impl return wi; } + void + work_queue_base::do_close() + { + // The drain itself is platform-specific. It cancels not-yet-started + // work and waits for in-flight callbacks; it must never run on a + // threadpool callback thread, which both `close()` and the owning + // destructor guarantee. + perform_platform_teardown(); + } + } // namespace m::threadpool_impl diff --git a/src/libraries/threadpool/src/work_queue_base.h b/src/libraries/threadpool/src/work_queue_base.h index abf0cdb3..a94f532c 100644 --- a/src/libraries/threadpool/src/work_queue_base.h +++ b/src/libraries/threadpool/src/work_queue_base.h @@ -72,9 +72,24 @@ namespace m::threadpool_impl std::shared_ptr do_enqueue(std::packaged_task&& task, m::wsstring const& description) override; + void + do_close() override; + virtual void perform_platform_initialization() = 0; + /// + /// The `perform_platform_teardown()` member function is overridden by + /// platform-specific implementations to cancel any not-yet-started work + /// and synchronously wait for any in-flight callbacks to complete. + /// + /// It must be safe to call more than once and must be called from a + /// thread that is not a threadpool callback thread (which the public + /// `close()` contract and the destructor both guarantee). + /// + virtual void + perform_platform_teardown() noexcept = 0; + /// /// The on_new_work_item() pure virtual member function is overridden by /// platform-specific implementations to dispatch to a platform-specific diff --git a/src/libraries/threadpool/test/test_work_queue.cpp b/src/libraries/threadpool/test/test_work_queue.cpp index c3eef33d..40771763 100644 --- a/src/libraries/threadpool/test/test_work_queue.cpp +++ b/src/libraries/threadpool/test/test_work_queue.cpp @@ -65,14 +65,63 @@ TEST(WorkQueue, QueueN20) q->wait_for(5s); } +// Moderately sized functional check that the queue drains a batch and sets +// every flag. Kept small (and quiet) so the normal test pass stays around +// 1-2 seconds; the heavy throughput soak lives in DISABLED_QueueNBigStress. TEST(WorkQueue, QueueNBig) +{ + auto q = m::threadpool->create_work_queue(); + constexpr std::size_t n = 10'000; + + auto work_items = + std::unique_ptr[]>(new std::shared_ptr[n]()); + // Value-initialize the atomics (trailing ()) so every flag starts at a + // well-defined 0; reading an uninitialized atomic would be UB on exactly + // the path this test is trying to detect (a work item that never ran). + auto flags_unique_ptr = std::unique_ptr[]>(new std::atomic[n]()); + auto flags = flags_unique_ptr.get(); + + auto const before_queue = m::clock_type::now(); + + for (std::size_t i = 0; i < n; i++) + { + work_items[i] = q->enqueue([p = &flags[i]] { *p = 1; }); + } + + auto const after_queue = m::clock_type::now(); + + constexpr auto d = 250ms; + + while (!q->wait_for(d)) + m::println("After {}, {} queue items still running", d, q->running()); + + auto const after_wait = m::clock_type::now(); + + q.reset(); + + // Once the queue has drained, verify the flags are all set + for (std::size_t i = 0; i < n; i++) + EXPECT_EQ(flags[i], 1); + + m::println("It took {} to queue, and then {} for the work to finish", + after_queue - before_queue, + after_wait - after_queue); +} + +// Heavy throughput stress run. Disabled by default so it is not part of the +// normal test pass (which should stay around 1-2 seconds); run explicitly with +// --gtest_also_run_disabled_tests when you want the big soak. +TEST(WorkQueue, DISABLED_QueueNBigStress) { auto q = m::threadpool->create_work_queue(); constexpr std::size_t n = 1'300'000; auto work_items = std::unique_ptr[]>(new std::shared_ptr[n]()); - auto flags_unique_ptr = std::unique_ptr[]>(new std::atomic[n]); + // Value-initialize the atomics (trailing ()) so every flag starts at a + // well-defined 0; reading an uninitialized atomic would be UB on exactly + // the path this test is trying to detect (a work item that never ran). + auto flags_unique_ptr = std::unique_ptr[]>(new std::atomic[n]()); auto flags = flags_unique_ptr.get(); auto const before_queue = m::clock_type::now(); @@ -335,4 +384,73 @@ TEST(WorkQueue, CreateWorkQueueWithDescription) auto wi = q->enqueue([] {}); q->wait_for(5s); -} \ No newline at end of file +} + +// --------------------------------------------------------------------------- +// close() teardown tests +// --------------------------------------------------------------------------- + +TEST(WorkQueue, CloseOnIdleQueueIsSafe) +{ + auto q = m::threadpool->create_work_queue(); + q->close(); +} + +TEST(WorkQueue, CloseAfterDrainIsSafe) +{ + auto q = m::threadpool->create_work_queue(); + for (int i = 0; i < 10; ++i) + static_cast(q->enqueue([] {})); + + EXPECT_TRUE(q->wait_for(5s)); + q->close(); + + EXPECT_EQ(q->running(), 0u); +} + +TEST(WorkQueue, CloseDrainsInFlightWork) +{ + std::latch started(1); + std::latch release(1); + + auto q = m::threadpool->create_work_queue(); + auto wi = q->enqueue([&]() { + started.count_down(); + release.wait(); + }); + + started.wait(); // ensure the callback is actually in flight + release.count_down(); + + q->close(); // must synchronously wait for the in-flight callback to finish + + EXPECT_EQ(q->running(), 0u); +} + +TEST(WorkQueue, CloseIsIdempotent) +{ + auto q = m::threadpool->create_work_queue(); + static_cast(q->enqueue([] {})); + EXPECT_TRUE(q->wait_for(5s)); + + q->close(); + q->close(); +} + +TEST(WorkQueue, DestroyWithoutCloseDrains) +{ + std::latch started(1); + std::latch release(1); + + auto q = m::threadpool->create_work_queue(); + static_cast(q->enqueue([&]() { + started.count_down(); + release.wait(); + })); + + started.wait(); + release.count_down(); + + // No explicit close(): the destructor must synchronously drain. + q.reset(); +} diff --git a/src/libraries/tracing/include/m/tracing/message_queue.h b/src/libraries/tracing/include/m/tracing/message_queue.h index 7b900e34..545c4794 100644 --- a/src/libraries/tracing/include/m/tracing/message_queue.h +++ b/src/libraries/tracing/include/m/tracing/message_queue.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -83,8 +84,11 @@ namespace m void wake_waiters() noexcept; + std::uint64_t + wake_generation() const noexcept; + void - wait() noexcept; + wait(std::uint64_t last_wake_generation) noexcept; void enqueue(m::not_null msg) noexcept; @@ -106,6 +110,17 @@ namespace m mutable std::mutex m_mutex; std::condition_variable m_cv; std::queue m_queue; + + // Monotonic wake generation. wake_waiters() advances this under + // m_mutex and broadcasts; each waiter passes the generation it + // sampled (via wake_generation()) into wait() and blocks only while + // the generation is unchanged. Because every waiter compares against + // its own baseline rather than consuming a single shared flag, a + // broadcast wake is observed independently by all waiters (no lost + // wakeups with multiple waiters), and a wake issued before a thread + // reaches wait() is still seen because that thread sampled an older + // generation. This closes the lost-wakeup race during sink teardown. + std::uint64_t m_wake_generation = 0; }; } // namespace tracing } // namespace m diff --git a/src/libraries/tracing/include/m/tracing/ostream_sink.h b/src/libraries/tracing/include/m/tracing/ostream_sink.h index 3c00f782..542a134d 100644 --- a/src/libraries/tracing/include/m/tracing/ostream_sink.h +++ b/src/libraries/tracing/include/m/tracing/ostream_sink.h @@ -221,6 +221,13 @@ namespace m { while (!m_stop.load(std::memory_order_acquire)) { + // Sample the wake generation before draining and before + // checking the termination flags. Passing this baseline to + // wait() means any wake_waiters() that races in after this + // point advances the generation and makes wait() return + // immediately rather than block, closing the teardown race. + auto const wake_gen = m_message_queue.wake_generation(); + while (!m_stop.load(std::memory_order_acquire) && !m_message_queue.empty()) { auto env = m_message_queue.dequeue(); @@ -232,7 +239,7 @@ namespace m m_stop.load(std::memory_order_acquire)) break; - m_message_queue.wait(); + m_message_queue.wait(wake_gen); } } }; diff --git a/src/libraries/tracing/src/envelope.cpp b/src/libraries/tracing/src/envelope.cpp index cdc41077..a252a6bf 100644 --- a/src/libraries/tracing/src/envelope.cpp +++ b/src/libraries/tracing/src/envelope.cpp @@ -15,20 +15,27 @@ namespace m::tracing {} envelope::envelope(envelope&& other) noexcept: - m_message_source{other.m_message_source}, m_imessage{} - { - using std::swap; - - swap(m_imessage, other.m_imessage); - } + m_message_source{std::exchange(other.m_message_source, nullptr)}, + m_imessage{std::exchange(other.m_imessage, nullptr)} + {} void envelope::operator=(envelope&& other) noexcept { - using std::swap; + if (this == &other) + return; - swap(m_imessage, other.m_imessage); - swap(m_message_source, other.m_message_source); + // Release any message we currently hold back to its source before taking + // ownership of other's message. A plain swap would move our live message + // into other, whose destructor only nulls the pointer (it does not return + // the slot), permanently leaking that message buffer from the pool. + return_to_sender(); + + // Take ownership of other's message and source, leaving the moved-from + // envelope in a consistent empty state (both pointers null) so it does + // not retain a dangling source. + m_message_source = std::exchange(other.m_message_source, nullptr); + m_imessage = std::exchange(other.m_imessage, nullptr); } void diff --git a/src/libraries/tracing/src/message_queue.cpp b/src/libraries/tracing/src/message_queue.cpp index c9330825..b0d888f1 100644 --- a/src/libraries/tracing/src/message_queue.cpp +++ b/src/libraries/tracing/src/message_queue.cpp @@ -102,18 +102,35 @@ namespace m::tracing return frame.succeeded(std::nullopt); } + std::uint64_t + message_queue::wake_generation() const noexcept + { + tr_frame frame(__FUNCTION__, this); + auto l = std::unique_lock(m_mutex); + return frame.succeeded(m_wake_generation); + } + void - message_queue::wait() noexcept + message_queue::wait(std::uint64_t last_wake_generation) noexcept { tr_frame frame(__FUNCTION__, this); auto l = std::unique_lock(m_mutex); - if (m_queue.empty()) + // Block only while the queue is empty and no wake has been broadcast + // since the caller sampled wake_generation(). Comparing against the + // caller-supplied baseline (rather than consuming a single shared flag) + // lets every waiter observe the same broadcast independently, so + // concurrent waiters are not subject to lost wakeups, and a wake issued + // before this call is still seen because last_wake_generation is older. + if (m_queue.empty() && m_wake_generation == last_wake_generation) { frame.write(L"Queue is empty, waiting"); - m_cv.wait(l); + m_cv.wait(l, [this, last_wake_generation] { + return !m_queue.empty() || m_wake_generation != last_wake_generation; + }); frame.write(L"Woke from wait, queue now has {} entries", m_queue.size()); } + frame.succeeded(); } @@ -121,8 +138,13 @@ namespace m::tracing message_queue::wake_waiters() noexcept { tr_frame frame(__FUNCTION__, this); - // In some world, we might see if we need to wake anyone - // but in fact, we can just tell the cv to wake anyone. + // Advance the wake generation under the lock so a wake issued before a + // thread reaches wait() is observed (that waiter sampled an older + // generation) rather than lost, then broadcast to every waiter. + { + auto l = std::unique_lock(m_mutex); + ++m_wake_generation; + } m_cv.notify_all(); frame.succeeded(); } diff --git a/src/libraries/tracing/src/monitor_class_impl.cpp b/src/libraries/tracing/src/monitor_class_impl.cpp index ccf4db44..c13b7db5 100644 --- a/src/libraries/tracing/src/monitor_class_impl.cpp +++ b/src/libraries/tracing/src/monitor_class_impl.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -23,26 +24,39 @@ namespace m::tracing_impl m_pool = std::make_shared(); m::dbg_format("Constructing monitor at {}", reinterpret_cast(this)); - constexpr std::size_t raw_message_count = 64; // It's tricky to construct the messages since each takes a pool. We have to allocate // the storage, then construct them. - - struct message_array_type - { - alignas(m::tracing::message) - std::array m_data; + m_message_storage = std::make_unique(); + m_raw_messages = reinterpret_cast(m_message_storage.get()); + + // Each message must be placement-constructed from the pool, so the + // standard uninitialized_* algorithms (which require a default or copy + // constructor) don't apply. Guard the loop with a unique_ptr-based RAII + // rollback (the same deleter-as-cleanup idiom used by mmake_arefc_ex): + // if a message constructor throws, the deleter destroys the messages + // already constructed (in reverse order) during unwinding, so the + // backing storage is never freed with live objects still in it. On + // success we dismiss the guard with release(). + std::size_t constructed = 0; + + auto const rollback_deleter = [&constructed](m::tracing::message* base) noexcept { + for (std::size_t i = constructed; i-- > 0;) + std::destroy_at(&base[i]); }; - auto p = new message_array_type; - - m_raw_messages = reinterpret_cast(p); + std::unique_ptr rollback(m_raw_messages, + rollback_deleter); for (std::size_t i = 0; i < raw_message_count; i++) { ::new (&m_raw_messages[i]) m::tracing::message(m_pool); + constructed = i + 1; } + rollback.release(); // all messages constructed; dismiss the rollback + + // enqueue() is noexcept, so this loop cannot throw and needs no rollback. for (std::size_t i = 0; i < raw_message_count; i++) { m::tracing::envelope item(m::not_null(this), &m_raw_messages[i]); @@ -54,6 +68,27 @@ namespace m::tracing_impl { for (auto&& s: m_sink_shims) s->close(m::tracing::close_flush_option::normal); + + // Every message slot must be back in the queue by the time we tear down. + // Dispatch is synchronous, so a slot is only ever checked out for the + // duration of a single log call; if any are missing here a message is still + // in flight (held by a live envelope) and destroying the backing storage + // would leave that envelope dangling. Fail fast rather than corrupt the heap. + M_INTERNAL_ERROR_CHECK(m_message_queue.size() == raw_message_count); + + // The messages were placement-constructed, so run their destructors + // explicitly (in reverse construction order) while m_pool is still alive; + // each returns its pooled buffer to the pool. Then free the byte block + // through its real type via the unique_ptr. + if (m_raw_messages != nullptr) + { + for (std::size_t i = raw_message_count; i-- > 0;) + std::destroy_at(&m_raw_messages[i]); + + m_raw_messages = nullptr; + } + + m_message_storage.reset(); } m::not_null diff --git a/src/libraries/tracing/src/monitor_class_impl.h b/src/libraries/tracing/src/monitor_class_impl.h index fc80165d..0a8eae58 100644 --- a/src/libraries/tracing/src/monitor_class_impl.h +++ b/src/libraries/tracing/src/monitor_class_impl.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -141,14 +142,28 @@ namespace m::tracing_impl using channel_sink_shim_map_type = std::multimap, std::less<>>; + // Number of preallocated message slots backing m_message_queue. + static constexpr std::size_t raw_message_count = 64; + + // Raw storage for the message slots. The messages take a pool argument so + // they cannot be default-constructed as an array; instead we allocate this + // correctly-aligned byte block and placement-new each message into it. Keeping + // the block in a typed unique_ptr lets the destructor free it through the same + // type it was allocated with rather than through the message* alias. + struct message_array_type + { + alignas(m::tracing::message) + std::array m_data; + }; + std::atomic m_topology_version; std::mutex m_mutex; channel_map_type m_channels; channel_sink_shim_map_type m_channel_sink_shims; std::vector> m_sink_shims; m::tracing::message_queue m_message_queue; + std::unique_ptr m_message_storage; m::tracing::message* m_raw_messages{}; - // std::unique_ptr m_raw_messages; bool m_closed_sinks; std::shared_ptr m_pool; diff --git a/src/libraries/tracing/test/CMakeLists.txt b/src/libraries/tracing/test/CMakeLists.txt index ad7fb408..5544b75a 100644 --- a/src/libraries/tracing/test/CMakeLists.txt +++ b/src/libraries/tracing/test/CMakeLists.txt @@ -7,6 +7,7 @@ set(TEST_EXE_NAME test_tracing) add_executable(${TEST_EXE_NAME} exercise_tracing.cpp test_fast_sink.cpp + test_monitor_teardown.cpp test_slow_sink.cpp ) diff --git a/src/libraries/tracing/test/test_monitor_teardown.cpp b/src/libraries/tracing/test/test_monitor_teardown.cpp new file mode 100644 index 00000000..6a6e75b3 --- /dev/null +++ b/src/libraries/tracing/test/test_monitor_teardown.cpp @@ -0,0 +1,248 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// +// Teardown harness for the tracing monitor. +// +// The production monitor is a process-lifetime singleton that is normally only +// destroyed by CRT static destructors at process exit / DLL_PROCESS_DETACH, +// which is awkward to drive from a unit test. These tests instead build +// standalone monitors through the public make_monitor_class() factory so the +// exact same construction/teardown path can be exercised on demand. +// +// What we are demonstrating: +// 1. A create / use / destroy cycle leaks nothing (CRT debug heap diff == 0). +// 2. After a multithreaded burst of logging is fully quiesced (threads joined), +// teardown still frees everything cleanly. +// 3. The destructor's "all message slots returned" tripwire actually fires +// when a sink retains a message past teardown. +// + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_WIN32) && defined(_DEBUG) +#define M_TRACING_CRT_LEAK_CHECK 1 +#include +#else +#define M_TRACING_CRT_LEAK_CHECK 0 +#endif + +using namespace m::string_view_literals; + +namespace +{ + // A minimal sink that processes (and discards) every message without ever + // retaining ownership. This is the well-behaved case: every message slot is + // returned to the queue as soon as the log call unwinds. + class counting_sink : public m::tracing::sink + { + public: + explicit counting_sink(m::not_null monitor): + sink(L"counting_sink"_sl, monitor) + {} + + ~counting_sink() override = default; + + std::uint64_t + count() const + { + return m_count.load(std::memory_order_relaxed); + } + + protected: + m::tracing::on_message_disposition + on_message(m::tracing::may_forward_message_option, m::tracing::envelope&) override + { + m_count.fetch_add(1, std::memory_order_relaxed); + return m::tracing::on_message_disposition::message_processed; + } + + bool + could_forward_message(m::tracing::envelope const&) override + { + return false; + } + + void + close(m::tracing::close_flush_option) noexcept override + {} + + private: + std::atomic m_count{0}; + }; + + // A deliberately misbehaving sink that steals the envelope (taking ownership + // of the message slot) and reports the message as forwarded, so the message + // allocator does not return the slot to the queue. This models an async + // forwarding sink that drops a message on the floor, which must trip the + // monitor's teardown invariant. + class retaining_sink : public m::tracing::sink + { + public: + explicit retaining_sink(m::not_null monitor): + sink(L"retaining_sink"_sl, monitor) + {} + + ~retaining_sink() override = default; + + protected: + m::tracing::on_message_disposition + on_message(m::tracing::may_forward_message_option, m::tracing::envelope& env) override + { + auto l = std::unique_lock(m_mutex); + m_held.emplace_back(std::move(env)); + return m::tracing::on_message_disposition::message_forwarded; + } + + bool + could_forward_message(m::tracing::envelope const&) override + { + return true; + } + + void + close(m::tracing::close_flush_option) noexcept override + {} + + private: + std::mutex m_mutex; + std::vector m_held; + }; + + // One full create / register / log / destroy cycle against a standalone + // monitor. Everything is scoped so it is fully torn down on return. + void + exercise_monitor_lifecycle(int message_count) + { + auto monitor = m::tracing::make_monitor_class(); + + auto snk = std::make_shared(monitor.get()); + auto reg = monitor->register_sink(m::tracing::diagnostic_channel_name, snk); + + auto src = monitor->make_source(m::tracing::event_kind::verbose); + + for (int i = 0; i < message_count; i++) + src->wlog(m::tracing::event_kind::information, L"teardown harness message {}", i); + } +} // namespace + +// Demonstrates that a single-threaded create/use/destroy cycle leaks nothing. +TEST(TracingMonitorTeardown, SingleThreadCycleLeaksNothing) +{ + // Warm up once so any one-time/global allocations (lazy statics, TLS, first + // pool growth) happen before we take the baseline checkpoint. + exercise_monitor_lifecycle(8); + +#if M_TRACING_CRT_LEAK_CHECK + _CrtMemState before{}; + _CrtMemState after{}; + _CrtMemState diff{}; + + _CrtMemCheckpoint(&before); + exercise_monitor_lifecycle(32); + _CrtMemCheckpoint(&after); + + // Capture the comparison without any intervening allocations, then assert. + int const leaked = _CrtMemDifference(&diff, &before, &after); + EXPECT_EQ(0, leaked) << "monitor teardown leaked heap blocks"; +#else + // Without the CRT debug heap we can still exercise the path; a leak or + // double-free would surface under an external tool (e.g. ASan). + exercise_monitor_lifecycle(32); +#endif +} + +// Demonstrates that after a concurrent burst of logging is fully quiesced +// (all worker threads joined), the monitor tears down cleanly with no leak and +// without tripping the teardown invariant. +TEST(TracingMonitorTeardown, QuiesceThenTeardownMultiThread) +{ + auto run_burst = [] { + auto monitor = m::tracing::make_monitor_class(); + + auto snk = std::make_shared(monitor.get()); + auto reg = monitor->register_sink(m::tracing::diagnostic_channel_name, snk); + + constexpr int thread_count = 8; + constexpr int messages_per_thread = 200; + std::atomic go{false}; + + std::vector workers; + workers.reserve(thread_count); + + for (int t = 0; t < thread_count; t++) + { + workers.emplace_back([&monitor, &go] { + while (!go.load(std::memory_order_acquire)) + std::this_thread::yield(); + + auto src = monitor->make_source(m::tracing::event_kind::verbose); + for (int i = 0; i < messages_per_thread; i++) + src->wlog(m::tracing::event_kind::information, L"burst {}", i); + }); + } + + go.store(true, std::memory_order_release); + + for (auto&& w: workers) + w.join(); + + // All dispatch is now quiesced; monitor is destroyed on return. + }; + + // Warm up, then measure. + run_burst(); + +#if M_TRACING_CRT_LEAK_CHECK + _CrtMemState before{}; + _CrtMemState after{}; + _CrtMemState diff{}; + + _CrtMemCheckpoint(&before); + run_burst(); + _CrtMemCheckpoint(&after); + + int const leaked = _CrtMemDifference(&diff, &before, &after); + EXPECT_EQ(0, leaked) << "multithreaded monitor teardown leaked heap blocks"; +#else + run_burst(); +#endif +} + +// Demonstrates that the teardown invariant fires when a message slot is still +// checked out (retained by a sink) at the time the monitor is destroyed. +TEST(TracingMonitorTeardownDeathTest, RetainedMessageTripsInvariant) +{ + EXPECT_DEATH( + { + auto monitor = m::tracing::make_monitor_class(); + + auto snk = std::make_shared(monitor.get()); + auto reg = monitor->register_sink(m::tracing::diagnostic_channel_name, snk); + + auto src = monitor->make_source(m::tracing::event_kind::verbose); + + // The retaining sink steals this message's slot and never returns it, + // so destroying the monitor must trip the "all slots returned" check. + src->wlog(m::tracing::event_kind::information, L"steal this slot"); + }, + ".*"); +} diff --git a/src/libraries/utf/include/m/utf/decode.h b/src/libraries/utf/include/m/utf/decode.h index f77d95ef..6ebf9114 100644 --- a/src/libraries/utf/include/m/utf/decode.h +++ b/src/libraries/utf/include/m/utf/decode.h @@ -169,6 +169,11 @@ namespace m // Check for non-shortest-length encoding if (ch < 0x00000800) throw utf_invalid_encoding_error("Non-shortest Utf-8 encoding"); + + // Reject UTF-16 surrogate code points (U+D800..U+DFFF); they are + // not valid Unicode scalar values and must not appear in UTF-8. + if ((ch >= 0xd800) && (ch <= 0xdfff)) + throw utf_invalid_encoding_error("surrogate code point in Utf-8"); } else if ((b1 & std::byte{0xf8}) == std::byte{0xf0}) { @@ -392,6 +397,14 @@ namespace m return iter_decode_result{}; // throw utf_invalid_encoding_error("Non-shortest Utf-8 encoding"); } + + // Reject UTF-16 surrogate code points (U+D800..U+DFFF); they are + // not valid Unicode scalar values and must not appear in UTF-8. + if ((ch >= 0xd800) && (ch <= 0xdfff)) + { + ec = std::make_error_code(std::errc::illegal_byte_sequence); + return iter_decode_result{}; + } } else if ((b1 & std::byte{0xf8}) == std::byte{0xf0}) { @@ -659,17 +672,17 @@ namespace m else if (ch1 <= 0xdbff) { if (first == last) - throw std::runtime_error("utf-16 sequence truncated"); + throw utf_sequence_truncated_error("utf-16 sequence truncated"); auto const ch2 = details::to_utf16le(*first++); if ((ch2 < 0xdc00) || (ch2 > 0xdfff)) - throw std::runtime_error("utf-16 invalid surrogate pair"); + throw utf_invalid_encoding_error("utf-16 invalid surrogate pair"); ch = ((((ch1 - 0xd800) * 1024) + (ch2 - 0xdc00)) + 0x10000); } else if (ch1 <= 0xdfff) - throw std::runtime_error("invalid UTF-16 encoding"); + throw utf_invalid_encoding_error("invalid UTF-16 encoding"); else ch = ch1; @@ -733,7 +746,11 @@ namespace m auto const ch = details::to_utf32le(*first++); if (ch > 0x10ffff) - throw std::runtime_error("invalid UTF-32 character"); + throw utf_invalid_encoding_error("invalid UTF-32 character"); + + // Surrogate code points (U+D800..U+DFFF) are not valid Unicode scalar values. + if ((ch >= 0xd800) && (ch <= 0xdfff)) + throw utf_invalid_encoding_error("surrogate code point in UTF-32"); return iter_decode_result{.it = first, .ch = ch}; } @@ -754,6 +771,13 @@ namespace m // throw std::runtime_error("invalid UTF-32 character"); } + // Surrogate code points (U+D800..U+DFFF) are not valid Unicode scalar values. + if ((ch >= 0xd800) && (ch <= 0xdfff)) + { + ec = std::make_error_code(std::errc::illegal_byte_sequence); + return iter_decode_result{}; + } + return iter_decode_result{.it = first, .ch = ch}; } diff --git a/src/libraries/utf/include/m/utf/encode.h b/src/libraries/utf/include/m/utf/encode.h index e6419ea2..c86dcb84 100644 --- a/src/libraries/utf/include/m/utf/encode.h +++ b/src/libraries/utf/include/m/utf/encode.h @@ -10,6 +10,8 @@ #include #include +#include + namespace m { namespace utf @@ -21,8 +23,8 @@ namespace m { using byte_t = OutCharT; - if ((ch >= 0x110000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) - throw std::runtime_error("invalid character"); + if ((ch >= 0x110000) || ((ch >= 0xd800) && (ch <= 0xdfff))) + throw utf_invalid_encoding_error("invalid character"); if (ch < 0x00000080) { @@ -74,7 +76,7 @@ namespace m { using byte_t = OutCharT; - if ((ch >= 0x110000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) + if ((ch >= 0x110000) || ((ch >= 0xd800) && (ch <= 0xdfff))) { ec = std::make_error_code(std::errc::illegal_byte_sequence); return it; @@ -126,8 +128,8 @@ namespace m constexpr std::size_t compute_encoded_utf8_size(char32_t ch) { - if ((ch >= 0x0011'0000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) - throw std::runtime_error("invalid character"); + if ((ch >= 0x0011'0000) || ((ch >= 0xd800) && (ch <= 0xdfff))) + throw utf_invalid_encoding_error("invalid character"); if (ch < 0x0000'0080) return 1; @@ -176,8 +178,8 @@ namespace m constexpr OutIterT encode_utf16le(char32_t ch, OutIterT it) { - if ((ch >= 0x110000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) - throw std::runtime_error("invalid character"); + if ((ch >= 0x110000) || ((ch >= 0xd800) && (ch <= 0xdfff))) + throw utf_invalid_encoding_error("invalid character"); if (ch < 0x10000) { @@ -199,7 +201,7 @@ namespace m constexpr OutIterT encode_utf16le(char32_t ch, OutIterT it, std::error_code& ec) { - if ((ch >= 0x110000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) + if ((ch >= 0x110000) || ((ch >= 0xd800) && (ch <= 0xdfff))) { ec = std::make_error_code(std::errc::illegal_byte_sequence); return it; @@ -226,8 +228,8 @@ namespace m { using word_t = OutCharT; - if ((ch >= 0x110000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) - throw std::runtime_error("invalid character"); + if ((ch >= 0x110000) || ((ch >= 0xd800) && (ch <= 0xdfff))) + throw utf_invalid_encoding_error("invalid character"); if (ch < 0x10000) { @@ -245,8 +247,8 @@ namespace m constexpr std::size_t compute_encoded_utf16_count(char32_t ch) { - if ((ch >= 0x0011'0000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) - throw std::runtime_error("invalid character"); + if ((ch >= 0x0011'0000) || ((ch >= 0xd800) && (ch <= 0xdfff))) + throw utf_invalid_encoding_error("invalid character"); if (ch < 0x0001'0000) return 1; @@ -257,7 +259,7 @@ namespace m constexpr std::size_t compute_encoded_utf16_count(char32_t ch, std::error_code& ec) { - if ((ch >= 0x0011'0000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) + if ((ch >= 0x0011'0000) || ((ch >= 0xd800) && (ch <= 0xdfff))) { ec = std::make_error_code(std::errc::illegal_byte_sequence); return 0; @@ -312,8 +314,8 @@ namespace m constexpr OutIterT encode_utf16be(char32_t ch, OutIterT it) { - if ((ch >= 0x110000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) - throw std::runtime_error("invalid character"); + if ((ch >= 0x110000) || ((ch >= 0xd800) && (ch <= 0xdfff))) + throw utf_invalid_encoding_error("invalid character"); if (ch < 0x10000) { @@ -334,7 +336,7 @@ namespace m constexpr OutIterT encode_utf16be(char32_t ch, OutIterT it, std::error_code& ec) { - if ((ch >= 0x110000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) + if ((ch >= 0x110000) || ((ch >= 0xd800) && (ch <= 0xdfff))) { ec = std::make_error_code(std::errc::illegal_byte_sequence); return it; @@ -384,8 +386,8 @@ namespace m constexpr OutIterT encode_utf32le(char32_t ch, OutIterT it) { - if ((ch >= 0x110000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) - throw std::runtime_error("invalid character"); + if ((ch >= 0x110000) || ((ch >= 0xd800) && (ch <= 0xdfff))) + throw utf_invalid_encoding_error("invalid character"); *it++ = OutValueT{static_cast((ch >> 0) & 0xff)}; *it++ = OutValueT{static_cast((ch >> 8) & 0xff)}; @@ -401,7 +403,7 @@ namespace m constexpr OutIterT encode_utf32le(char32_t ch, OutIterT it, std::error_code& ec) { - if ((ch >= 0x110000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) + if ((ch >= 0x110000) || ((ch >= 0xd800) && (ch <= 0xdfff))) { ec = std::make_error_code(std::errc::illegal_byte_sequence); return it; @@ -422,8 +424,8 @@ namespace m constexpr OutIterT encode_utf32(char32_t ch, OutIterT it) { - if ((ch >= 0x110000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) - throw std::runtime_error("invalid character"); + if ((ch >= 0x110000) || ((ch >= 0xd800) && (ch <= 0xdfff))) + throw utf_invalid_encoding_error("invalid character"); *it++ = static_cast(ch); @@ -436,7 +438,7 @@ namespace m constexpr OutIterT encode_utf32(char32_t ch, OutIterT it, std::error_code& ec) { - if ((ch >= 0x110000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) + if ((ch >= 0x110000) || ((ch >= 0xd800) && (ch <= 0xdfff))) { ec = std::make_error_code(std::errc::illegal_byte_sequence); return it; @@ -451,8 +453,8 @@ namespace m constexpr std::size_t compute_encoded_utf32_bytes(char32_t ch) { - if ((ch >= 0x0011'0000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) - throw std::runtime_error("invalid character"); + if ((ch >= 0x0011'0000) || ((ch >= 0xd800) && (ch <= 0xdfff))) + throw utf_invalid_encoding_error("invalid character"); return 4; } @@ -460,7 +462,7 @@ namespace m constexpr std::size_t compute_encoded_utf32_bytes(char32_t ch, std::error_code& ec) { - if ((ch >= 0x0011'0000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) + if ((ch >= 0x0011'0000) || ((ch >= 0xd800) && (ch <= 0xdfff))) { ec = std::make_error_code(std::errc::illegal_byte_sequence); return 0; @@ -473,8 +475,8 @@ namespace m constexpr std::size_t compute_encoded_utf32_count(char32_t ch) { - if ((ch >= 0x0011'0000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) - throw std::runtime_error("invalid character"); + if ((ch >= 0x0011'0000) || ((ch >= 0xd800) && (ch <= 0xdfff))) + throw utf_invalid_encoding_error("invalid character"); return 1; } @@ -482,7 +484,7 @@ namespace m constexpr std::size_t compute_encoded_utf32_count(char32_t ch, std::error_code& ec) { - if ((ch >= 0x0011'0000) || ((ch >= 0xdc00) && (ch <= 0xdfff))) + if ((ch >= 0x0011'0000) || ((ch >= 0xd800) && (ch <= 0xdfff))) { ec = std::make_error_code(std::errc::illegal_byte_sequence); return 0; diff --git a/src/libraries/utf/src/decode_utf8.cpp b/src/libraries/utf/src/decode_utf8.cpp index 1d0a5470..69aca859 100644 --- a/src/libraries/utf/src/decode_utf8.cpp +++ b/src/libraries/utf/src/decode_utf8.cpp @@ -81,6 +81,14 @@ namespace m rv.m_char = k_invalid_character; return rv; } + + // Reject UTF-16 surrogate code points (U+D800..U+DFFF); they are + // not valid Unicode scalar values and must not appear in UTF-8. + if ((rv.m_char >= 0xd800) && (rv.m_char <= 0xdfff)) + { + rv.m_char = k_invalid_character; + return rv; + } } else if ((b1 & std::byte{0xf8}) == std::byte{0xf0}) {