diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index f84db17f885ac4564f5007065c12c93eb0a72d13..638b1dc3364ae880e45f9cf8ad21b4115c0a73f6 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -35,174 +35,187 @@ namespace kai::test { -namespace { - -struct GemmVariant { - size_t acc_height; - size_t acc_width; - size_t acc_fanin; - - bool (*fn_is_supported)(); - - size_t (*fn_pack_lhs_get_m_step)(size_t mr); - size_t (*fn_pack_lhs_get_lhs_offset)(size_t m_idx, size_t lhs_stride); - size_t (*fn_pack_lhs_get_packed_lhs_offset)(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); - size_t (*fn_pack_lhs_get_packed_lhs_size)(size_t m, size_t k, size_t mr, size_t kr, size_t sr); - void (*fn_pack_lhs_run)( +using Buffer = std::vector; + +struct LhsPackKernel { + std::function get_m_step; + std::function get_lhs_offset; + std::function get_packed_lhs_offset; + std::function get_packed_lhs_size; + std::function + pack; +}; + +struct RhsPackKernel { + std::function get_n_step; + std::function get_rhs_offset; + std::function get_bias_offset; + std::function get_scale_offset; + std::function get_packed_rhs_offset; + std::function get_packed_rhs_size; + std::function + pack; +}; + +struct MatMulKernel { + std::function get_m_step; + std::function get_n_step; + std::function get_mr; + std::function get_nr; + std::function get_kr; + std::function get_sr; + std::function get_packed_lhs_offset; + std::function get_packed_rhs_offset; + std::function get_dst_offset; + std::function get_dst_size; + std::function + matmul; }; -const std::array gemm_variants = { - GemmVariant{ - .acc_height = 2 * get_sme_vector_length(), - .acc_width = 2 * get_sme_vector_length(), - .acc_fanin = sizeof(int32_t) / sizeof(int8_t), - - .fn_is_supported = cpu_has_sme2, - - .fn_pack_lhs_get_m_step = kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme, - .fn_pack_lhs_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme, - .fn_pack_lhs_get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme, - .fn_pack_lhs_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme, - .fn_pack_lhs_run = kai_run_lhs_pack_x8p2vlx4_x8_sme, - - .fn_pack_rhs_get_n_step = kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .fn_pack_rhs_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .fn_pack_rhs_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .fn_pack_rhs_get_scale_offset = kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .fn_pack_rhs_get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .fn_pack_rhs_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - .fn_pack_rhs_run = kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, - - .fn_main_get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .fn_main_get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .fn_main_get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .fn_main_get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .fn_main_get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .fn_main_get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .fn_main_get_packed_lhs_offset = - kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .fn_main_get_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .fn_main_get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .fn_main_get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .fn_main_run = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - }, +const static RhsPackKernel rhs_pack = { + .get_n_step = kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_scale_offset = kai_get_scale_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, + .pack = kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme, }; -constexpr float output_clamp_rate = 0.1F; // Clamping 10% the range of the output. - -const std::array gemm_shapes = { - MatMulShape{1, 1, 1}, // - MatMulShape{ - 2 * get_sme_vector_length(), 2 * get_sme_vector_length(), - sizeof(int32_t) / sizeof(int8_t)}, // - MatMulShape{20, 30, 40}, // - MatMulShape{1, 49, 21}, // - MatMulShape{23, 1, 43}, // - MatMulShape{32, 14, 1}, // - MatMulShape{123, 85, 45}, // - MatMulShape{130, 130, 6}, +struct MatMulVariant { + std::string_view name; ///< Test identification + MatMulShape acc; ///< Accumulator shape + + std::function is_supported; ///< HW support check + + std::optional lhs_pack; ///< LHS packing kernel interface + RhsPackKernel rhs_pack; ///< RHS packing kernel interface + MatMulKernel matmul; ///< Matmul kernel interface }; -const std::array output_portions = { - MatrixPortion(0, 0, 1, 1), // Full matrix. - MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. - MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. +const std::array gemm_variants = { + MatMulVariant{ + .name = "matmul_qai8_qai8p_qsi8cxp", + .acc{ + .m = 2 * get_sme_vector_length(), + .n = 2 * get_sme_vector_length(), + .k = sizeof(int32_t) / sizeof(int8_t), + }, + + .is_supported = cpu_has_sme2, + + .lhs_pack = + LhsPackKernel{ + .get_m_step = kai_get_m_step_lhs_pack_x8p2vlx4_x8_sme, + .get_lhs_offset = kai_get_lhs_offset_lhs_pack_x8p2vlx4_x8_sme, + .get_packed_lhs_offset = kai_get_lhs_packed_offset_lhs_pack_x8p2vlx4_x8_sme, + .get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x8p2vlx4_x8_sme, + .pack = kai_run_lhs_pack_x8p2vlx4_x8_sme, + }, + .rhs_pack = rhs_pack, + .matmul = + MatMulKernel{ + .get_m_step = kai_get_m_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_n_step = kai_get_n_step_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_mr = kai_get_mr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_nr = kai_get_nr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_kr = kai_get_kr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_sr = kai_get_sr_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_dst_offset = kai_get_dst_offset_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .get_dst_size = kai_get_dst_size_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .matmul = kai_run_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + }, + }, }; -void run_test(const MatMulShape& shape, const GemmVariant& variant, const MatrixPortion& output_portion) { - const uint64_t seed = 0; +constexpr uint64_t seed = 0; ///< Random seed used for tests +constexpr float output_clamp_rate = 0.1F; ///< Clamping range in ration of output - if (!variant.fn_is_supported()) { - GTEST_SKIP(); +/// Value range +template +struct Range { + T min; + T max; + + [[nodiscard]] T range() const { + return max - min; } +}; - // ============================================================ - // Test the packing and scheduling parameters - // ============================================================ +/// Quantization parameters +struct Quant { + float scale; + int32_t zero_point; +}; - const auto imp_mr = variant.fn_main_get_mr(); - const auto imp_nr = variant.fn_main_get_nr(); - const auto imp_kr = variant.fn_main_get_kr(); - const auto imp_sr = variant.fn_main_get_sr(); +/// Reference test data +struct TestReference { + Range clamp; - ASSERT_EQ(imp_mr, variant.acc_height); - ASSERT_EQ(imp_nr, variant.acc_width); - ASSERT_EQ(imp_kr, variant.acc_fanin); - ASSERT_EQ(imp_sr, 1); + Quant qa_lhs; + Quant qa_dst; - const auto imp_m_step = variant.fn_main_get_m_step(); - const auto imp_n_step = variant.fn_main_get_n_step(); + Buffer lhs_qai8; + Buffer lhs_qai8_scales; + Buffer lhs_qai8_zero_points; - ASSERT_EQ(imp_m_step, variant.acc_height); - ASSERT_EQ(imp_n_step, variant.acc_width); + Buffer rhs_qsi8; + Buffer rhs_scales; - // ============================================================ - // Calculates the output area under test - // ============================================================ + Buffer bias_qsi32; + + Buffer dst_qsi8_clamped; - const auto output_area = output_portion.compute_portion(shape.m, shape.n, variant.acc_height, variant.acc_width); + Buffer packed_lhs; + Buffer packed_rhs; +}; +/// Generate test reference data +TestReference get_test_reference(const MatMulShape& shape, const MatMulVariant& variant) { // ============================================================ // Generates input and reference output data // ============================================================ // Generates the input data in floating-point. - const auto lhs_f32 = fill_random(shape.m * shape.k, seed + 0); - const auto rhs_f32 = fill_random(shape.k * shape.n, seed + 1); - const auto bias_f32 = fill_random(shape.n, seed + 2); + const auto lhs_f32 = fill_random(shape.m * shape.k, seed); + const auto rhs_f32 = fill_random(shape.k * shape.n, seed); + const auto bias_f32 = fill_random(shape.n, seed); // Quantizes the input data. // * LHS: 8-bit asymmetric per-matrix quantization. // * RHS: 8-bit symmetric per-channel quantization. // * Bias: 32-bit symmetric per-channel quantization. - const auto [lhs_qai8, lhs_qai8_scales, lhs_qai8_zero_points] = + auto [lhs_qai8, lhs_qai8_scales, lhs_qai8_zero_points] = quantize_asymmetric_per_block_dynamic( lhs_f32.data(), 1, shape.m * shape.k, shape.m * shape.k); const auto lhs_scale = read_array(lhs_qai8_scales.data(), 0); const auto lhs_zero_point = read_array(lhs_qai8_zero_points.data(), 0); const auto rhs_f32_t = transpose(rhs_f32.data(), shape.k, shape.n); - const auto [rhs_qsi8_t, rhs_scales] = + auto [rhs_qsi8_t, rhs_scales] = quantize_symmetric_per_block_dynamic(rhs_f32_t.data(), shape.n, shape.k, shape.k); - const auto rhs_qsi8 = transpose(rhs_qsi8_t.data(), shape.n, shape.k); + auto rhs_qsi8 = transpose(rhs_qsi8_t.data(), shape.n, shape.k); - const auto bias_scale = mul(&lhs_scale, 1, 1, rhs_scales.data(), 1, shape.n); - const auto bias_qsi32 = - quantize_symmetric_per_block(bias_f32.data(), bias_scale.data(), shape.n, 1, 1); + const auto bias_scales = mul(&lhs_scale, 1, 1, rhs_scales.data(), 1, shape.n); + auto bias_qsi32 = + quantize_symmetric_per_block(bias_f32.data(), bias_scales.data(), shape.n, 1, 1); // Runs the reference implementation of matmul to produce floating-point result. const auto ref_dst_f32 = matmul_nt_t_quantized( shape.m, shape.n, shape.k, lhs_qai8.data(), &lhs_scale, &lhs_zero_point, shape.m, shape.k, - rhs_qsi8_t.data(), rhs_scales.data(), nullptr, 1, shape.k, bias_qsi32.data(), bias_scale.data(), nullptr, + rhs_qsi8_t.data(), rhs_scales.data(), nullptr, 1, shape.k, bias_qsi32.data(), bias_scales.data(), nullptr, 1); // Computes the output quantization information and clamping limits. @@ -233,7 +246,7 @@ void run_test(const MatMulShape& shape, const GemmVariant& variant, const Matrix // Clamps and quantizes the reference output matrix. const auto ref_dst_f32_clamped = clamp(ref_dst_f32.data(), shape.m * shape.n, ref_dst_f32_clamp_min, ref_dst_f32_clamp_max); - const auto ref_dst_qsi8_clamped = quantize_asymmetric_per_block( + auto ref_dst_qsi8_clamped = quantize_asymmetric_per_block( ref_dst_f32_clamped.data(), &dst_scale, &dst_zero_point, 1, shape.m * shape.n, shape.m * shape.n); // Runs the reference implementation of the packing functions. @@ -241,148 +254,224 @@ void run_test(const MatMulShape& shape, const GemmVariant& variant, const Matrix // The reference packing functions cannot be executed earlier // because we need the reference floating-point output first to have // the quantization information. - const auto ref_packed_lhs = - reorder_block(lhs_qai8.data(), shape.m, shape.k, variant.acc_height, variant.acc_fanin); - - const auto ref_packed_rhs = matmul_pack_rhs_nxk_static_quantized( + auto packed_lhs = reorder_block(lhs_qai8.data(), shape.m, shape.k, variant.acc.m, variant.acc.k); + auto packed_rhs = matmul_pack_rhs_nxk_static_quantized( rhs_qsi8_t.data(), rhs_scales.data(), lhs_scale, dst_scale, bias_qsi32.data(), lhs_zero_point, shape.n, shape.k, - variant.acc_width, variant.acc_fanin); + variant.acc.n, variant.acc.k); - // ============================================================ - // Runs the optimized implementation and checks for correctness - // ============================================================ + return { + .clamp = {.min = dst_qai8_clamp_min, .max = dst_qai8_clamp_max}, + + .qa_lhs = {.scale = lhs_scale, .zero_point = lhs_zero_point}, + .qa_dst = {.scale = dst_scale, .zero_point = dst_zero_point}, + + .lhs_qai8 = std::move(lhs_qai8), + .lhs_qai8_scales = std::move(lhs_qai8_scales), + .lhs_qai8_zero_points = std::move(lhs_qai8_zero_points), + + .rhs_qsi8 = std::move(rhs_qsi8), + .rhs_scales = std::move(rhs_scales), + + .bias_qsi32 = std::move(bias_qsi32), + + .dst_qsi8_clamped = std::move(ref_dst_qsi8_clamped), + + .packed_lhs = std::move(packed_lhs), + .packed_rhs = std::move(packed_rhs), + }; +} + +/// Test LHS packing +void test_lhs_pack( + const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { + KAI_ASSUME(variant.lhs_pack.has_value()); - // Runs the optimized implementation of LHS packing. const auto imp_packed_lhs_size = - variant.fn_pack_lhs_get_packed_lhs_size(shape.m, shape.k, variant.acc_height, variant.acc_fanin, 1); - ASSERT_EQ(imp_packed_lhs_size, ref_packed_lhs.size()); - std::vector imp_packed_lhs(imp_packed_lhs_size); - - { - const auto imp_lhs_offset = - variant.fn_pack_lhs_get_lhs_offset(output_area.start_row(), shape.k * sizeof(int8_t)); - const auto imp_packed_lhs_offset = - variant.fn_pack_lhs_get_packed_lhs_offset(output_area.start_row(), shape.k, imp_mr, imp_kr, imp_sr); - - variant.fn_pack_lhs_run( - output_area.height(), shape.k, imp_mr, imp_kr, imp_sr, 0, lhs_qai8.data() + imp_lhs_offset, - shape.k * sizeof(int8_t), imp_packed_lhs.data() + imp_packed_lhs_offset); - - const auto imp_packed_lhs_end_offset = output_area.end_row() < shape.m - ? variant.fn_pack_lhs_get_packed_lhs_offset(output_area.end_row(), shape.k, imp_mr, imp_kr, imp_sr) - : imp_packed_lhs_size; - - for (size_t i = 0; i < ref_packed_lhs.size(); ++i) { - if (i >= imp_packed_lhs_offset && i < imp_packed_lhs_end_offset) { - ASSERT_EQ(imp_packed_lhs[i], ref_packed_lhs[i]); - } else { - ASSERT_EQ(imp_packed_lhs[i], 0); - } + variant.lhs_pack->get_packed_lhs_size(shape.m, shape.k, variant.acc.m, variant.acc.k, 1); + ASSERT_EQ(imp_packed_lhs_size, reference.packed_lhs.size()); + + Buffer imp_packed_lhs(imp_packed_lhs_size); + const auto imp_lhs_offset = variant.lhs_pack->get_lhs_offset(output_area.start_row(), shape.k * sizeof(int8_t)); + const auto imp_packed_lhs_offset = + variant.lhs_pack->get_packed_lhs_offset(output_area.start_row(), shape.k, variant.acc.m, variant.acc.k, 1); + + variant.lhs_pack->pack( + output_area.height(), shape.k, variant.acc.m, variant.acc.k, 1, 0, reference.lhs_qai8.data() + imp_lhs_offset, + shape.k * sizeof(int8_t), imp_packed_lhs.data() + imp_packed_lhs_offset); + + const auto imp_packed_lhs_end_offset = output_area.end_row() < shape.m + ? variant.lhs_pack->get_packed_lhs_offset(output_area.end_row(), shape.k, variant.acc.m, variant.acc.k, 1) + : imp_packed_lhs_size; + + for (size_t i = 0; i < reference.packed_lhs.size(); ++i) { + if (i >= imp_packed_lhs_offset && i < imp_packed_lhs_end_offset) { + ASSERT_EQ(imp_packed_lhs[i], reference.packed_lhs[i]); + } else { + ASSERT_EQ(imp_packed_lhs[i], 0); } } +} - // Runs the optimized implementation of RHS packing. - const auto imp_packed_rhs_size = variant.fn_pack_rhs_get_packed_rhs_size(shape.n, shape.k); - ASSERT_EQ(imp_packed_rhs_size, ref_packed_rhs.size()); - std::vector imp_packed_rhs(imp_packed_rhs_size); - - { - const auto imp_rhs_offset = variant.fn_pack_rhs_get_rhs_offset(output_area.start_col()); - const auto imp_bias_offset = variant.fn_pack_rhs_get_bias_offset(output_area.start_col()); - const auto imp_scale_offset = variant.fn_pack_rhs_get_scale_offset(output_area.start_col()); - const auto imp_packed_rhs_offset = variant.fn_pack_rhs_get_packed_rhs_offset(output_area.start_col(), shape.k); - - const kai_rhs_pack_qsi8cx_params imp_pack_rhs_params{ - .lhs_zero_point = lhs_zero_point, - .scale_multiplier = lhs_scale / dst_scale, - }; - - variant.fn_pack_rhs_run( - 1, output_area.width(), shape.k, imp_nr, imp_kr, imp_sr, shape.n * sizeof(int8_t), - rhs_qsi8.data() + imp_rhs_offset, bias_qsi32.data() + imp_bias_offset, rhs_scales.data() + imp_scale_offset, - imp_packed_rhs.data() + imp_packed_rhs_offset, 0, &imp_pack_rhs_params); - - const auto imp_packed_rhs_end_offset = output_area.end_col() < shape.n - ? variant.fn_pack_rhs_get_packed_rhs_offset(output_area.end_col(), shape.k) - : imp_packed_rhs_size; - - for (size_t i = 0; i < ref_packed_rhs.size(); ++i) { - if (i >= imp_packed_rhs_offset && i < imp_packed_rhs_end_offset) { - ASSERT_EQ(imp_packed_rhs[i], ref_packed_rhs[i]); - } else { - ASSERT_EQ(imp_packed_rhs[i], 0); - } +/// Test RHS packing +void test_rhs_pack( + const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { + const auto imp_packed_rhs_size = variant.rhs_pack.get_packed_rhs_size(shape.n, shape.k); + ASSERT_EQ(imp_packed_rhs_size, reference.packed_rhs.size()); + Buffer imp_packed_rhs(imp_packed_rhs_size); + + const auto imp_rhs_offset = variant.rhs_pack.get_rhs_offset(output_area.start_col()); + const auto imp_bias_offset = variant.rhs_pack.get_bias_offset(output_area.start_col()); + const auto imp_scale_offset = variant.rhs_pack.get_scale_offset(output_area.start_col()); + const auto imp_packed_rhs_offset = variant.rhs_pack.get_packed_rhs_offset(output_area.start_col(), shape.k); + + const kai_rhs_pack_qsi8cx_params imp_pack_rhs_params{ + .lhs_zero_point = reference.qa_lhs.zero_point, + .scale_multiplier = reference.qa_lhs.scale / reference.qa_dst.scale, + }; + + variant.rhs_pack.pack( + 1, output_area.width(), shape.k, variant.acc.n, variant.acc.k, 1, shape.n * sizeof(int8_t), + reference.rhs_qsi8.data() + imp_rhs_offset, reference.bias_qsi32.data() + imp_bias_offset, + reference.rhs_scales.data() + imp_scale_offset, imp_packed_rhs.data() + imp_packed_rhs_offset, 0, + &imp_pack_rhs_params); + + const auto imp_packed_rhs_end_offset = output_area.end_col() < shape.n + ? variant.rhs_pack.get_packed_rhs_offset(output_area.end_col(), shape.k) + : imp_packed_rhs_size; + + for (size_t i = 0; i < reference.packed_rhs.size(); ++i) { + if (i >= imp_packed_rhs_offset && i < imp_packed_rhs_end_offset) { + ASSERT_EQ(imp_packed_rhs[i], reference.packed_rhs[i]); + } else { + ASSERT_EQ(imp_packed_rhs[i], 0); } } +} + +/// Test MatMul of GEMM like kernel +void test_matmul( + const MatMulShape& shape, const MatMulVariant& variant, const Rect& output_area, const TestReference& reference) { + const auto imp_dst_size = variant.matmul.get_dst_size(shape.m, shape.n); + ASSERT_EQ(imp_dst_size, reference.dst_qsi8_clamped.size()); - // Runs the optimized implementation of GEMM kernel. - const auto imp_dst_size = variant.fn_main_get_dst_size(shape.m, shape.n); - ASSERT_EQ(imp_dst_size, ref_dst_qsi8_clamped.size()); - - std::vector imp_dst(imp_dst_size); - - { - const auto imp_packed_lhs_offset = variant.fn_main_get_packed_lhs_offset(output_area.start_row(), shape.k); - const auto imp_packed_rhs_offset = variant.fn_main_get_packed_rhs_offset(output_area.start_col(), shape.k); - const auto imp_dst_offset = - variant.fn_main_get_dst_offset(output_area.start_row(), output_area.start_col(), shape.n * sizeof(int8_t)); - ASSERT_EQ(imp_dst_offset, output_area.start_row() * shape.n + output_area.start_col()); - - const kai_matmul_requantize32_params imp_main_params{ - .min_value = dst_qai8_clamp_min, - .max_value = dst_qai8_clamp_max, - .output_zero_point = dst_zero_point, - }; - - variant.fn_main_run( - output_area.height(), output_area.width(), shape.k, imp_packed_lhs.data() + imp_packed_lhs_offset, - imp_packed_rhs.data() + imp_packed_rhs_offset, imp_dst.data() + imp_dst_offset, shape.n * sizeof(int8_t), - sizeof(int8_t), &imp_main_params); - - for (size_t y = 0; y < shape.m; ++y) { - for (size_t x = 0; x < shape.n; ++x) { - const auto i = y * shape.n + x; - const auto in_area = y >= output_area.start_row() && y < output_area.end_row() && - x >= output_area.start_col() && x < output_area.end_col(); - - const int32_t imp_value = read_array(imp_dst.data(), i); - const int32_t ref_value = in_area ? read_array(ref_dst_qsi8_clamped.data(), i) : 0; - const auto error = std::abs(imp_value - ref_value); - const auto threshold = in_area ? 1 : 0; - - if (error > threshold) { - ASSERT_EQ(imp_value, ref_value); - } + Buffer imp_dst(imp_dst_size); + const auto [imp_lhs_offset, lhs_data] = [&]() -> std::tuple { + if (variant.lhs_pack.has_value()) { + return {variant.matmul.get_packed_lhs_offset(output_area.start_row(), shape.k), reference.packed_lhs}; + } + return {output_area.start_row() * shape.k, reference.lhs_qai8}; + }(); + const size_t imp_packed_rhs_offset = variant.matmul.get_packed_rhs_offset(output_area.start_col(), shape.k); + const size_t imp_dst_offset = + variant.matmul.get_dst_offset(output_area.start_row(), output_area.start_col(), shape.n * sizeof(int8_t)); + ASSERT_EQ(imp_dst_offset, output_area.start_row() * shape.n + output_area.start_col()); + + const kai_matmul_requantize32_params imp_main_params{ + .min_value = reference.clamp.min, + .max_value = reference.clamp.max, + .output_zero_point = reference.qa_dst.zero_point, + }; + + variant.matmul.matmul( + output_area.height(), output_area.width(), shape.k, lhs_data.data() + imp_lhs_offset, + reference.packed_rhs.data() + imp_packed_rhs_offset, imp_dst.data() + imp_dst_offset, shape.n * sizeof(int8_t), + sizeof(int8_t), &imp_main_params); + + for (size_t y = 0; y < shape.m; ++y) { + for (size_t x = 0; x < shape.n; ++x) { + const auto i = y * shape.n + x; + const auto in_area = y >= output_area.start_row() && y < output_area.end_row() && + x >= output_area.start_col() && x < output_area.end_col(); + + const auto imp_value = read_array(imp_dst.data(), i); + const auto ref_value = in_area ? read_array(reference.dst_qsi8_clamped.data(), i) : 0; + const auto error = std::abs(imp_value - ref_value); + const auto threshold = in_area ? 1 : 0; + + if (error > threshold) { + ASSERT_EQ(imp_value, ref_value); } } } } -using ThisTest = testing::TestWithParam>; +using ThisTest = testing::TestWithParam>; + +static std::string test_description( + const MatMulVariant& variant, // + const MatMulShape& shape, // + const MatrixPortion& portion) { + std::stringstream sstream; + sstream << "Method_" << variant.name << "__M_" // + << shape.m << "__N_" << shape.n << "__K_" << shape.k // + << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // + << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // + << "__PortionHeight_" << static_cast(portion.height() * 1000) // + << "__PortionWidth_" << static_cast(portion.width() * 1000); + return sstream.str(); +}; TEST_P(ThisTest, EndToEnd) { const auto& [variant, shape, output_portion] = GetParam(); - run_test(shape, variant, output_portion); -} + if (!variant.is_supported()) { + GTEST_SKIP() << "CPU features are not supported by current CPU"; + } + + TestReference reference = get_test_reference(shape, variant); + + // Check scheduling parameters + const auto imp_mr = variant.matmul.get_mr(); + const auto imp_nr = variant.matmul.get_nr(); + const auto imp_kr = variant.matmul.get_kr(); + const auto imp_sr = variant.matmul.get_sr(); -} // namespace + ASSERT_EQ(imp_mr, variant.acc.m); + ASSERT_EQ(imp_nr, variant.acc.n); + ASSERT_EQ(imp_kr, variant.acc.k); + ASSERT_EQ(imp_sr, 1); + + const auto imp_m_step = variant.matmul.get_m_step(); + const auto imp_n_step = variant.matmul.get_n_step(); + + ASSERT_EQ(imp_m_step, variant.acc.m); + ASSERT_EQ(imp_n_step, variant.acc.n); + + // Test kernels + const auto output_area = output_portion.compute_portion(shape.m, shape.n, variant.acc.m, variant.acc.n); + if (variant.lhs_pack.has_value()) { + test_lhs_pack(shape, variant, output_area, reference); + } + test_rhs_pack(shape, variant, output_area, reference); + test_matmul(shape, variant, output_area, reference); +} INSTANTIATE_TEST_SUITE_P( matmul_clamp_qai8_qai8p_qsi8cxp, ThisTest, testing::Combine( - testing::ValuesIn(gemm_variants), testing::ValuesIn(gemm_shapes), testing::ValuesIn(output_portions)), - [](const auto& info) { - const auto shape = std::get(info.param); - const auto portion = std::get(info.param); - - std::stringstream sstream; - sstream << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // - << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // - << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // - << "__PortionHeight_" << static_cast(portion.height() * 1000) // - << "__PortionWidth_" << static_cast(portion.width() * 1000); - return sstream.str(); + testing::ValuesIn(gemm_variants), + testing::ValuesIn({ + MatMulShape{1, 1, 1}, // + MatMulShape{ + 2 * get_sme_vector_length(), 2 * get_sme_vector_length(), + sizeof(int32_t) / sizeof(int8_t)}, // + MatMulShape{20, 30, 40}, // + MatMulShape{1, 49, 21}, // + MatMulShape{23, 1, 43}, // + MatMulShape{32, 14, 1}, // + MatMulShape{123, 85, 45}, // + MatMulShape{130, 130, 6}, + }), + testing::ValuesIn({ + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 0.25, 0.25), // Top-left corner. + MatrixPortion(0.75, 0.75, 1, 1) // Bottom-right corner. + })), + [](const auto& info) -> std::string { + return test_description( + std::get(info.param), // + std::get(info.param), // + std::get(info.param)); }); } // namespace kai::test