From 9f77195bc6387ff7dc8417e5002b1030bed82ac6 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Fri, 25 Oct 2024 13:35:54 +0100 Subject: [PATCH] Name change of bf16 Lhs/Rhs packed GEMM kernel Signed-off-by: Gunes Bayir --- CMakeLists.txt | 2 +- .../CMakeLists.txt | 2 +- .../matmul_clamp_f32_bf16p_bf16p.cpp | 26 +++++------ kai/ukernels/matmul/BUILD.bazel | 4 +- ..._f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c} | 24 +++++----- ..._f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h} | 29 ++++++------ .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 46 +++++++++---------- 7 files changed, 67 insertions(+), 66 deletions(-) rename kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/{kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c => kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c} (95%) rename kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/{kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h => kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h} (75%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 649e1585..28a8cf89 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -91,7 +91,7 @@ set(KLEIDIAI_FILES_NEON_FP16 set(KLEIDIAI_FILES_NEON_BF16 kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c - kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c + kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c ) set(KLEIDIAI_FILES_NEON diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt index 4b13a183..3dbf737d 100644 --- a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt +++ b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt @@ -22,7 +22,7 @@ include_directories( # Files requires to build the executable add_executable(matmul_clamp_f32_bf16p_bf16p matmul_clamp_f32_bf16p_bf16p.cpp - ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_bf16p_f32_neon.c ${MATMUL_PACK_PATH}/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c ) diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp index f11900c3..2e9dc2c3 100644 --- a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp +++ b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp @@ -29,7 +29,7 @@ // Include micro-kernel variants #include "kai/kai_common.h" #include "kai_lhs_quant_pack_bf16p_f32_neon.h" -#include "kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h" +#include "kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" #include "kai_matmul_clamp_f32_bf16p_bf16p_interface.h" #include "kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h" @@ -41,17 +41,17 @@ inline static float bf16_to_float(const uint16_t* v) { namespace { /// Micro-kernel interface constexpr kai_matmul_clamp_f32_bf16p_bf16p_ukernel ukernel{ - kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla}; + kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla}; /// @brief Truncate the 32-bit floating point number's least significant 16 mantissa bits /// @param x floating-point number @@ -294,7 +294,7 @@ int main() { const bool is_valid = is_output_correct(M, N, rel_tolerance, dst_ref, dst); std::cout << "TEST[matmul_clamp_f32_bf16p_bf16p]\n"; - std::cout << "- ukernel: matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla\n"; + std::cout << "- ukernel: matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla\n"; if (is_valid) { std::cout << "- Status: PASSED\n"; std::cout << "- Performance: " << time_matmul.count() << "ns\n"; diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 66a3c386..598d4b0d 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -41,8 +41,8 @@ kai_c_library( kai_c_library( name = "clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla", - srcs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c"], - hdrs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h"], + srcs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c"], + hdrs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h"], cpu_uarch = kai_cpu_bf16(), deps = [ ":clamp_f32_bf16p_bf16p_interface", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c similarity index 95% rename from kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c index 929e3753..ca2b728c 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.c @@ -8,7 +8,7 @@ #error This file must be compiled for AArch64, FEAT_BF16. #else // Architectural features check. -#include "kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h" +#include "kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" #include #include @@ -21,43 +21,43 @@ static const size_t kai_nr = 12; static const size_t kai_kr = 4; static const size_t kai_sr = 1; -size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void) { return kai_mr; } -size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void) { return kai_nr; } -size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) { +size_t kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(size_t m_idx, size_t k) { KAI_ASSUME(m_idx % kai_mr == 0); return m_idx * kai_roundup(k, kai_kr) * sizeof(uint16_t); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_nr == 0); return n_idx * (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(uint16_t)); } -size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla( size_t m_idx, size_t n_idx, size_t stride) { KAI_ASSUME(m_idx % kai_mr == 0); KAI_ASSUME(n_idx % kai_nr == 0); @@ -65,11 +65,11 @@ size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( return m_idx * stride + n_idx * sizeof(float); } -size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( +void kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla( size_t m, size_t n, size_t k, // const void* lhs_packed, // const void* rhs_packed, // diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h similarity index 75% rename from kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h index e870fb2a..251977e1 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h @@ -24,42 +24,42 @@ extern "C" { /// The starting row index must be divisible by `m_step`. /// /// @return The m step value. -size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void); /// Gets n step value. /// /// The starting column index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void); /// Gets mr value. /// /// This is the packing parameter which must be used to pack the LHS matrix. /// /// @return The mr value. -size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void); /// Gets nr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The nr value. -size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void); /// Gets kr value. /// /// This is the packing parameter which must be used to pack the LHS & RHS matrices. /// /// @return The kr value. -size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void); /// Gets sr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The sr value. -size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); +size_t kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(void); /// Gets the offset in bytes to the data element in the packed LHS matrix buffer. /// @@ -67,7 +67,7 @@ size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void); /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(size_t m_idx, size_t k); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// @@ -75,7 +75,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_m /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// @@ -84,7 +84,8 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_m /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t n_idx, size_t stride); +size_t kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla( + size_t m_idx, size_t n_idx, size_t stride); /// Gets the size in bytes of the destination matrix buffer. /// @@ -92,16 +93,16 @@ size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(siz /// @param[in] n Number of columns. /// /// @return The size in bytes of the destination matrix buffer. -size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(size_t m, size_t n); /// Runs the matrix multiplication microkernel followed by a clamp operation. /// /// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset /// calculated using the following functions: /// -/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla. -/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla. -/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla. /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. @@ -113,7 +114,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_ /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. -void kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla( +void kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla( size_t m, size_t n, size_t k, // const void* lhs_packed, // const void* rhs_packed, // diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index 730ff5ae..02a94498 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -32,7 +32,7 @@ #include "test/reference/pack.hpp" // matmul_clamp_f32_bf16p_bf16p -#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h" namespace kai::test { @@ -60,33 +60,33 @@ const std::array matmul_methods = { .bias_format = DataFormat(DataType::FP32), .fn_is_supported = cpu_has_bf16, - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon, .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p_f32_neon, .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_packed_rhs_size = nullptr, .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, }, MatMulMethod{ .name = "matmul_nt_nt_f32_bf16p_bf16p_8x12_neon_mla_opt_bias", @@ -108,33 +108,33 @@ const std::array matmul_methods = { .bias_format = DataFormat(DataType::UNKNOWN), .fn_is_supported = cpu_has_bf16, - .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_get_lhs_offset = kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon, .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_pack_lhs = kai_run_lhs_quant_pack_bf16p_f32_neon, .fn_get_rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_packed_rhs_size = nullptr, .fn_get_packed_rhs_size_generic_block_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_pack_rhs_packed_rhs_offset = nullptr, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, .fn_pack_rhs = kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, .fn_get_bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, - .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla, + .fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla, }}; } // namespace -- GitLab