diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c index 9b27c742b6eb4e542752db28bce2ba6b59cf27e0..ae71f09c1559fbb40f79a433348f140d7da968a2 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c @@ -47,7 +47,7 @@ inline static size_t kai_rhs_packed_stride(size_t k) { KAI_ASSERT((k_internal % 2) == 0); - return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) { diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp index 6519d028bbcf19b4a6c74b9cebf5c0d6c33aa958..403b85eb1d1881b0382e82e646442aafb2560da6 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -34,6 +34,7 @@ #include "test/common/bfloat16.hpp" #include "test/common/cpu_info.hpp" #include "test/common/int4.hpp" +#include "test/common/matrix_portion.hpp" #include "test/common/memory.hpp" #include "test/common/round.hpp" #include "test/common/test_suite.hpp" @@ -66,27 +67,41 @@ static const std::array; +using MatMulTestParams_withBL = std::tuple; class UkernelVariantTest_withBL : public ::testing::TestWithParam {}; class MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p : public UkernelVariantTest_withBL {}; TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, Offset_RHS) { - const auto& [variant_index, matmul_shape, bl] = GetParam(); + const auto& [variant_index, matmul_shape, bl, portion] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_index); const size_t K = matmul_shape.k; + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); + kai_datatype scale_dt = kai_datatype::kai_dt_bf16; auto n_step = ukernel_variant.interface.get_n_step(); auto a_tmp = ukernel_variant.interface.get_rhs_packed_offset(n_step, K, bl) / n_step; auto b_tmp = ukernel_variant.interface.get_rhs_packed_offset(n_step * 16, K, bl) / (n_step * 16); ASSERT_EQ(a_tmp, b_tmp); + + auto rhs_packed_offset_kxn = + kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(n_step, K, nr, kr, sr, bl, scale_dt); + auto rhs_packed_offset_nxk = + kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(n_step, K, nr, kr, sr, bl, scale_dt); + + ASSERT_EQ(rhs_packed_offset_kxn, rhs_packed_offset_nxk); + + auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(n_step, K, bl); + ASSERT_EQ(rhs_packed_offset_kxn, rhs_matmul_offset); } TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, Offset_LHS) { - const auto& [variant_index, matmul_shape, bl] = GetParam(); + const auto& [variant_index, matmul_shape, bl, portion] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_index); const size_t K = matmul_shape.k; @@ -99,7 +114,7 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, Offset_LHS) { } TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_nxk) { - const auto& [variant_index, matmul_shape, bl] = GetParam(); + const auto& [variant_index, matmul_shape, bl, portion] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_index); if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { @@ -121,6 +136,7 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_nxk) { const auto ref_lhs = fill_random(M * K, seed + 0); const auto ref_rhs = fill_random(N * K, seed + 1); const auto ref_biases = fill_random(N, seed + 2); + kai_datatype scale_dt = kai_datatype::kai_dt_bf16; // Runs the reference implementation. // * Quantizes the LHS matrix using 8-bit asymmetric quantization. @@ -136,11 +152,31 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_nxk) { ref_rhs_scales.data(), nullptr, bl, ref_biases.data(), std::numeric_limits::lowest(), std::numeric_limits::max()); + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto lhs_start_row = rect.start_row(); + size_t lhs_stride = K * sizeof(float); + // Runs the LHS packing micro-kernel. const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); std::vector imp_packed_lhs(imp_packed_lhs_size); + + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); + kai_run_lhs_quant_pack_qai8dxp_f32( - M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, + reinterpret_cast(ref_lhs.data() + lhs_offset), lhs_stride, + imp_packed_lhs.data() + lhs_packed_offset); // Runs the RHS packing micro-kernel. // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. @@ -153,8 +189,13 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_nxk) { const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::kai_dt_bf16); + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt); std::vector imp_packed_rhs(imp_packed_rhs_size); + + const auto packed_rhs_start_row = rect.start_col(); + auto rhs_packed_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(packed_rhs_start_row, K, nr, kr, sr, bl, scale_dt); + constexpr kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{ .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::kai_dt_bf16}; kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( @@ -162,20 +203,30 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_nxk) { reinterpret_cast(ref_biases.data()), reinterpret_cast(ref_rhs_scales.data()), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, ¶ms); - // Runs the GEMM micro-kernel. + const auto dst_stride = N * sizeof(float); + const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); + const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); + ASSERT_EQ(dst_offset, ref_dst_offset); + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); ASSERT_EQ(imp_dst_size, ref_dst.size()); + + // Runs the GEMM micro-kernel. std::vector imp_dst(imp_dst_size); ukernel_variant.interface.run_matmul( - M, N, K, bl, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), + rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_packed_offset, + imp_packed_rhs.data() + rhs_packed_offset, reinterpret_cast(imp_dst.data() + dst_offset), N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); - // Compares the output of the micro-kernels against the output of the reference implementation. - for (size_t y = 0; y < M; ++y) { - for (size_t x = 0; x < N; ++x) { - const auto imp_value = read_array(imp_dst.data(), y * N + x); - const auto ref_value = read_array(ref_dst.data(), y * N + x); - const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + for (size_t y = 0; y < rect.height(); ++y) { + for (size_t x = 0; x < rect.width(); ++x) { + const auto imp_value = + read_array(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto ref_value = + read_array(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; if (rel_error > 0.0001F) { ASSERT_EQ(imp_value, ref_value); @@ -185,7 +236,7 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_nxk) { } TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_kxn) { - const auto& [variant_index, matmul_shape, bl] = GetParam(); + const auto& [variant_index, matmul_shape, bl, portion] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_index); if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { @@ -207,6 +258,7 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_kxn) { const auto ref_lhs = fill_random(M * K, seed + 0); const auto ref_rhs_transposed = fill_random(N * K, seed + 1); const auto ref_biases = fill_random(N, seed + 2); + kai_datatype scale_dt = kai_datatype::kai_dt_bf16; // Transposed(nxk) RHS dimensions const size_t ref_rhs_qsi4_nxk_stride = K; @@ -234,11 +286,31 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_kxn) { ref_rhs_scales.data(), nullptr, bl, ref_biases.data(), std::numeric_limits::lowest(), std::numeric_limits::max()); + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto lhs_start_row = rect.start_row(); + size_t lhs_stride = K * sizeof(float); + // Runs the LHS packing micro-kernel. const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); std::vector imp_packed_lhs(imp_packed_lhs_size); + + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); + kai_run_lhs_quant_pack_qai8dxp_f32( - M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, + reinterpret_cast(ref_lhs.data() + lhs_offset), lhs_stride, + imp_packed_lhs.data() + lhs_packed_offset); // Runs the RHS packing micro-kernel. // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. @@ -249,8 +321,12 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_kxn) { const size_t ref_rhs_qsu4_stride = round_up_division(N, 2); const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); + const auto packed_rhs_start_row = rect.start_col(); + auto rhs_packed_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(packed_rhs_start_row, K, nr, kr, sr, bl, scale_dt); + const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::kai_dt_bf16); + kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, scale_dt); std::vector imp_packed_rhs(imp_packed_rhs_size); constexpr kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params{ .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::kai_dt_bf16}; @@ -259,20 +335,28 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_kxn) { reinterpret_cast(ref_biases.data()), ref_rhs_scales.data(), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, ¶ms); + const auto dst_stride = N * sizeof(float); + const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); + const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); + ASSERT_EQ(dst_offset, ref_dst_offset); + // Runs the GEMM micro-kernel. const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); ASSERT_EQ(imp_dst_size, ref_dst.size()); std::vector imp_dst(imp_dst_size); ukernel_variant.interface.run_matmul( - M, N, K, bl, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), + rect.height(), rect.width(), K, bl, imp_packed_lhs.data() + lhs_packed_offset, + imp_packed_rhs.data() + rhs_packed_offset, reinterpret_cast(imp_dst.data() + dst_offset), N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); // Compares the output of the micro-kernels against the output of the reference implementation. - for (size_t y = 0; y < M; ++y) { - for (size_t x = 0; x < N; ++x) { - const auto imp_value = read_array(imp_dst.data(), y * N + x); - const auto ref_value = read_array(ref_dst.data(), y * N + x); - const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); + for (size_t y = 0; y < rect.height(); ++y) { + for (size_t x = 0; x < rect.width(); ++x) { + const auto imp_value = + read_array(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto ref_value = + read_array(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; if (rel_error > 0.0001F) { ASSERT_EQ(imp_value, ref_value); @@ -291,15 +375,26 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{17, 25, 64}, // MatMulShape{15, 31, 128}, // MatMulShape{1, 25, 64}), - testing::Values(32, 64)), + testing::Values(32, 64), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. + MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. + MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle + )), [](const auto& info) { const auto variant_idx = std::get<0>(info.param); const std::string name{variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_idx).name}; const auto shape = std::get(info.param); const auto bl = std::get<2>(info.param); + const auto portion = std::get<3>(info.param); std::stringstream sstream; - sstream << name << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k << "__BL_" << bl; + sstream << name << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k << "__BL_" << bl + << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // + << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // + << "__PortionHeight_" << static_cast(portion.height() * 1000) // + << "__PortionWidth_" << static_cast(portion.width() * 1000); return sstream.str(); }); diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp index 3a5d0b587348cdc1e9fee6391e966fa4bdff795d..402af683a0e7d014b5205b3f1c54c85298d4fe42 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp @@ -35,6 +35,7 @@ #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h" #include "test/common/cpu_info.hpp" #include "test/common/int4.hpp" +#include "test/common/matrix_portion.hpp" #include "test/common/memory.hpp" #include "test/common/round.hpp" #include "test/common/test_suite.hpp" @@ -46,16 +47,21 @@ #include "test/reference/transpose.hpp" namespace kai::test { +/// Matrix multiplication test information. +using MatMulTestParams_with_portion = std::tuple; +class UkernelVariantTest_with_portions : public ::testing::TestWithParam {}; enum class RhsPackType { NxK, KxN }; using ukernel_rhs_pack_function = std::function; using ukernel_get_rhs_packed_size = std::function; +using ukernel_get_rhs_packed_offset = std::function; template struct UkernelVariantCustom : public UkernelVariant { ukernel_rhs_pack_function run_rhs_pack; ukernel_get_rhs_packed_size get_rhs_packed_size; + ukernel_get_rhs_packed_offset get_rhs_packed_offset; RhsPackType rhs_pack_type; bool signed_integer_support; @@ -64,10 +70,12 @@ struct UkernelVariantCustom : public UkernelVariant { UkernelVariantCustom( T interface, std::string_view name, const std::function& fn_is_supported, ukernel_rhs_pack_function run_rhs_pack, ukernel_get_rhs_packed_size get_rhs_packed_size, - const RhsPackType pack_type, const bool signed_integer_support) : + ukernel_get_rhs_packed_offset get_rhs_packed_offset, const RhsPackType pack_type, + const bool signed_integer_support) : UkernelVariant(interface, name, fn_is_supported), run_rhs_pack(std::move(run_rhs_pack)), get_rhs_packed_size(std::move(get_rhs_packed_size)), + get_rhs_packed_offset(std::move(get_rhs_packed_offset)), rhs_pack_type(pack_type), signed_integer_support(signed_integer_support) { } @@ -78,101 +86,104 @@ static const std::array::lowest(), std::numeric_limits::max()); + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto lhs_start_row = rect.start_row(); + size_t lhs_stride = K * sizeof(float); + // Runs the LHS packing micro-kernel. const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); std::vector imp_packed_lhs(imp_packed_lhs_size); + + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); + ASSERT_EQ(lhs_packed_offset, ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K)); + kai_run_lhs_quant_pack_qai8dxp_f32( - M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, + reinterpret_cast(ref_lhs.data() + lhs_offset), lhs_stride, + imp_packed_lhs.data() + lhs_packed_offset); // Runs the RHS packing micro-kernel. // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. @@ -228,6 +260,9 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsi4cx) { ref_rhs_qsi4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2)); const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr); + const auto packed_rhs_start_row = rect.start_col(); + auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(packed_rhs_start_row, K, nr, kr, sr); + ASSERT_EQ(rhs_packed_offset, ukernel_variant.interface.get_rhs_packed_offset(packed_rhs_start_row, K)); std::vector imp_packed_rhs(imp_packed_rhs_size); const kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{.lhs_zero_point = 1, .rhs_zero_point = 0}; @@ -235,20 +270,29 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsi4cx) { 1, N, K, nr, kr, sr, ref_rhs_qsi4_padded.data(), reinterpret_cast(ref_biases.data()), reinterpret_cast(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, ¶ms); + const auto dst_stride = N * sizeof(float); + const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); + const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); + ASSERT_EQ(dst_offset, ref_dst_offset); + // Runs the GEMM micro-kernel. const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); ASSERT_EQ(imp_dst_size, ref_dst.size()); std::vector imp_dst(imp_dst_size); ukernel_variant.interface.run_matmul( - M, N, K, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), + rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_packed_offset, + imp_packed_rhs.data() + rhs_packed_offset, reinterpret_cast(imp_dst.data() + dst_offset), N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); - // Compares the output of the micro-kernels against the output of the reference implementation. - for (size_t y = 0; y < M; ++y) { - for (size_t x = 0; x < N; ++x) { - const auto imp_value = read_array(imp_dst.data(), y * N + x); - const auto ref_value = read_array(ref_dst.data(), y * N + x); - const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + for (size_t y = 0; y < rect.height(); ++y) { + for (size_t x = 0; x < rect.width(); ++x) { + const auto imp_value = + read_array(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto ref_value = + read_array(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; if (rel_error > 0.0001F) { ASSERT_EQ(imp_value, ref_value); @@ -256,8 +300,9 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsi4cx) { } } } + TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsu4cx) { - const auto& [variant_index, matmul_shape] = GetParam(); + const auto& [variant_index, matmul_shape, portion] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { @@ -297,11 +342,32 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsu4cx) { ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits::lowest(), std::numeric_limits::max()); + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto lhs_start_row = rect.start_row(); + size_t lhs_stride = K * sizeof(float); + // Runs the LHS packing micro-kernel. const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); std::vector imp_packed_lhs(imp_packed_lhs_size); + + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); + ASSERT_EQ(lhs_packed_offset, ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K)); + kai_run_lhs_quant_pack_qai8dxp_f32( - M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, + reinterpret_cast(ref_lhs.data() + lhs_offset), lhs_stride, + imp_packed_lhs.data() + lhs_packed_offset); const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); // Runs the RHS packing micro-kernel. @@ -311,6 +377,9 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsu4cx) { ref_rhs_qsu4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2)); const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr); + const auto packed_rhs_start_row = rect.start_col(); + auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(packed_rhs_start_row, K, nr, kr, sr); + ASSERT_EQ(rhs_packed_offset, ukernel_variant.interface.get_rhs_packed_offset(packed_rhs_start_row, K)); std::vector imp_packed_rhs(imp_packed_rhs_size); const kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{.lhs_zero_point = 1, .rhs_zero_point = 8}; @@ -318,6 +387,10 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsu4cx) { 1, N, K, nr, kr, sr, ref_rhs_qsu4_padded.data(), reinterpret_cast(ref_biases.data()), reinterpret_cast(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, ¶ms); + const auto dst_stride = N * sizeof(float); + const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); + const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); + ASSERT_EQ(dst_offset, ref_dst_offset); // Runs the GEMM micro-kernel. const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); ASSERT_EQ(imp_dst_size, ref_dst.size()); @@ -341,7 +414,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsu4cx) { } TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsi4cx) { - const auto& [variant_index, matmul_shape] = GetParam(); + const auto& [variant_index, matmul_shape, portion] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { @@ -395,11 +468,31 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsi4cx) { ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits::lowest(), std::numeric_limits::max()); + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto lhs_start_row = rect.start_row(); + size_t lhs_stride = K * sizeof(float); + // Runs the LHS packing micro-kernel. const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); std::vector imp_packed_lhs(imp_packed_lhs_size); + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); + ASSERT_EQ(lhs_packed_offset, ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K)); + kai_run_lhs_quant_pack_qai8dxp_f32( - M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, + reinterpret_cast(ref_lhs.data() + lhs_offset), lhs_stride, + imp_packed_lhs.data() + lhs_packed_offset); // Runs the RHS packing micro-kernel. // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. @@ -407,26 +500,39 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsi4cx) { const auto ref_rhs_qsi4_padded = pad_row( ref_rhs_qsi4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2)); const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr); + const auto packed_rhs_start_row = rect.start_col(); + auto rhs_packed_offset = ukernel_variant.get_rhs_packed_offset(packed_rhs_start_row, K, nr, kr, sr); + ASSERT_EQ(rhs_packed_offset, ukernel_variant.interface.get_rhs_packed_offset(packed_rhs_start_row, K)); + std::vector imp_packed_rhs(imp_packed_rhs_size); const kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{.lhs_zero_point = 1, .rhs_zero_point = 0}; ukernel_variant.run_rhs_pack( 1, N, K, nr, kr, sr, ref_rhs_qsi4_padded.data(), reinterpret_cast(ref_biases.data()), reinterpret_cast(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, ¶ms); + const auto dst_stride = N * sizeof(float); + const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); + const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); + ASSERT_EQ(dst_offset, ref_dst_offset); + // Runs the GEMM micro-kernel. const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); ASSERT_EQ(imp_dst_size, ref_dst.size()); std::vector imp_dst(imp_dst_size); ukernel_variant.interface.run_matmul( - M, N, K, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), + rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_packed_offset, + imp_packed_rhs.data() + rhs_packed_offset, reinterpret_cast(imp_dst.data() + dst_offset), N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); - // Compares the output of the micro-kernels against the output of the reference implementation. - for (size_t y = 0; y < M; ++y) { - for (size_t x = 0; x < N; ++x) { - const auto imp_value = read_array(imp_dst.data(), y * N + x); - const auto ref_value = read_array(ref_dst.data(), y * N + x); - const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + for (size_t y = 0; y < rect.height(); ++y) { + for (size_t x = 0; x < rect.width(); ++x) { + const auto imp_value = + read_array(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto ref_value = + read_array(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; if (rel_error > 0.0001F) { ASSERT_EQ(imp_value, ref_value); @@ -436,7 +542,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsi4cx) { } TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsu4cx) { - const auto& [variant_index, matmul_shape] = GetParam(); + const auto& [variant_index, matmul_shape, portion] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_index); if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { @@ -488,11 +594,32 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsu4cx) { ref_rhs_scales.data(), nullptr, K, ref_biases.data(), std::numeric_limits::lowest(), std::numeric_limits::max()); + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP(); + } + + const auto lhs_start_row = rect.start_row(); + size_t lhs_stride = K * sizeof(float); + // Runs the LHS packing micro-kernel. const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); std::vector imp_packed_lhs(imp_packed_lhs_size); + + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, lhs_stride); + auto lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(lhs_start_row, K, mr, kr, sr); + ASSERT_EQ(lhs_packed_offset, ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K)); + kai_run_lhs_quant_pack_qai8dxp_f32( - M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + rect.height() /* m */, K, mr, kr, sr, 0 /* m_idx_start*/, + reinterpret_cast(ref_lhs.data() + lhs_offset), lhs_stride, + imp_packed_lhs.data() + lhs_packed_offset); // Runs the RHS packing micro-kernel. // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. @@ -501,26 +628,39 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsu4cx) { const auto ref_rhs_qsu4_padded = pad_row( ref_rhs_qsu4.data(), K, N, N, round_up_multiple(N, 2), round_up_division(K * round_up_multiple(N, 2), 2)); const auto imp_packed_rhs_size = ukernel_variant.get_rhs_packed_size(N, K, nr, kr, sr); + const auto packed_rhs_start_row = rect.start_col(); + auto rhs_packed_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(packed_rhs_start_row, K, nr, kr, sr); + ASSERT_EQ(rhs_packed_offset, ukernel_variant.interface.get_rhs_packed_offset(packed_rhs_start_row, K)); + std::vector imp_packed_rhs(imp_packed_rhs_size); const kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{.lhs_zero_point = 1, .rhs_zero_point = 8}; ukernel_variant.run_rhs_pack( 1, N, K, nr, kr, sr, ref_rhs_qsu4_padded.data(), reinterpret_cast(ref_biases.data()), reinterpret_cast(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, ¶ms); + const auto dst_stride = N * sizeof(float); + const auto dst_offset = ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride); + const auto ref_dst_offset = rect.start_row() * dst_stride + rect.start_col() * sizeof(float); + ASSERT_EQ(dst_offset, ref_dst_offset); + // Runs the GEMM micro-kernel. const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); ASSERT_EQ(imp_dst_size, ref_dst.size()); std::vector imp_dst(imp_dst_size); ukernel_variant.interface.run_matmul( - M, N, K, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), + rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_packed_offset, + imp_packed_rhs.data() + rhs_packed_offset, reinterpret_cast(imp_dst.data() + dst_offset), N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); // Compares the output of the micro-kernels against the output of the reference implementation. - for (size_t y = 0; y < M; ++y) { - for (size_t x = 0; x < N; ++x) { - const auto imp_value = read_array(imp_dst.data(), y * N + x); - const auto ref_value = read_array(ref_dst.data(), y * N + x); - const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); + for (size_t y = 0; y < rect.height(); ++y) { + for (size_t x = 0; x < rect.width(); ++x) { + const auto imp_value = + read_array(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto ref_value = + read_array(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; if (rel_error > 0.0001F) { ASSERT_EQ(imp_value, ref_value); @@ -539,14 +679,25 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{15, 35, 65}, // MatMulShape{8, 32, 64}, // MatMulShape{15, 31, 45}, // - MatMulShape{1, 35, 65})), + MatMulShape{1, 35, 65}), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. + MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. + MatrixPortion(0, 0.5, 1, 0.8) // Somewhere Middle + )), [](const auto& info) { const auto variant_idx = std::get<0>(info.param); const std::string name{variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.at(variant_idx).name}; const auto shape = std::get(info.param); + const auto portion = std::get<2>(info.param); std::stringstream sstream; - sstream << name << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k; + sstream << name << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k << "__PortionStartRow_" + << static_cast(portion.start_row() * 1000) // + << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // + << "__PortionHeight_" << static_cast(portion.height() * 1000) // + << "__PortionWidth_" << static_cast(portion.width() * 1000); return sstream.str(); });