From cd0e915ec5f60a41ce3cf80d7bc72a421bd8abd7 Mon Sep 17 00:00:00 2001 From: Jakub Sujak Date: Wed, 9 Oct 2024 14:42:17 +0100 Subject: [PATCH] Rename SME2 GEMM micro-kernels The SME2 GEMM micro-kernels now follow the naming convention correctly. * GEMM reports the m_step and n_step output block. * The LHS packing function reports the mr and kr packing parameters. * The RHS packing function reports the nr and kr packing parameters. Signed-off-by: Jakub Sujak --- CHANGELOG.md | 4 ++ CMakeLists.txt | 6 +-- benchmark/matmul/matmul_f32_f32p_f32p.cpp | 35 +++++++------- kai/ukernels/matmul/BUILD.bazel | 20 ++++---- ..._clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.c} | 33 +++++++------ ..._clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.h} | 39 ++++++++------- ...me.c => kai_lhs_pack_f32p_f32_2vlx1_sme.c} | 10 ++-- ...me.h => kai_lhs_pack_f32p_f32_2vlx1_sme.h} | 14 +++--- ...ai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.c} | 18 +++---- ...ai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.h} | 20 ++++---- test/tests/matmul_test.cpp | 48 +++++++++---------- 11 files changed, 122 insertions(+), 125 deletions(-) rename kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/{kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c => kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.c} (93%) rename kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/{kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h => kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.h} (67%) rename kai/ukernels/matmul/pack/{kai_lhs_pack_f32p2vlx1_f32_sme.c => kai_lhs_pack_f32p_f32_2vlx1_sme.c} (97%) rename kai/ukernels/matmul/pack/{kai_lhs_pack_f32p2vlx1_f32_sme.h => kai_lhs_pack_f32p_f32_2vlx1_sme.h} (82%) rename kai/ukernels/matmul/pack/{kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c => kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.c} (89%) rename kai/ukernels/matmul/pack/{kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h => kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.h} (75%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 123686b0..0344831f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification for releases. +## v0.4.0 - Upcoming Release + +- Rename the SME2 GEMM micro-kernels to follow the function naming convention. + ## v0.3.0 - Advanced SIMD FP32 GEMM micro-kernel. diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b6ea9a5..e2d93a7a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,12 +113,12 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME - kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c - kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.c + kai/ukernels/matmul/pack/kai_lhs_pack_f32p_f32_2vlx1_sme.c ) set(KLEIDIAI_FILES_SME2 - kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.c ) add_library(kleidiai) diff --git a/benchmark/matmul/matmul_f32_f32p_f32p.cpp b/benchmark/matmul/matmul_f32_f32p_f32p.cpp index 64f605b8..219d9e2c 100644 --- a/benchmark/matmul/matmul_f32_f32p_f32p.cpp +++ b/benchmark/matmul/matmul_f32_f32p_f32p.cpp @@ -13,10 +13,10 @@ #include #include "benchmark/matmul/matmul_utils.hpp" -#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h" -#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p_f32_2vlx1_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.h" #include "test/common/cpu_info.hpp" namespace kai::bench::matmul_f32_f32p_f32p { @@ -31,18 +31,15 @@ struct kai_matmul_ukernel_f32_f32p_f32p { }; kai_matmul_ukernel_f32_f32p_f32p sme_variants[] = { - {kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - "matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa"}, + {kai_get_m_step_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + kai_get_n_step_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + kai_get_mr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, kai_get_nr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + kai_get_kr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, kai_get_sr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + kai_get_dst_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + kai_get_dst_size_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + kai_run_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, "matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa"}, }; struct kai_matmul_f32_f32p_f32p_sme { @@ -67,10 +64,10 @@ struct kai_matmul_f32_f32p_f32p_sme { const size_t kr = variant.ukernel.get_kr(); const size_t sr = variant.ukernel.get_sr(); - const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(m, k, mr, kr, sr); + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_pack_f32p_f32_2vlx1_sme(m, k, mr, kr, sr); float* lhs_packed = new float[lhs_packed_size / sizeof(float)]; - const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(n, k); + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(n, k); float* rhs_packed = new float[rhs_packed_size / sizeof(float)]; @@ -78,11 +75,11 @@ struct kai_matmul_f32_f32p_f32p_sme { const size_t rhs_stride = n * sizeof(float); const size_t dst_stride_row = n * sizeof(float); const size_t dst_stride_col = sizeof(float); - kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( + kai_run_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme( 1, n, k, nr, kr, sr, // Packing arguments rhs_stride, rhs, bias, NULL, rhs_packed, 0, NULL); - kai_run_lhs_pack_f32p2vlx1_f32_sme(m, k, mr, kr, sr, 0, lhs, k * sizeof(float), lhs_packed); + kai_run_lhs_pack_f32p_f32_2vlx1_sme(m, k, mr, kr, sr, 0, lhs, k * sizeof(float), lhs_packed); float* dst = new float[dst_size]; for (auto _ : state) { diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 96c26780..e4b6f59f 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -41,8 +41,8 @@ kai_c_library( kai_c_library( name = "clamp_f32_f32p_f32p", - srcs = ["matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c"], - hdrs = ["matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h"], + srcs = ["matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.c"], + hdrs = ["matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.h"], cpu_uarch = kai_cpu_sme(), ) @@ -119,9 +119,9 @@ kai_c_library( ) kai_c_library( - name = "lhs_pack_f32p2vlx1_f32_sme", - srcs = ["pack/kai_lhs_pack_f32p2vlx1_f32_sme.c"], - hdrs = ["pack/kai_lhs_pack_f32p2vlx1_f32_sme.h"], + name = "lhs_pack_f32p_f32_2vlx1_sme", + srcs = ["pack/kai_lhs_pack_f32p_f32_2vlx1_sme.c"], + hdrs = ["pack/kai_lhs_pack_f32p_f32_2vlx1_sme.h"], cpu_uarch = kai_cpu_sme(), ) @@ -140,9 +140,9 @@ kai_c_library( ) kai_c_library( - name = "rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", - srcs = ["pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c"], - hdrs = ["pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h"], + name = "rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme", + srcs = ["pack/kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.c"], + hdrs = ["pack/kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.h"], cpu_uarch = kai_cpu_sme(), ) @@ -282,10 +282,10 @@ kai_c_library( ":clamp_f32_qsi8d32p_qsi4c32p_dotprod", ":clamp_f32_qsi8d32p_qsi4c32p_i8mm", ":clamp_f32_qsi8d32p_qsi4c32p_interface", - ":lhs_pack_f32p2vlx1_f32_sme", + ":lhs_pack_f32p_f32_2vlx1_sme", ":lhs_quant_pack_qai8dxp_f32", ":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", - ":rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", + ":rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme", ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", ":rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", ":rhs_pack_kxn_qsi4cxp_qs4cxs1s0", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.c similarity index 93% rename from kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c rename to kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.c index 5611fad1..0b02eea3 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.c @@ -8,7 +8,7 @@ #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. -#include "kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#include "kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.h" #include #include @@ -20,53 +20,52 @@ static const size_t kai_nr = 2; static const size_t kai_kr = 1; static const size_t kai_sr = 1; -size_t kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { +size_t kai_get_m_step_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(void) { return kai_mr * kai_get_sme_vector_length_u32(); } -size_t kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { +size_t kai_get_n_step_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(void) { return kai_nr * kai_get_sme_vector_length_u32(); } -size_t kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { +size_t kai_get_mr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(void) { return kai_mr * kai_get_sme_vector_length_u32(); } -size_t kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { +size_t kai_get_nr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(void) { return kai_nr * kai_get_sme_vector_length_u32(); } -size_t kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { +size_t kai_get_kr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void) { +size_t kai_get_sr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t m_idx, size_t k) { - KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa() == 0); return m_idx * k * sizeof(float); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t n_idx, size_t k) { - KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa() == 0); return n_idx * (k * sizeof(float) + sizeof(float)); } -size_t kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0); - KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() == 0); +size_t kai_get_dst_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa() == 0); + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa() == 0); return m_idx * dst_stride + n_idx * sizeof(float); } -size_t kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( +void kai_run_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max) { KAI_ASSUME(dst_stride_col == sizeof(float)); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.h similarity index 67% rename from kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h rename to kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.h index 1dcc3404..0c66589e 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.h @@ -14,50 +14,50 @@ extern "C" { /// Micro-kernel dependencies /// -/// -# kai_lhs_pack_f32p2vlx1_f32_sme to pack the LHS matrix. -/// -# kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme to pack the RHS matrix. +/// -# kai_lhs_pack_f32p_f32_2vlx1_sme to pack the LHS matrix. +/// -# kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme to pack the RHS matrix. /// Gets m step value. /// /// The starting row index must be divisible by `m_step`. /// /// @return The m step value. -size_t kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); +size_t kai_get_m_step_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(void); /// Gets n step value. /// /// The starting column index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); +size_t kai_get_n_step_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(void); /// Gets mr value. /// /// This is the packing parameter which must be used to pack the LHS matrix. /// /// @return The mr value. -size_t kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); +size_t kai_get_mr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(void); /// Gets nr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The nr value. -size_t kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); +size_t kai_get_nr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(void); /// Gets kr value. /// /// This is the packing parameter which must be used to pack the LHS and RHS matrix. /// /// @return The kr value. -size_t kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); +size_t kai_get_kr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(void); /// Gets sr value. /// /// This is the packing parameter which must be used to pack the LHS and RHS matrix. /// /// @return The sr value. -size_t kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); +size_t kai_get_sr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(void); /// Gets the offset in bytes to the data element in the packed LHS matrix buffer. /// @@ -65,7 +65,7 @@ size_t kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(void); /// @param[in] k Number of columns in the unpacked LHS matrix. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(size_t m_idx, size_t k); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// @@ -73,17 +73,16 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme /// @param[in] k Number of rows in the unpacked RHS matrix. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// /// @param[in] m_idx Row index. /// @param[in] n_idx Column index. -/// @param[in] stride Row stride in bytes. +/// @param[in] dst_stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( - size_t m_idx, size_t n_idx, size_t dst_stride); +size_t kai_get_dst_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(size_t m_idx, size_t n_idx, size_t dst_stride); /// Gets the size in bytes of the destination matrix buffer. /// @@ -91,28 +90,28 @@ size_t kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( /// @param[in] n Number of columns. /// /// @return The size in bytes of the destination matrix buffer. -size_t kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa(size_t m, size_t n); /// Runs the matrix multiplication microkernel followed by a clamp operation. /// /// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset /// calculated using the following functions: /// -/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa. -/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa. -/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa. /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. /// @param[in] k Common dimension of the LHS and RHS operands. -/// @param[in] packed_lhs Packed LHS matrix buffer. -/// @param[in] packed_rhs Packed RHS matrix buffer. +/// @param[in] lhs_packed Packed LHS matrix buffer. +/// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. /// @param[in] dst_stride_row Row stride in bytes of the output matrix. /// @param[in] dst_stride_col Column stride in bytes of the output matrix. /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. -void kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( +void kai_run_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max); diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p_f32_2vlx1_sme.c similarity index 97% rename from kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c rename to kai/ukernels/matmul/pack/kai_lhs_pack_f32p_f32_2vlx1_sme.c index 8f0d0ecf..a9cde832 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p_f32_2vlx1_sme.c @@ -17,20 +17,20 @@ static const size_t kai_mr = 2; static const size_t kai_kr = 1; static const size_t kai_sr = 1; -size_t kai_get_m_step_lhs_pack_f32p2vlx1_f32_sme(size_t mr) { +size_t kai_get_m_step_lhs_pack_f32p_f32_2vlx1_sme(size_t mr) { KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u32()); KAI_UNUSED(mr); return kai_mr * kai_get_sme_vector_length_u32(); } -size_t kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t lhs_stride) { +size_t kai_get_lhs_offset_lhs_pack_f32p_f32_2vlx1_sme(size_t m_idx, size_t lhs_stride) { KAI_ASSUME(m_idx % (kai_mr * kai_get_sme_vector_length_u32()) == 0); return m_idx * lhs_stride; } -size_t kai_get_lhs_packed_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { +size_t kai_get_lhs_packed_offset_lhs_pack_f32p_f32_2vlx1_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { const size_t scaled_mr = kai_mr * kai_get_sme_vector_length_u32(); KAI_ASSUME(m_idx % scaled_mr == 0); KAI_ASSUME(mr == scaled_mr); @@ -44,7 +44,7 @@ size_t kai_get_lhs_packed_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t return m_idx * k * sizeof(float); } -size_t kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { +size_t kai_get_lhs_packed_size_lhs_pack_f32p_f32_2vlx1_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u32()); KAI_ASSUME(kr == kai_kr); KAI_ASSUME(sr == kai_sr); @@ -56,7 +56,7 @@ size_t kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(size_t m, size_t k, si return kai_roundup(m, kai_mr * kai_get_sme_vector_length_u32()) * k * sizeof(float); } -void kai_run_lhs_pack_f32p2vlx1_f32_sme( +void kai_run_lhs_pack_f32p_f32_2vlx1_sme( size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u32()); diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p_f32_2vlx1_sme.h similarity index 82% rename from kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h rename to kai/ukernels/matmul/pack/kai_lhs_pack_f32p_f32_2vlx1_sme.h index 82c5db48..da89e093 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p_f32_2vlx1_sme.h @@ -19,7 +19,7 @@ extern "C" { /// @param[in] mr Number of rows to be interleaved. /// /// @return The m step value. -size_t kai_get_m_step_lhs_pack_f32p2vlx1_f32_sme(size_t mr); +size_t kai_get_m_step_lhs_pack_f32p_f32_2vlx1_sme(size_t mr); /// Gets the offset in bytes to the data element in the LHS buffer. /// @@ -27,7 +27,7 @@ size_t kai_get_m_step_lhs_pack_f32p2vlx1_f32_sme(size_t mr); /// @param[in] lhs_stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t lhs_stride); +size_t kai_get_lhs_offset_lhs_pack_f32p_f32_2vlx1_sme(size_t m_idx, size_t lhs_stride); /// Gets the offset in bytes to the data element in the packed LHS buffer. /// @@ -38,7 +38,7 @@ size_t kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t lhs_st /// @param[in] sr Unused. Must be 1. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_packed_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); +size_t kai_get_lhs_packed_offset_lhs_pack_f32p_f32_2vlx1_sme(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); /// Gets the size in bytes of the packed LHS buffer. /// @@ -49,15 +49,15 @@ size_t kai_get_lhs_packed_offset_lhs_pack_f32p2vlx1_f32_sme(size_t m_idx, size_t /// @param[in] sr Unused. Must be 1. /// /// @return The size in bytes of the packed LHS buffer. -size_t kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr); +size_t kai_get_lhs_packed_size_lhs_pack_f32p_f32_2vlx1_sme(size_t m, size_t k, size_t mr, size_t kr, size_t sr); /// Runs the LHS packing function for matrix multiplication. /// /// The pointer of each buffers (LHS and packed LHS) needs to be added with offset /// calculated using the following functions: /// -/// * LHS: @ref kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme. -/// * Packed LHS: @ref kai_get_lhs_packed_offset_lhs_pack_f32p2vlx1_f32_sme. +/// * LHS: @ref kai_get_lhs_offset_lhs_pack_f32p_f32_2vlx1_sme. +/// * Packed LHS: @ref kai_get_lhs_packed_offset_lhs_pack_f32p_f32_2vlx1_sme. /// /// @param[in] m Number of rows of the unpacked LHS matrix. /// @param[in] k Common dimension between the LHS and RHS matrix. @@ -68,7 +68,7 @@ size_t kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(size_t m, size_t k, si /// @param[in] lhs LHS matrix data buffer. /// @param[in] lhs_stride Row stride in bytes of the LHS matrix. /// @param[out] lhs_packed Packed RHS matrix. -void kai_run_lhs_pack_f32p2vlx1_f32_sme( +void kai_run_lhs_pack_f32p_f32_2vlx1_sme( size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.c similarity index 89% rename from kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c rename to kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.c index 015e70ba..42a3a6fa 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.c @@ -18,36 +18,36 @@ static const size_t kai_kr = 1; static const size_t kai_num_bytes_data = sizeof(uint32_t); static const size_t kai_num_bytes_bias = sizeof(uint32_t); -size_t kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(void) { +size_t kai_get_n_step_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(void) { return kai_nr * kai_get_sme_vector_length_u32(); } -size_t kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx) { +size_t kai_get_rhs_offset_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(size_t n_idx) { KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u32()) == 0); return n_idx * kai_num_bytes_data; } -size_t kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx) { +size_t kai_get_bias_offset_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(size_t n_idx) { return n_idx * kai_num_bytes_bias; } -size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t k) { +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(size_t k) { return kai_nr * kai_get_sme_vector_length_u32() * (kai_num_bytes_bias + k * kai_num_bytes_data); } -size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % (kai_nr * kai_get_sme_vector_length_u32()) == 0); return n_idx * (kai_num_bytes_bias + k * kai_num_bytes_data); } -size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n, size_t k) { - return kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(size_t n, size_t k) { + return kai_get_rhs_packed_offset_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme( kai_roundup(n, kai_nr * kai_get_sme_vector_length_u32()), k); } -void kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( +void kai_run_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { KAI_ASSUME(num_groups == 1); @@ -66,7 +66,7 @@ void kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( const void* in = rhs; void* out = rhs_packed; const size_t in_stride = rhs_stride; - size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(height); + size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(height); __asm__ __volatile__( ".inst 0xd503477f // SMSTART ZA\n" diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.h similarity index 75% rename from kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h rename to kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.h index 602e24c2..561924d2 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.h @@ -17,28 +17,28 @@ extern "C" { /// The starting row index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(void); +size_t kai_get_n_step_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(void); /// Gets the offset in bytes to the data element in the RHS matrix buffer. /// /// @param[in] n_idx Column index. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx); +size_t kai_get_rhs_offset_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(size_t n_idx); /// Gets the offset in bytes to the data element in the bias buffer. /// /// @param[in] n_idx Column index. /// /// @return The offset in bytes to the data element. -size_t kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx); +size_t kai_get_bias_offset_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(size_t n_idx); /// Get the row stride in bytes to the packed RHS matrix /// /// @param[in] k In the RHS matrix (not packed), K is the number of columns. /// /// @return The stride in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t k); +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(size_t k); /// Gets the offset in bytes to the data element in the packed RHS buffer. /// @@ -46,7 +46,7 @@ size_t kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_ /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(size_t n_idx, size_t k); /// Gets the size in bytes of the packed RHS buffer. /// @@ -54,16 +54,16 @@ size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_ /// @param[in] k Number of columns. /// /// @return The size in bytes of the packed RHS buffer. -size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t n, size_t k); +size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme(size_t n, size_t k); /// Runs the RHS packing function for matrix multiplication. /// /// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset /// calculated using the following functions: /// -/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme. -/// * Bias: @ref kai_get_packed_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme. -/// * Output: @ref kai_get_dst_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme. +/// * RHS: @ref kai_get_rhs_offset_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme. +/// * Bias: @ref kai_get_packed_rhs_offset_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme. +/// * Output: @ref kai_get_dst_offset_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme. /// /// @param[in] num_groups Number of groups. It must be 1. /// @param[in] n Number of columns of the output matrix. @@ -78,7 +78,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(size_t /// @param[out] rhs_packed Packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0. /// @param[in] params Extra packing parameters. It must be NULL. -void kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme( +void kai_run_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params); diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index f2d5e544..9b76ef52 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -37,9 +37,9 @@ #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" // matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa -#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" -#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p_f32_2vlx1_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme.h" // matmul_clamp_f32_f32_f32p #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" @@ -453,36 +453,34 @@ static const std::array matmul_methods = { .bias_format = DataFormat(DataType::FP32), .fn_is_supported = cpu_has_sme2, - .fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + .fn_get_mr = kai_get_mr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + .fn_get_nr = kai_get_nr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + .fn_get_kr = kai_get_kr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + .fn_get_sr = kai_get_sr_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, - .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p2vlx1_f32_sme, - .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme, - .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_pack_lhs = kai_run_lhs_pack_f32p2vlx1_f32_sme, + .fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_f32p_f32_2vlx1_sme, + .fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_f32p_f32_2vlx1_sme, + .fn_get_packed_lhs_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + .fn_pack_lhs = kai_run_lhs_pack_f32p_f32_2vlx1_sme, - .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, - .fn_get_pack_rhs_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_main_packed_rhs_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_pack_rhs = kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, + .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme, + .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme, + .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + .fn_pack_rhs = kai_run_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme, - .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme, + .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f32pb_f32_f32_2vlx1_sme, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, .fn_matmul_f16_f16_f16p = nullptr, .fn_matmul_f32_f32_f32p = nullptr, - .fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + .fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p_f32pb_2vlx2vl_sme2_mopa, }, }; -- GitLab