diff --git a/cpp/include/raft/linalg/strided_reduction.cuh b/cpp/include/raft/linalg/strided_reduction.cuh index 9480eb9fa0..a7a8594f64 100644 --- a/cpp/include/raft/linalg/strided_reduction.cuh +++ b/cpp/include/raft/linalg/strided_reduction.cuh @@ -69,7 +69,7 @@ void stridedReduction(OutType* dots, // other cases, because coalescedReduction supports arbitrary types. if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || std::is_same_v) { detail::stridedReduction( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else { diff --git a/cpp/include/raft/stats/detail/mean.cuh b/cpp/include/raft/stats/detail/mean.cuh index 05e17d6d8a..5af1733525 100644 --- a/cpp/include/raft/stats/detail/mean.cuh +++ b/cpp/include/raft/stats/detail/mean.cuh @@ -10,24 +10,33 @@ #include #include +#include + namespace raft { namespace stats { namespace detail { -template -void mean(Type* mu, const Type* data, IdxType D, IdxType N, cudaStream_t stream) +template +void mean(OutType* mu, const InType* data, IdxType D, IdxType N, cudaStream_t stream) { - Type ratio = Type(1) / Type(N); - raft::linalg::reduce(mu, - data, - D, - N, - Type(0), - stream, - false, - raft::identity_op(), - raft::add_op(), - raft::mul_const_op(ratio)); + OutType ratio = OutType(1) / OutType(N); + auto main_op = [=]() { + if constexpr (std::is_same_v) { + return raft::identity_op(); + } else { + return raft::cast_op(); + } + }(); + raft::linalg::reduce(mu, + data, + D, + N, + OutType(0), + stream, + false, + main_op, + raft::add_op(), + raft::mul_const_op(ratio)); } template diff --git a/cpp/include/raft/stats/mean.cuh b/cpp/include/raft/stats/mean.cuh index 87310a0bb2..9dfc9b2d36 100644 --- a/cpp/include/raft/stats/mean.cuh +++ b/cpp/include/raft/stats/mean.cuh @@ -14,6 +14,8 @@ #include #include +#include + namespace raft { namespace stats { @@ -68,21 +70,22 @@ template */ /** - * @brief Compute mean of the input matrix + * @brief Compute mean of the input matrix with different input and output data types. * * Mean operation is assumed to be performed on a given column. * - * @tparam value_t the data type + * @tparam in_value_t the input data type + * @tparam out_value_t the output data type * @tparam idx_t index type * @tparam layout_t Layout type of the input matrix. * @param[in] handle the raft handle - * @param[in] data: the input matrix - * @param[out] mu: the output mean vector + * @param[in] data the input matrix + * @param[out] mu the output mean vector */ -template +template void mean(raft::resources const& handle, - raft::device_matrix_view data, - raft::device_vector_view mu) + raft::device_matrix_view data, + raft::device_vector_view mu) { static_assert( std::is_same_v || std::is_same_v, @@ -97,6 +100,26 @@ void mean(raft::resources const& handle, resource::get_cuda_stream(handle)); } +/** + * @brief Compute mean of the input matrix + * + * Mean operation is assumed to be performed on a given column. + * + * @tparam value_t the data type + * @tparam idx_t index type + * @tparam layout_t Layout type of the input matrix. + * @param[in] handle the raft handle + * @param[in] data: the input matrix + * @param[out] mu: the output mean vector + */ +template +void mean(raft::resources const& handle, + raft::device_matrix_view data, + raft::device_vector_view mu) +{ + mean(handle, data, mu); +} + /** * @brief Compute mean of the input matrix * diff --git a/cpp/tests/stats/mean.cu b/cpp/tests/stats/mean.cu index 4323d2f0bd..a21305b0dd 100644 --- a/cpp/tests/stats/mean.cu +++ b/cpp/tests/stats/mean.cu @@ -1,49 +1,74 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #include "../test_utils.cuh" #include +#include #include #include #include #include +#include + #include #include #include +#include +#include +#include + namespace raft { namespace stats { template +float toFloat(T value) +{ + return static_cast(value); +} + +template <> +inline float toFloat(half value) +{ + return __half2float(value); +} + +struct float_to_half_op { + __device__ half operator()(float x) const { return __float2half(x); } +}; + +template struct MeanInputs { - T tolerance, mean; + OutputT tolerance; + InputT mean; int rows, cols; bool rowMajor; unsigned long long int seed; - T stddev = (T)1.0; + InputT stddev = (InputT)1.0; }; -template -::std::ostream& operator<<(::std::ostream& os, const MeanInputs& dims) +template +::std::ostream& operator<<(::std::ostream& os, const MeanInputs& dims) { - return os << "{ " << dims.tolerance << ", " << dims.rows << ", " << dims.cols << ", " - << ", " << dims.rowMajor << ", " << dims.stddev << "}" << std::endl; + return os << "{ tol=" << toFloat(dims.tolerance) << ", mean=" << toFloat(dims.mean) + << ", rows=" << dims.rows << ", cols=" << dims.cols << ", rowMajor=" << dims.rowMajor + << ", stddev=" << toFloat(dims.stddev) << "}" << std::endl; } -template -class MeanTest : public ::testing::TestWithParam> { +template +class MeanTest : public ::testing::TestWithParam> { public: MeanTest() - : params(::testing::TestWithParam>::GetParam()), + : params(::testing::TestWithParam>::GetParam()), stream(resource::get_cuda_stream(handle)), rows(params.rows), cols(params.cols), - data(rows * cols, stream), - mean_act(cols, stream) + data(raft::make_device_matrix(handle, rows, cols)), + mean_act(raft::make_device_vector(handle, cols)) { } @@ -52,23 +77,31 @@ class MeanTest : public ::testing::TestWithParam> { { raft::random::RngState r(params.seed); int len = rows * cols; - normal(handle, r, data.data(), len, params.mean, params.stddev); - meanSGtest(data.data(), stream); + if constexpr (std::is_same_v) { + rmm::device_uvector data_float(len, stream); + normal(handle, r, data_float.data(), len, toFloat(params.mean), toFloat(params.stddev)); + raft::linalg::unaryOp(data.data_handle(), data_float.data(), len, float_to_half_op{}, stream); + } else if constexpr (std::is_integral_v) { + normalInt(handle, r, data.data_handle(), len, params.mean, params.stddev); + } else { + normal(handle, r, data.data_handle(), len, params.mean, params.stddev); + } + meanSGtest(); } - void meanSGtest(T* data, cudaStream_t stream) + void meanSGtest() { int rows = params.rows, cols = params.cols; if (params.rowMajor) { using layout = raft::row_major; mean(handle, - raft::make_device_matrix_view(data, rows, cols), - raft::make_device_vector_view(mean_act.data(), cols)); + raft::make_device_matrix_view(data.data_handle(), rows, cols), + raft::make_device_vector_view(mean_act.data_handle(), cols)); } else { using layout = raft::col_major; mean(handle, - raft::make_device_matrix_view(data, rows, cols), - raft::make_device_vector_view(mean_act.data(), cols)); + raft::make_device_matrix_view(data.data_handle(), rows, cols), + raft::make_device_vector_view(mean_act.data_handle(), cols)); } } @@ -76,9 +109,10 @@ class MeanTest : public ::testing::TestWithParam> { raft::resources handle; cudaStream_t stream; - MeanInputs params; + MeanInputs params; int rows, cols; - rmm::device_uvector data, mean_act; + raft::device_matrix data; + raft::device_vector mean_act; }; // Note: For 1024 samples, 256 experiments, a mean of 1.0 with stddev=1.0, the @@ -131,23 +165,66 @@ const std::vector> inputsd = {{0.15, -1.0, 1024, 32, false, 1 {1e-8, 1e-1, 1 << 27, 2, false, 1234ULL, 0.0001}, {1e-8, 1e-1, 1 << 27, 2, true, 1234ULL, 0.0001}}; +const std::vector> inputshf = { + {0.15f, -1.f, 1024, 32, false, 1234ULL}, + {0.15f, -1.f, 1024, 64, false, 1234ULL}, + {0.15f, -1.f, 1024, 128, false, 1234ULL}, + {0.15f, -1.f, 1024, 256, false, 1234ULL}, + {0.15f, -1.f, 1024, 32, true, 1234ULL}, + {0.15f, -1.f, 1024, 64, true, 1234ULL}, + {0.0001f, 0.1f, 1 << 27, 2, false, 1234ULL, 0.0001f}}; + +const std::vector> inputsi8h = {{0.95f, -5, 8096, 32, false, 1234ULL, 1}, + {0.5f, 1, 8096, 10, false, 1234ULL, 10}, + {0.15f, 0, 60000, 128, false, 1234ULL, 6}, + {0.5f, -1, 8096, 256, false, 1234ULL, 2}, + {1.0f, 8, 2000, 32, true, 1234ULL, 1}, + {0.50f, -1, 20000, 64, true, 1234ULL, 5}, + {1.0f, 6, 10024, 2, false, 1234ULL, 10}}; + typedef MeanTest MeanTestF; TEST_P(MeanTestF, Result) { - ASSERT_TRUE( - devArrMatch(params.mean, mean_act.data(), params.cols, CompareApprox(params.tolerance))); + ASSERT_TRUE(devArrMatch( + params.mean, mean_act.data_handle(), params.cols, CompareApprox(params.tolerance))); } typedef MeanTest MeanTestD; TEST_P(MeanTestD, Result) { ASSERT_TRUE(devArrMatch( - params.mean, mean_act.data(), params.cols, CompareApprox(params.tolerance))); + params.mean, mean_act.data_handle(), params.cols, CompareApprox(params.tolerance))); } INSTANTIATE_TEST_SUITE_P(MeanTests, MeanTestF, ::testing::ValuesIn(inputsf)); INSTANTIATE_TEST_SUITE_P(MeanTests, MeanTestD, ::testing::ValuesIn(inputsd)); +typedef MeanTest MeanTestHF; +TEST_P(MeanTestHF, Result) +{ + ASSERT_TRUE(devArrMatch(toFloat(params.mean), + mean_act.data_handle(), + params.cols, + CompareApprox(params.tolerance))); +} + +typedef MeanTest MeanTestI8H; +TEST_P(MeanTestI8H, Result) +{ + std::vector mean_act_h(params.cols); + raft::update_host(mean_act_h.data(), mean_act.data_handle(), params.cols, stream); + raft::resource::sync_stream(handle); + + auto expected = toFloat(params.mean); + auto tolerance = toFloat(params.tolerance); + for (int i = 0; i < params.cols; ++i) { + ASSERT_NEAR(toFloat(mean_act_h[i]), expected, tolerance) << " @col=" << i; + } +} + +INSTANTIATE_TEST_SUITE_P(MeanTests, MeanTestHF, ::testing::ValuesIn(inputshf)); + +INSTANTIATE_TEST_SUITE_P(MeanTests, MeanTestI8H, ::testing::ValuesIn(inputsi8h)); } // end namespace stats } // end namespace raft