From 0a5c889808283a31f6e1e360b4fad587124dd6b1 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Tue, 15 Apr 2025 13:43:45 +0200 Subject: [PATCH 1/2] Transition QAI8 tests to lazy initialization This change moves the QAI8 static initializations to lazy, C++17 compliant, initializations. This change also makes use of kernel interfaces, as to make sure they're exercised Signed-off-by: Emil Ohlsson --- CMakeLists.txt | 3 +- .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 398 +++++++++--------- 2 files changed, 194 insertions(+), 207 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5529fc5b..500e3e53 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -399,9 +399,10 @@ if(KLEIDIAI_BUILD_TESTS) ) set_source_files_properties( - test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f16_bf16p_bf16p_test.cpp + test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp + test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp test/tests/matmul_test.cpp PROPERTIES COMPILE_FLAGS "-Wpedantic") endif() diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index d1529ea0..5bdaf5c6 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -133,16 +133,69 @@ struct MatMulIndirectKernel { imatmul; }; -const static RhsPackKernel rhs_pack = { - .get_n_step = kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .get_scale_offset = kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .pack = kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, +/// Make sure that interface matches +static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel& +get_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface() { + static kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel; + + ukernel.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; + ukernel.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; + ukernel.get_nr = kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; + ukernel.get_kr = kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; + ukernel.get_sr = kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; + ukernel.get_lhs_offset = kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; + ukernel.get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; + ukernel.get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; + ukernel.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; + ukernel.run_matmul = kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot; + + return ukernel; }; +/// Make sure that interface matches +static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& +get_imatmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface() { + static kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel; + + ukernel.get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + ukernel.get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + ukernel.get_lhs_packed_offset = + kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + ukernel.get_rhs_packed_offset = + kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + ukernel.get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + ukernel.get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + ukernel.run_imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + + return ukernel; +}; + +static const RhsPackKernel& get_rhs_pack() { + static RhsPackKernel ukernel; + + ukernel.get_n_step = kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + ukernel.get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + ukernel.get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + ukernel.get_scale_offset = kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + ukernel.get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + ukernel.get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + ukernel.pack = kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + + return ukernel; +} + +static const LhsPackKernel& get_lhs_pack() { + static LhsPackKernel ukernel; + + ukernel.get_m_step = kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme; + ukernel.get_lhs_offset = kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme; + ukernel.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme; + ukernel.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme; + ukernel.pack = kai_run_lhs_pack_x8p2vlx4_x8_sme; + + return ukernel; +} + struct MatMulVariant { std::string_view name; ///< Test identification MatMulShape acc_pack; ///< Accumulator shape for packing (mr/nr/kr) @@ -167,137 +220,103 @@ struct IndirectMatMulVariant { MatMulIndirectKernel matmul; ///< Matmul kernel interface }; -const std::array gemm_variants = { - MatMulVariant{ - .name = "matmul_qai8_qai8p_qsi8cxp", - .acc_pack{ - .m = 2 * get_sme_vector_length(), - .n = 2 * get_sme_vector_length(), - .k = sizeof(int32_t) / sizeof(int8_t), - }, - .acc_step{ - .m = 2 * get_sme_vector_length(), - .n = 2 * get_sme_vector_length(), - .k = sizeof(int32_t) / sizeof(int8_t), - }, - - .is_supported = cpu_has_sme2, - - .lhs_pack = - LhsPackKernel{ - .get_m_step = kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme, - .get_lhs_offset = kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme, - .get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme, - .get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme, - .pack = kai_run_lhs_pack_x8p2vlx4_x8_sme, - }, - .rhs_pack = rhs_pack, - .matmul = - MatMulKernel{ - .get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - }, - }, -}; +static const std::array& get_gemm_variants() { + static std::array variants; + + variants[0].name = "matmul_qai8_qai8p_qsi8cxp"; + variants[0].acc_pack.m = 2 * get_sme_vector_length(); + variants[0].acc_pack.n = 2 * get_sme_vector_length(); + variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t); + variants[0].acc_step.m = 2 * get_sme_vector_length(); + variants[0].acc_step.n = 2 * get_sme_vector_length(); + variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t); + variants[0].is_supported = cpu_has_sme2; + variants[0].lhs_pack = get_lhs_pack(); + variants[0].rhs_pack = get_rhs_pack(); + variants[0].matmul.get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + variants[0].matmul.get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + variants[0].matmul.get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + variants[0].matmul.get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + variants[0].matmul.get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + variants[0].matmul.get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + variants[0].matmul.get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + variants[0].matmul.get_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + variants[0].matmul.get_dst_offset = + kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + variants[0].matmul.get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + variants[0].matmul.matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa; + + return variants; +} -const std::array indirect_gemm_variants = { - IndirectMatMulVariant{ - .name = "indirect_matmul_qai8_qai8p_qsi8cxp", - .acc_pack{ - .m = 2 * get_sme_vector_length(), - .n = 2 * get_sme_vector_length(), - .k = sizeof(int32_t) / sizeof(int8_t), - }, - .acc_step{ - .m = 2 * get_sme_vector_length(), - .n = 2 * get_sme_vector_length(), - .k = sizeof(int32_t) / sizeof(int8_t), - }, - - .is_supported = cpu_has_sme2, - - .lhs_pack = - LhsPackIndirectKernel{ - .get_m_step = kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme, - .get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme, - .get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme, - .pack = kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme, - }, - .rhs_pack = - RhsPackIndirectKernel{ - .get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .get_scale_offset = kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .get_packed_rhs_offset = - kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .get_packed_rhs_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - }, - .matmul = - MatMulIndirectKernel{ - .get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_packed_lhs_offset = - kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_packed_rhs_offset = - kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - }, - }, -}; +static const std::array& get_indirect_gemm_variants() { + static std::array variants; + static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel = + get_imatmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface(); + + variants[0].name = "indirect_matmul_qai8_qai8p_qsi8cxp"; + variants[0].acc_pack.m = 2 * get_sme_vector_length(); + variants[0].acc_pack.n = 2 * get_sme_vector_length(); + variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t); + variants[0].acc_step.m = 2 * get_sme_vector_length(); + variants[0].acc_step.n = 2 * get_sme_vector_length(); + variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t); + variants[0].is_supported = cpu_has_sme2; + variants[0].lhs_pack.get_m_step = kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme; + variants[0].lhs_pack.get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme; + variants[0].lhs_pack.get_packed_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme; + variants[0].lhs_pack.pack = kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme; + variants[0].rhs_pack.get_n_step = kai_get_n_step_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + variants[0].rhs_pack.get_rhs_offset = kai_get_rhs_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + variants[0].rhs_pack.get_bias_offset = kai_get_bias_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + variants[0].rhs_pack.get_scale_offset = kai_get_scale_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + variants[0].rhs_pack.get_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + variants[0].rhs_pack.get_packed_rhs_size = + kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + variants[0].rhs_pack.pack = kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; + variants[0].matmul.get_m_step = ukernel.get_m_step; + variants[0].matmul.get_n_step = ukernel.get_n_step; + variants[0].matmul.get_packed_lhs_offset = ukernel.get_lhs_packed_offset; + variants[0].matmul.get_packed_rhs_offset = ukernel.get_rhs_packed_offset; + variants[0].matmul.get_dst_offset = ukernel.get_dst_offset; + variants[0].matmul.get_dst_size = ukernel.get_dst_size; + variants[0].matmul.imatmul = ukernel.run_imatmul; + + return variants; +} -const std::array gemv_variants = { - MatMulVariant{ - .name = "matmul_qai8_qai8_qsi8cxp", - .acc_pack{ - .m = 1, - .n = 2 * get_sme_vector_length(), - .k = sizeof(int32_t) / sizeof(int8_t), - }, - .acc_step{ - .m = 1, - .n = 16 * get_sme_vector_length(), - .k = sizeof(int32_t) / sizeof(int8_t), - }, - - .is_supported = cpu_has_sme2, - - .lhs_pack = std::nullopt, - .rhs_pack = rhs_pack, - .matmul = MatMulKernel{ - .get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_mr = []() -> size_t { return 1; }, - .get_nr = kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_kr = kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_sr = kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_packed_lhs_offset = kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .matmul = - [](size_t m, size_t n, size_t k, const void* lhs, const void* rhs, void* dst, size_t dst_stride_row, - size_t dst_stride_col, const kai_matmul_requantize32_params* quant_param) { - kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot( - m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, quant_param); - }, - }, - }, -}; +static const std::array& get_gemv_variants() { + static std::array variants; + static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel = + get_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface(); + + variants[0].name = "matmul_qai8_qai8_qsi8cxp"; + variants[0].acc_pack.m = 1; + variants[0].acc_pack.n = 2 * get_sme_vector_length(); + variants[0].acc_pack.k = sizeof(int32_t) / sizeof(int8_t); + variants[0].acc_step.m = 1; + variants[0].acc_step.n = 16 * get_sme_vector_length(); + variants[0].acc_step.k = sizeof(int32_t) / sizeof(int8_t); + variants[0].is_supported = cpu_has_sme2; + variants[0].lhs_pack = std::nullopt; + variants[0].rhs_pack = get_rhs_pack(); + variants[0].matmul.get_m_step = ukernel.get_m_step; + variants[0].matmul.get_n_step = ukernel.get_n_step; + variants[0].matmul.get_mr = []() -> size_t { return 1; }; + variants[0].matmul.get_nr = ukernel.get_nr; + variants[0].matmul.get_kr = ukernel.get_kr; + variants[0].matmul.get_sr = ukernel.get_sr; + variants[0].matmul.get_packed_lhs_offset = nullptr; + variants[0].matmul.get_packed_rhs_offset = ukernel.get_rhs_packed_offset; + variants[0].matmul.get_dst_offset = ukernel.get_dst_offset; + variants[0].matmul.get_dst_size = ukernel.get_dst_size; + variants[0].matmul.matmul = ukernel.run_matmul; + + return variants; +} constexpr uint32_t seed = 0; ///< Random seed used for tests @@ -344,35 +363,6 @@ struct TestReference { Buffer packed_rhs; }; -/// Make sure that interface matches -static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface - [[maybe_unused]] = { - .get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_nr = kai_get_nr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_kr = kai_get_kr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_sr = kai_get_sr_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_lhs_offset = kai_get_lhs_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, - .run_matmul = kai_run_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot, -}; - -/// Make sure that interface matches -static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel - imatmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface [[maybe_unused]] = { - .get_m_step = kai_get_m_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_n_step = kai_get_n_step_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_lhs_packed_offset = - kai_get_lhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_rhs_packed_offset = - kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .run_imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, -}; - static constexpr int8_t padding_value = 0; // Functionality for hashing generated test data. @@ -539,30 +529,27 @@ static const TestReference& get_test_reference(const TestDataId& test_data_id) { rhs_qsi8_t.data(), rhs_scales.data(), lhs_scale, dst_scale, bias_qsi32.data(), lhs_zero_point, shape.n, shape.k, pack_shape.n, pack_shape.k); - const TestReference& reference = g_data[test_data_id] = { - .clamp = {.min = dst_qai8_clamp_min, .max = dst_qai8_clamp_max}, - - .qa_lhs = {.scale = lhs_scale, .zero_point = lhs_zero_point}, - .qa_dst = {.scale = dst_scale, .zero_point = dst_zero_point}, + TestReference& reference = g_data[test_data_id]; + reference.clamp.min = dst_qai8_clamp_min; + reference.clamp.max = dst_qai8_clamp_max; + reference.qa_lhs.scale = lhs_scale; + reference.qa_lhs.zero_point = lhs_zero_point; + reference.qa_dst.scale = dst_scale; + reference.qa_dst.zero_point = dst_zero_point; + reference.lhs_qai8 = std::move(lhs_qai8); + reference.lhs_qai8_scales = std::move(lhs_qai8_scales); + reference.lhs_qai8_zero_points = std::move(lhs_qai8_zero_points); + reference.lhs_qai8_indirect = std::move(lhs_qai8_indirect); + reference.lhs_qai8_indirect_packed = std::move(lhs_qai8_indirect_packed); + reference.lhs_qai8_indirect_padding = std::move(lhs_padding); + reference.lhs_qai8_indirect_offset = indirection_base; + reference.rhs_qsi8 = std::move(rhs_qsi8); + reference.rhs_scales = std::move(rhs_scales); + reference.bias_qsi32 = std::move(bias_qsi32); + reference.dst_qsi8_clamped = std::move(ref_dst_qsi8_clamped); + reference.packed_lhs = std::move(packed_lhs); + reference.packed_rhs = std::move(packed_rhs); - .lhs_qai8 = std::move(lhs_qai8), - .lhs_qai8_scales = std::move(lhs_qai8_scales), - .lhs_qai8_zero_points = std::move(lhs_qai8_zero_points), - .lhs_qai8_indirect = std::move(lhs_qai8_indirect), - .lhs_qai8_indirect_packed = std::move(lhs_qai8_indirect_packed), - .lhs_qai8_indirect_padding = std::move(lhs_padding), - .lhs_qai8_indirect_offset = indirection_base, - - .rhs_qsi8 = std::move(rhs_qsi8), - .rhs_scales = std::move(rhs_scales), - - .bias_qsi32 = std::move(bias_qsi32), - - .dst_qsi8_clamped = std::move(ref_dst_qsi8_clamped), - - .packed_lhs = std::move(packed_lhs), - .packed_rhs = std::move(packed_rhs), - }; return reference; } @@ -611,10 +598,9 @@ static void test_rhs_pack( const auto imp_scale_offset = variant.rhs_pack.get_scale_offset(output_area.start_col()); const auto imp_packed_rhs_offset = variant.rhs_pack.get_packed_rhs_offset(output_area.start_col(), shape.k); - const kai_rhs_pack_qsi8cx_params imp_pack_rhs_params{ - .lhs_zero_point = reference.qa_lhs.zero_point, - .scale_multiplier = reference.qa_lhs.scale / reference.qa_dst.scale, - }; + kai_rhs_pack_qsi8cx_params imp_pack_rhs_params{}; + imp_pack_rhs_params.lhs_zero_point = reference.qa_lhs.zero_point; + imp_pack_rhs_params.scale_multiplier = reference.qa_lhs.scale / reference.qa_dst.scale; variant.rhs_pack.pack( 1, output_area.width(), shape.k, variant.acc_pack.n, variant.acc_pack.k, 1, shape.n * sizeof(int8_t), @@ -692,11 +678,10 @@ static void test_matmul( variant.matmul.get_dst_offset(output_area.start_row(), output_area.start_col(), shape.n * sizeof(int8_t)); ASSERT_EQ(imp_dst_offset, output_area.start_row() * shape.n + output_area.start_col()); - const kai_matmul_requantize32_params imp_main_params{ - .min_value = reference.clamp.min, - .max_value = reference.clamp.max, - .output_zero_point = reference.qa_dst.zero_point, - }; + kai_matmul_requantize32_params imp_main_params{}; + imp_main_params.min_value = reference.clamp.min; + imp_main_params.max_value = reference.clamp.max; + imp_main_params.output_zero_point = reference.qa_dst.zero_point; variant.matmul.matmul( output_area.height(), output_area.width(), shape.k, lhs_data.data() + imp_lhs_offset, @@ -854,11 +839,10 @@ static Buffer matmul( Buffer dst(dst_size); // Calculate geffective uantization parameters - kai_matmul_requantize32_params requantization{ - .min_value = reference.clamp.min, - .max_value = reference.clamp.max, - .output_zero_point = reference.qa_dst.zero_point, - }; + kai_matmul_requantize32_params requantization{}; + requantization.min_value = reference.clamp.min; + requantization.max_value = reference.clamp.max; + requantization.output_zero_point = reference.qa_dst.zero_point; // Call matmul kernel variant.imatmul( @@ -940,7 +924,8 @@ static constexpr std::array shapes{ INSTANTIATE_TEST_SUITE_P( matmul_clamp_qai8_qai8p_qsi8cxp, MatMulQuantizedTest, testing::Combine( - testing::ValuesIn(gemm_variants), testing::ValuesIn(shapes), + testing::ValuesIn(get_gemm_variants()), // + testing::ValuesIn(shapes), // testing::ValuesIn({ // clang-format off MatrixPortion( 0, 0, 1, 1), // Full matrix. @@ -960,7 +945,7 @@ INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P( matmul_clamp_qai8_qai8_qsi8cxp, MatMulQuantizedTest, testing::Combine( - testing::ValuesIn(gemv_variants), + testing::ValuesIn(get_gemv_variants()), testing::ValuesIn({ // clang-format off MatMulShape{ 1, 1, 1}, @@ -1000,7 +985,8 @@ INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P( indirect_matmul_clamp_qai8_qai8p_qsi8cxp, IndirectMatMulQuantizedTest, testing::Combine( - testing::ValuesIn(indirect_gemm_variants), testing::ValuesIn(shapes), + testing::ValuesIn(get_indirect_gemm_variants()), // + testing::ValuesIn(shapes), // testing::ValuesIn({ // clang-format off // (Start row , start col , height , width) -- GitLab From 503d531c02ad3a57a9f63ae2160d19178a48cb92 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Wed, 16 Apr 2025 11:28:19 +0200 Subject: [PATCH 2/2] Move local code into anonymous namespace Signed-off-by: Emil Ohlsson --- .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 5bdaf5c6..9ac28a93 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -47,6 +47,9 @@ namespace kai::test { +// Ensure static linkage for all functionality local to this test file +namespace { + using Buffer = std::vector; using IndirectionBuffer = std::vector; @@ -134,7 +137,7 @@ struct MatMulIndirectKernel { }; /// Make sure that interface matches -static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel& +const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel& get_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface() { static kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel; @@ -153,7 +156,7 @@ get_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface() { }; /// Make sure that interface matches -static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& +const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& get_imatmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface() { static kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel ukernel; @@ -170,7 +173,7 @@ get_imatmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface() { return ukernel; }; -static const RhsPackKernel& get_rhs_pack() { +const RhsPackKernel& get_rhs_pack() { static RhsPackKernel ukernel; ukernel.get_n_step = kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme; @@ -184,7 +187,7 @@ static const RhsPackKernel& get_rhs_pack() { return ukernel; } -static const LhsPackKernel& get_lhs_pack() { +const LhsPackKernel& get_lhs_pack() { static LhsPackKernel ukernel; ukernel.get_m_step = kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme; @@ -220,7 +223,7 @@ struct IndirectMatMulVariant { MatMulIndirectKernel matmul; ///< Matmul kernel interface }; -static const std::array& get_gemm_variants() { +const std::array& get_gemm_variants() { static std::array variants; variants[0].name = "matmul_qai8_qai8p_qsi8cxp"; @@ -251,7 +254,7 @@ static const std::array& get_gemm_variants() { return variants; } -static const std::array& get_indirect_gemm_variants() { +const std::array& get_indirect_gemm_variants() { static std::array variants; static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel = get_imatmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface(); @@ -288,7 +291,7 @@ static const std::array& get_indirect_gemm_variants() return variants; } -static const std::array& get_gemv_variants() { +const std::array& get_gemv_variants() { static std::array variants; static const kai_matmul_clamp_qai8_qai8p_qsi8cxp_ukernel& ukernel = get_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot_interface(); @@ -363,7 +366,7 @@ struct TestReference { Buffer packed_rhs; }; -static constexpr int8_t padding_value = 0; +constexpr int8_t padding_value = 0; // Functionality for hashing generated test data. // This is particularly useful for portion testing @@ -398,11 +401,11 @@ struct HashTestDataId { }; // NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables) -static std::unordered_map g_data; +std::unordered_map g_data; // NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables) /// Generate test reference data -static const TestReference& get_test_reference(const TestDataId& test_data_id) { +const TestReference& get_test_reference(const TestDataId& test_data_id) { // ============================================================ // Generates input and reference output data // ============================================================ @@ -554,7 +557,7 @@ static const TestReference& get_test_reference(const TestDataId& test_data_id) { } /// Test LHS packing -static void test_lhs_pack( +void test_lhs_pack( const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { KAI_ASSUME(variant.lhs_pack.has_value()); @@ -587,7 +590,7 @@ static void test_lhs_pack( } /// Test RHS packing -static void test_rhs_pack( +void test_rhs_pack( const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { const auto imp_packed_rhs_size = variant.rhs_pack.get_packed_rhs_size(shape.n, shape.k); ASSERT_EQ(imp_packed_rhs_size, reference.packed_rhs.size()); @@ -627,7 +630,7 @@ static void test_rhs_pack( ASSERT_EQ(mismatches, 0) << "There are an unexpected amount of mismatches in RHS packing"; } -static void compare_matmul_result( +void compare_matmul_result( const MatMulShape& shape, const Rect& output_area, const Buffer& actual, const Buffer& reference) { size_t mismatches = 0; bool printed_row = false; @@ -661,7 +664,7 @@ static void compare_matmul_result( } /// Test MatMul of GEMM/GEMV like kernel -static void test_matmul( +void test_matmul( const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { const auto imp_dst_size = variant.matmul.get_dst_size(shape.m, shape.n); ASSERT_EQ(imp_dst_size, reference.dst_qsi8_clamped.size()); @@ -691,6 +694,8 @@ static void test_matmul( compare_matmul_result(shape, output_area, imp_dst, reference.dst_qsi8_clamped); } +} // namespace + using MatMulQuantizedTest = testing::TestWithParam>; using IndirectMatMulQuantizedTest = testing::TestWithParam>; -- GitLab