From b2674c1b1d94a965872f92f2e6cce8c984a967ad Mon Sep 17 00:00:00 2001 From: Michael Kozlov Date: Thu, 1 Aug 2024 16:45:14 +0100 Subject: [PATCH] Use variant/visit for fn_matmul Currently MatMulMethod struct reqires a new field for each new kai_run_matmul_* method, which is not nicely scalable. Use std::variant and std::visit to encapsulate different possible matmul methods and provide calling of the correct method. Signed-off-by: Michael Kozlov --- test/tests/matmul_test.cpp | 77 +++++++++++++------------------------- 1 file changed, 25 insertions(+), 52 deletions(-) diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index ac534645..52f5c1b8 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -221,44 +221,10 @@ struct MatMulMethod { /// @return The size in bytes of the destination matrix buffer. std::function fn_get_dst_size; - /// Performs F16 matrix multiplication with RHS packing followed by - /// clamp operation. - /// - /// @param[in] m Size of the matrix in M dimension. - /// @param[in] n Size of the matrix in N dimension. - /// @param[in] k Size of the matrix in K dimension. - /// @param[in] lhs LHS data buffer. - /// @param[in] packed_rhs Packed RHS data buffer. - /// @param[out] dst Output data buffer. - /// @param[in] lhs_stride LHS row stride. - /// @param[in] dst_stride Output row stride. - /// @param[in] clamp_min Lower bound of the output data. - /// @param[in] clamp_max Upper bound of the output data. - std::function - fn_matmul_f16_f16_f16p; - - /// Performs F32 matrix multiplication with LHS & RHS packing - /// followed by clamp operation. - /// - /// @param[in] m Number of output rows to be computed. - /// @param[in] n Number of output columns to be computed. - /// @param[in] k Common dimension of the LHS and RHS operands. - /// @param[in] packed_lhs Packed LHS matrix buffer. - /// @param[in] packed_rhs Packed RHS matrix buffer. - /// @param[out] dst Output matrix buffer. - /// @param[in] dst_stride_row Row stride in bytes of the output matrix. - /// @param[in] dst_stride_col Column stride in bytes of the output matrix. - /// @param[in] clamp_min Minimum value to clamp the final result. - /// @param[in] clamp_max Maximum value to clamp the final result. - std::function - fn_matmul_f32_f32p_f32p; + std::variant< + std::monostate, std::function, + std::function> + fn_matmul; /// Gets a value indicating whether pre-processing the RHS matrix is needed. [[nodiscard]] bool is_pack_rhs_needed() const { @@ -295,7 +261,7 @@ struct MatMulMethod { } [[nodiscard]] bool has_main_kernel() const { - return fn_matmul_f16_f16_f16p != nullptr || fn_matmul_f32_f32p_f32p != nullptr; + return !std::holds_alternative(fn_matmul); } void main_kernel( @@ -304,15 +270,24 @@ struct MatMulMethod { KAI_UNUSED(bias); KAI_UNUSED(rhs_stride); - if (fn_matmul_f16_f16_f16p) { - fn_matmul_f16_f16_f16p( - m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), static_cast(clamp_min), - static_cast(clamp_max)); - } else if (fn_matmul_f32_f32p_f32p) { - fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); - } else { - KAI_ERROR("Main kernel is not available!"); - } + std::visit( + [&](auto&& matmul) { + using T = std::decay_t; + if constexpr (std::is_same_v< + T, + std::function< + decltype(kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla)>>) { + matmul( + m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), + static_cast(clamp_min), static_cast(clamp_max)); + + } else if constexpr ( + std::is_same_v< + T, std::function>) { + matmul(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max); + } + }, + fn_matmul); } }; @@ -364,8 +339,7 @@ static const std::array matmul_methods = { .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_matmul_f32_f32p_f32p = nullptr, + .fn_matmul = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, }, MatMulMethod{ @@ -415,8 +389,7 @@ static const std::array matmul_methods = { .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_matmul_f16_f16_f16p = nullptr, - .fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + .fn_matmul = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, }, }; -- GitLab