From 02bc8d57f60ffc3250a3d86376832ce4f338a57e Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Wed, 9 Apr 2025 07:57:32 +0200 Subject: [PATCH 1/2] Address review comments from feature review * Rename `matmul` in imatmul interface to `imatmul` * rename `zero` argument in lhs pack to `pad_ptr` * Clarify `k_chunk_length` to mean "in bytes" Signed-off-by: Emil Ohlsson --- ...matmul_clamp_qai8_qai8p_qsi8cxp_interface.h | 2 +- .../kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c | 4 ++-- .../kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h | 18 +++++++++++------- .../matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp | 8 ++++---- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h index 0a1e23f9..75eed329 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h @@ -44,7 +44,7 @@ struct kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel { kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_offset_func_t get_dst_offset; kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_size_func_t get_dst_size; - kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_matmul_func_t run_matmul; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_matmul_func_t run_imatmul; }; #ifdef __cplusplus diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c index 3e682437..25a48afc 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.c @@ -45,7 +45,7 @@ size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme( void kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme( size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, - const void* zero, void* lhs_packed) { + const void* pad_ptr, void* lhs_packed) { KAI_ASSUME(lhs_ptrs != NULL); KAI_ASSUME(lhs_packed != NULL); @@ -64,7 +64,7 @@ void kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme( for (size_t y = 0; y < height; y += 1) { KAI_ASSERT(i_k_chunk + (i_m + y) * k_chunk_count < m * k_chunk_count); in[y] = *(lhs_ptrs + i_m * k_chunk_count + i_k_chunk * m_step + y); - if (in[y] != zero) { + if (in[y] != pad_ptr) { in[y] += lhs_ptr_offset; } } diff --git a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h index 721c07a8..7136d837 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x8p2vlx4_x8p_sme.h @@ -23,7 +23,7 @@ size_t kai_get_m_step_lhs_imatmul_pack_x8p2vlx4_x8p_sme(void); /// /// @param[in] m_idx Row index in the unpacked LHS matrix. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length of a LHS column split. +/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. /// /// @return The offset in bytes to the data element. size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme( @@ -33,7 +33,7 @@ size_t kai_get_lhs_packed_offset_lhs_imatmul_pack_x8p2vlx4_x8p_sme( /// /// @param[in] m Number of rows in the unpacked LHS matrix. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length of a LHS column split. +/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. /// /// @return The size in bytes of the packed LHS buffer. size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme(size_t m, size_t k_chunk_count, size_t k_chunk_length); @@ -42,14 +42,18 @@ size_t kai_get_lhs_packed_size_lhs_imatmul_pack_x8p2vlx4_x8p_sme(size_t m, size_ /// /// @param[in] m Number of rows of the unpacked LHS matrix. /// @param[in] k_chunk_count Number of LHS column splits. -/// @param[in] k_chunk_length Length of a LHS column split. -/// @param[in] lhs_ptrs Pointer to an array of input pointers consisting of `m * k_chunk_count` pointers. -/// @param[in] lhs_ptr_offset Offset to add to each pointer of the @ref lhs_ptrs array, excluding zero pointers. -/// @param[in] zero Pointer to a zero element. Used to check for padding pointers in @ref lhs_ptrs. +/// @param[in] k_chunk_length Length, in bytes, of a LHS column split. +/// @param[in] lhs_ptrs Pointer to an array of input pointers consisting of +/// t `m * k_chunk_count` pointers. +/// @param[in] lhs_ptr_offset Offset to add to each pointer of the @ref lhs_ptrs +/// array, excluding zero pointers. +/// @param[in] pad_ptr Pointer to chunk used for padding. @ref lhs_ptr_offset is +/// not applied to this pointer when used in @ref lhs_ptrs. This can +/// be NULL if there is no padding used @ref lhs_ptrs /// @param[out] lhs_packed Packed LHS matrix. void kai_run_lhs_imatmul_pack_x8p2vlx4_x8p_sme( size_t m, size_t k_chunk_count, size_t k_chunk_length, const void* const* lhs_ptrs, size_t lhs_ptr_offset, - const void* zero, void* lhs_packed); + const void* pad_ptr, void* lhs_packed); #ifdef __cplusplus } // extern "C" diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 77c697c0..ff01ab4a 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -133,7 +133,7 @@ struct MatMulIndirectKernel { std::function - matmul; + imatmul; }; const static RhsPackKernel rhs_pack = { @@ -261,7 +261,7 @@ const std::array indirect_gemm_variants = { kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .matmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, }, }, }; @@ -379,7 +379,7 @@ static const kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel kai_get_rhs_packed_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_dst_offset = kai_get_dst_offset_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, .get_dst_size = kai_get_dst_size_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, - .run_matmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, + .run_imatmul = kai_run_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa, }; static constexpr int8_t padding_value = 0; @@ -875,7 +875,7 @@ static Buffer matmul( }; // Call matmul kernel - variant.matmul( + variant.imatmul( portion.height(), portion.width(), k_chunk.count, k_chunk.length, // Dimensions packed_lhs.data() + lhs_offset, // LHS packed_rhs.data() + rhs_offset, // RHS -- GitLab From 1abdf358f6a7d0ff4054198cac9c8b3fde02d7e8 Mon Sep 17 00:00:00 2001 From: Emil Ohlsson Date: Wed, 9 Apr 2025 10:37:41 +0200 Subject: [PATCH 2/2] Rename IGEMM type Signed-off-by: Emil Ohlsson --- .../kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h index 75eed329..01b5ca95 100644 --- a/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h +++ b/kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p_qsi8cxp_interface.h @@ -29,7 +29,7 @@ typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_offset_func_t)( typedef size_t (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_size_func_t)(size_t m, size_t n); /// Micro-kernel core function ("run" method) -typedef void (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_matmul_func_t)( +typedef void (*kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_imatmul_func_t)( size_t m, size_t n, size_t k_chunk_count, size_t k_chunk_length, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_row_stride, const struct kai_matmul_requantize32_params* params); @@ -44,7 +44,7 @@ struct kai_imatmul_clamp_qai8_qai8p_qsi8cxp_ukernel { kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_offset_func_t get_dst_offset; kai_imatmul_clamp_qai8_qai8p_qsi8cxp_get_dst_size_func_t get_dst_size; - kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_matmul_func_t run_imatmul; + kai_imatmul_clamp_qai8_qai8p_qsi8cxp_run_imatmul_func_t run_imatmul; }; #ifdef __cplusplus -- GitLab