From f38c33c4c9c377c169aecc42b4f4eafb238b5dc8 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Tue, 27 May 2025 14:37:37 +0100 Subject: [PATCH 1/3] Address code review comments and clang-tidy failure, rebase against main Signed-off-by: Evie Wright --- ...rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c | 82 +++++++++++++------ test/reference/fill.cpp | 34 ++++---- test/reference/fill.hpp | 11 +++ .../matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp | 81 ++++++++---------- 4 files changed, 122 insertions(+), 86 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c index d0c66276..36c20fd1 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c @@ -128,33 +128,69 @@ void kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( int32_t sum = 0; // Iterate over k src columns in blocks of kr columns - for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) { - // Iterate over columns in the kr block - // Kr checked to be multiple of 2 (because 2 values per byte) - for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) { - // We pad dst with 0s if the rounded k or n values have been exceeded - if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) { - dst_kr_block[kr_block_idx / 2] = 0; - continue; + if (rhs_zero_point == 8) { + for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) { + // Iterate over columns in the kr block + // Kr checked to be multiple of 2 (because 2 values per byte) + for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) { + // We pad dst with 0s if the rounded k or n values have been exceeded + if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) { + dst_kr_block[kr_block_idx / 2] = 0; + continue; + } + + // Load the 2 u4 values from source + const uint8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2]; + + // extract i8 values from the 2 u4 values + const uint8_t first_value = (dst_byte & 0xF) - rhs_zero_point; + const uint8_t second_value = + col_idx + kr_block_idx + 1 >= k ? 0 : (dst_byte >> 4) - rhs_zero_point; + + // Add the i4 value to the row sum + sum += (int32_t)first_value + (int32_t)second_value; + + // Truncate i8 to i4 and write to dst + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + dst_kr_block[kr_block_idx / 2] = (second_value << 4) | (first_value & 0xF); } - // Load the 2 u4 values from source - const uint8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2]; - - // extract i8 values from the 2 u4 values - const int32_t first_value = (dst_byte & 0xF) - rhs_zero_point; - const int32_t second_value = col_idx + kr_block_idx + 1 >= k ? 0 : (dst_byte >> 4) - rhs_zero_point; - - // Add the i4 value to the row sum - sum += first_value + second_value; - - // Truncate i8 to i4 and write to dst - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - dst_kr_block[kr_block_idx / 2] = (second_value << 4) | (first_value & 0xF); + // Go to the next kr block for this row in the nr rows + dst_kr_block += dst_nr_block_size; } + } else { + for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) { + // Iterate over columns in the kr block + // Kr checked to be multiple of 2 (because 2 values per byte) + for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) { + // We pad dst with 0s if the rounded k or n values have been + // exceeded + if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) { + dst_kr_block[kr_block_idx / 2] = 0; + continue; + } + + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + // Load the 2 u4 values from source + const int8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2]; + + // extract i8 values from the 2 u4 values, shift first value + // back and forth to get the sign right. + const int8_t first_value = kai_ext_sign_i8_i4(dst_byte & 0xF); + const int8_t second_value = + col_idx + kr_block_idx + 1 >= k ? 0 : kai_ext_sign_i8_i4((dst_byte >> 4) & 0xF); + + // Add the i4 value to the row sum + sum += (int32_t)first_value + (int32_t)second_value; + + // Truncate i8 to i4 and write to dst + dst_kr_block[kr_block_idx / 2] = (second_value << 4) | (first_value & 0xF); + // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + } - // Go to the next kr block for this row in the nr rows - dst_kr_block += dst_nr_block_size; + // Go to the next kr block for this row in the nr rows + dst_kr_block += dst_nr_block_size; + } } // save sum diff --git a/test/reference/fill.cpp b/test/reference/fill.cpp index 0e155340..2e113a3b 100644 --- a/test/reference/fill.cpp +++ b/test/reference/fill.cpp @@ -25,23 +25,6 @@ namespace kai::test { namespace { -template -Buffer fill_matrix_raw(size_t height, size_t width, std::function gen) { - const auto size = height * width * size_in_bits / 8; - KAI_ASSUME(width * size_in_bits % 8 == 0); - - Buffer data(size); - auto ptr = reinterpret_cast(data.data()); - - for (size_t y = 0; y < height; ++y) { - for (size_t x = 0; x < width; ++x) { - write_array(ptr, y * width + x, gen(y, x)); - } - } - - return data; -} - template Buffer fill_matrix_random_raw(size_t height, size_t width, uint32_t seed) { using TDist = std::conditional_t< @@ -87,6 +70,23 @@ Buffer fill_matrix_random_raw(size_t height, size_t width, uint32_t seed) } // namespace +template +Buffer fill_matrix_raw(size_t height, size_t width, std::function gen) { + const auto size = height * width * size_in_bits / 8; + KAI_ASSUME(width * size_in_bits % 8 == 0); + + Buffer data(size); + auto ptr = reinterpret_cast(data.data()); + + for (size_t y = 0; y < height; ++y) { + for (size_t x = 0; x < width; ++x) { + write_array(ptr, y * width + x, gen(y, x)); + } + } + + return data; +} + Buffer fill_matrix_random(size_t height, size_t width, const DataFormat& format, uint32_t seed) { switch (format.pack_format()) { case DataFormat::PackFormat::NONE: diff --git a/test/reference/fill.hpp b/test/reference/fill.hpp index 29c9cf3b..7f12d0cb 100644 --- a/test/reference/fill.hpp +++ b/test/reference/fill.hpp @@ -8,6 +8,7 @@ #include #include +#include #include "test/common/buffer.hpp" @@ -36,4 +37,14 @@ Buffer fill_matrix_random(size_t height, size_t width, const DataFormat& format, template Buffer fill_random(size_t length, uint32_t seed); +/// Creates a new matrix filled with data produced by a generator function. +/// +/// @param[in] height Number of rows. +/// @param[in] width Number of columns. +/// @param[in] gen Generator function. +/// +/// @return The data buffer for the matrix. +template +Buffer fill_matrix_raw(size_t height, size_t width, std::function gen); + } // namespace kai::test diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp index dd6a27a2..9c9e129a 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp @@ -10,8 +10,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -33,6 +33,7 @@ #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h" #include "test/common/buffer.hpp" +#include "test/common/compare.hpp" #include "test/common/cpu_info.hpp" #include "test/common/int4.hpp" #include "test/common/matmul_test_common.hpp" @@ -64,7 +65,6 @@ struct UkernelVariantCustom : public UkernelVariant { ukernel_get_rhs_packed_offset get_rhs_packed_offset; ukernel_get_rhs_offset get_rhs_offset; RhsPackType rhs_pack_type; - bool signed_integer_support; UkernelVariantCustom() = delete; @@ -72,14 +72,13 @@ struct UkernelVariantCustom : public UkernelVariant { T interface, std::string_view name, const std::function& fn_is_supported, ukernel_rhs_pack_function run_rhs_pack, ukernel_get_rhs_packed_size get_rhs_packed_size, ukernel_get_rhs_packed_offset get_rhs_packed_offset, ukernel_get_rhs_offset get_rhs_offset, - const RhsPackType pack_type, const bool signed_integer_support) : + const RhsPackType pack_type) : UkernelVariant(interface, name, fn_is_supported), run_rhs_pack(std::move(run_rhs_pack)), get_rhs_packed_size(std::move(get_rhs_packed_size)), get_rhs_packed_offset(std::move(get_rhs_packed_offset)), get_rhs_offset(std::move(get_rhs_offset)), - rhs_pack_type(pack_type), - signed_integer_support(signed_integer_support) { + rhs_pack_type(pack_type) { } }; @@ -90,115 +89,115 @@ static const std::array(M * K, seed + 0); - const auto ref_rhs = fill_random(N * K, seed + 1); const auto ref_biases = fill_random(N, seed + 2); + std::uniform_real_distribution dist(-10.0, 1.0); + + std::mt19937 rnd(seed + 1); + + const auto ref_rhs = fill_matrix_raw(1, N * K, [&](size_t, size_t) { return dist(rnd); }); // Runs the reference implementation. // * Quantizes the LHS matrix using 8-bit asymmetric quantization. @@ -533,9 +533,6 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsi4cx) { if (ukernel_variant.rhs_pack_type == RhsPackType::NxK) { GTEST_SKIP() << "Wrong type. This test for KxN"; } - if (!ukernel_variant.signed_integer_support) { - GTEST_SKIP() << "Signed integer input unsupported"; - } const uint32_t seed = 0; @@ -774,20 +771,12 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsu4cx) { imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast(imp_dst.data() + dst_offset), N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); - // Compares the output of the micro-kernels against the output of the reference implementation. - for (size_t y = 0; y < rect.height(); ++y) { - for (size_t x = 0; x < rect.width(); ++x) { - const auto imp_value = - read_array(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); - const auto ref_value = - read_array(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); - const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; - - if (rel_error > 0.0001F) { - ASSERT_EQ(imp_value, ref_value); - } - } - } + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + DefaultMismatchHandler handler(0, 0.1, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::FP32); + const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); } INSTANTIATE_TEST_SUITE_P( -- GitLab From ce2b22fae71fb0f36a04e1f769d3a26eedbb3067 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Tue, 27 May 2025 16:09:56 +0100 Subject: [PATCH 2/3] add explicit instantiation for fill_matrix_raw function Signed-off-by: Evie Wright --- test/reference/fill.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/test/reference/fill.cpp b/test/reference/fill.cpp index 2e113a3b..179acc58 100644 --- a/test/reference/fill.cpp +++ b/test/reference/fill.cpp @@ -124,5 +124,6 @@ Buffer fill_random(size_t length, uint32_t seed) { template Buffer fill_random(size_t length, uint32_t seed); template Buffer fill_random(size_t length, uint32_t seed); +template Buffer fill_matrix_raw(size_t height, size_t width, std::function gen); } // namespace kai::test -- GitLab From a5925f9e7726d668dd81031693f2dfcd09f62512 Mon Sep 17 00:00:00 2001 From: Evie Wright Date: Wed, 28 May 2025 10:05:38 +0100 Subject: [PATCH 3/3] address review comments for test file Signed-off-by: Evie Wright --- test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp index 9c9e129a..d9850f3d 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -293,7 +294,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsi4cx) { std::mt19937 rnd(seed + 1); - const auto ref_rhs = fill_matrix_raw(1, N * K, [&](size_t, size_t) { return dist(rnd); }); + const auto ref_rhs = fill_matrix_raw(1, N * K, [&dist, &rnd](size_t, size_t) { return dist(rnd); }); // Runs the reference implementation. // * Quantizes the LHS matrix using 8-bit asymmetric quantization. -- GitLab