From af1ab950b5fe3211b1d311021d724b61252a3380 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Mon, 28 Oct 2024 14:49:32 +0000 Subject: [PATCH 1/2] Rename bf16 matmul kernel to include Lhs packing parameters Signed-off-by: Gunes Bayir --- ...32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.c} | 24 +++++++-------- ...32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.h} | 29 ++++++++++--------- .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 28 +++++++++--------- 3 files changed, 41 insertions(+), 40 deletions(-) rename kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/{kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c => kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.c} (95%) rename kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/{kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h => kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.h} (75%) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.c similarity index 95% rename from kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.c index 4934496f..524fde1f 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.c @@ -8,7 +8,7 @@ #error This file must be compiled for AArch64, FEAT_BF16. #else // Architectural features check. -#include "kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h" +#include "kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.h" #include #include @@ -22,43 +22,43 @@ static const size_t kai_nr = 12; static const size_t kai_kr = 4; static const size_t kai_sr = 1; -size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void) { return kai_mr; } -size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void) { return kai_nr; } -size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k) { KAI_ASSUME(m_idx % kai_mr == 0); return m_idx * kai_roundup(k, kai_kr) * sizeof(uint16_t); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_nr == 0); return n_idx * (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(uint16_t)); } -size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla( size_t m_idx, size_t n_idx, size_t stride) { KAI_ASSUME(m_idx % kai_mr == 0); KAI_ASSUME(n_idx % kai_nr == 0); @@ -66,11 +66,11 @@ size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( return m_idx * stride + n_idx * sizeof(float); } -size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( +void kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla( size_t m, size_t n, size_t k, // const void* lhs_packed, // const void* rhs_packed, // diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.h similarity index 75% rename from kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.h index e870fb2a..38533117 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.h @@ -24,42 +24,42 @@ extern "C" { /// The starting row index must be divisible by `m_step`. /// /// @return The m step value. -size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void); /// Gets n step value. /// /// The starting column index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void); /// Gets mr value. /// /// This is the packing parameter which must be used to pack the LHS matrix. /// /// @return The mr value. -size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void); /// Gets nr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The nr value. -size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void); /// Gets kr value. /// /// This is the packing parameter which must be used to pack the LHS & RHS matrices. /// /// @return The kr value. -size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void); /// Gets sr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The sr value. -size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void); /// Gets the offset in bytes to the data element in the packed LHS matrix buffer. /// @@ -67,7 +67,7 @@ size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// @@ -75,7 +75,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_m /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// @@ -84,7 +84,8 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_m /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t n_idx, size_t stride); +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla( + size_t m_idx, size_t n_idx, size_t stride); /// Gets the size in bytes of the destination matrix buffer. /// @@ -92,16 +93,16 @@ size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(siz /// @param[in] n Number of columns. /// /// @return The size in bytes of the destination matrix buffer. -size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n); /// Runs the matrix multiplication microkernel followed by a clamp operation. /// /// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset /// calculated using the following functions: /// -/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla. -/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla. -/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla. /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. @@ -113,7 +114,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_ /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. -void kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( +void kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla( size_t m, size_t n, size_t k, // const void* lhs_packed, // const void* rhs_packed, // diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 30dc207b..3f69422f 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -59,10 +59,10 @@ const std::array gemm_methods = { .bias_format = DataFormat(DataType::FP32), .fn_is_supported = cpu_has_bf16, - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, @@ -83,10 +83,10 @@ const std::array gemm_methods = { .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, }, MatMulMethod{ .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias", @@ -105,10 +105,10 @@ const std::array gemm_methods = { .bias_format = DataFormat(DataType::UNKNOWN), .fn_is_supported = cpu_has_bf16, - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, @@ -129,10 +129,10 @@ const std::array gemm_methods = { .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, }}; const std::array gemv_methods = { -- GitLab From b0c398af8cefb372681ec194dab3dd05d283f546 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Wed, 30 Oct 2024 12:41:39 +0000 Subject: [PATCH 2/2] Rebase and remove the kacc parameter from the block size suffix Signed-off-by: Gunes Bayir --- CMakeLists.txt | 2 +- .../CMakeLists.txt | 2 +- .../matmul_clamp_f32_bf16p_bf16p.cpp | 26 +++++------ kai/ukernels/matmul/BUILD.bazel | 4 +- ..._f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c} | 24 +++++----- ..._f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h} | 28 +++++------ .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 46 +++++++++---------- 7 files changed, 66 insertions(+), 66 deletions(-) rename kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/{kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.c => kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c} (96%) rename kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/{kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.h => kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h} (82%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 34b345c9..2b497e2c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -93,7 +93,7 @@ set(KLEIDIAI_FILES_NEON_BF16 kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.c kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.c - kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c + kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c ) set(KLEIDIAI_FILES_NEON diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt index 8ae3896e..afbcc670 100644 --- a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt +++ b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt @@ -23,7 +23,7 @@ include_directories( add_executable(matmul_clamp_f32_bf16p_bf16p matmul_clamp_f32_bf16p_bf16p.cpp ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.c - ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_bf16p1x4_f32_neon.c ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_bf16p8x4_f32_neon.c ${MATMUL_PACK_PATH}/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.c diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp index ec3f55e7..09c108f8 100644 --- a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp +++ b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp @@ -31,7 +31,7 @@ #include "kai_lhs_quant_pack_bf16p1x4_f32_neon.h" #include "kai_lhs_quant_pack_bf16p8x4_f32_neon.h" #include "kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.h" -#include "kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h" +#include "kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" #include "kai_matmul_clamp_f32_bf16p_bf16p_interface.h" #include "kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.h" @@ -72,20 +72,20 @@ const kai_matmul_clamp_f32_bf16p_bf16p ukernel_variants[] = { kai_run_lhs_quant_pack_bf16p1x4_f32_neon, kai_get_lhs_packed_size_lhs_quant_pack_bf16p1x4_f32_neon, "matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot"}, - {{kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla}, + {{kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla}, kai_run_lhs_quant_pack_bf16p8x4_f32_neon, kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon, - "matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla"}}; + "matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla"}}; // Number of micro-kernel variants stored in the array constexpr size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index bb047b41..c1fb02e5 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -43,11 +43,11 @@ kai_c_library( name = "clamp_f32_bf16p_bf16p_neon_mmla", srcs = [ "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.c", - "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c", + "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c", ], hdrs = [ "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.h", - "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h", + "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h", ], cpu_uarch = kai_cpu_bf16(), deps = [ diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c similarity index 96% rename from kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.c rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c index 524fde1f..9e82dd59 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c @@ -8,7 +8,7 @@ #error This file must be compiled for AArch64, FEAT_BF16. #else // Architectural features check. -#include "kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.h" +#include "kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" #include #include @@ -22,43 +22,43 @@ static const size_t kai_nr = 12; static const size_t kai_kr = 4; static const size_t kai_sr = 1; -size_t kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void) { return kai_mr; } -size_t kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void) { return kai_nr; } -size_t kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(size_t m_idx, size_t k) { KAI_ASSUME(m_idx % kai_mr == 0); return m_idx * kai_roundup(k, kai_kr) * sizeof(uint16_t); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_nr == 0); return n_idx * (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(uint16_t)); } -size_t kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla( +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla( size_t m_idx, size_t n_idx, size_t stride) { KAI_ASSUME(m_idx % kai_mr == 0); KAI_ASSUME(n_idx % kai_nr == 0); @@ -66,11 +66,11 @@ size_t kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla( return m_idx * stride + n_idx * sizeof(float); } -size_t kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla( +void kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla( size_t m, size_t n, size_t k, // const void* lhs_packed, // const void* rhs_packed, // diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h similarity index 82% rename from kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.h rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h index 38533117..251977e1 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h @@ -24,42 +24,42 @@ extern "C" { /// The starting row index must be divisible by `m_step`. /// /// @return The m step value. -size_t kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void); /// Gets n step value. /// /// The starting column index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void); /// Gets mr value. /// /// This is the packing parameter which must be used to pack the LHS matrix. /// /// @return The mr value. -size_t kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void); /// Gets nr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The nr value. -size_t kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void); /// Gets kr value. /// /// This is the packing parameter which must be used to pack the LHS & RHS matrices. /// /// @return The kr value. -size_t kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void); /// Gets sr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The sr value. -size_t kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void); /// Gets the offset in bytes to the data element in the packed LHS matrix buffer. /// @@ -67,7 +67,7 @@ size_t kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(void); /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(size_t m_idx, size_t k); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// @@ -75,7 +75,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neo /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// @@ -84,7 +84,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neo /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla( +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla( size_t m_idx, size_t n_idx, size_t stride); /// Gets the size in bytes of the destination matrix buffer. @@ -93,16 +93,16 @@ size_t kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla( /// @param[in] n Number of columns. /// /// @return The size in bytes of the destination matrix buffer. -size_t kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(size_t m, size_t n); /// Runs the matrix multiplication microkernel followed by a clamp operation. /// /// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset /// calculated using the following functions: /// -/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla. -/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla. -/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla. /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. @@ -114,7 +114,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla(si /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. -void kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla( +void kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla( size_t m, size_t n, size_t k, // const void* lhs_packed, // const void* rhs_packed, // diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 3f69422f..7d9df33e 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -33,7 +33,7 @@ // matmul_clamp_f32_bf16p_bf16p #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.h" @@ -59,18 +59,18 @@ const std::array gemm_methods = { .bias_format = DataFormat(DataType::FP32), .fn_is_supported = cpu_has_bf16, - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon, .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon, .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, @@ -78,15 +78,15 @@ const std::array gemm_methods = { .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, }, MatMulMethod{ .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias", @@ -105,18 +105,18 @@ const std::array gemm_methods = { .bias_format = DataFormat(DataType::UNKNOWN), .fn_is_supported = cpu_has_bf16, - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p8x4_f32_neon, .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p8x4_f32_neon, .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, @@ -124,15 +124,15 @@ const std::array gemm_methods = { .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12x4_neon_mmla, + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, }}; const std::array gemv_methods = { -- GitLab