diff --git a/CMakeLists.txt b/CMakeLists.txt index 8e0c713024da474aadb962f705def01a39e46ffa..308648cb9bcbd9905e535025036f2c71327a8001 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -222,6 +222,7 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c @@ -229,7 +230,6 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c - kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c diff --git a/benchmark/matmul/matmul_registry.cpp b/benchmark/matmul/matmul_registry.cpp index 75d6a48b92bc361d05c4fc53fee36fcaa8b7824b..8db6b36ea89df2521b36137a54c68fc5418f3ceb 100644 --- a/benchmark/matmul/matmul_registry.cpp +++ b/benchmark/matmul/matmul_registry.cpp @@ -40,6 +40,7 @@ // matmul_clamp_f32_bf16p_bf16p #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" // matmul_clamp_f32_f32_f32p @@ -91,9 +92,6 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" -// matmul_clamp_fp32_bf16p_bf16p -#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" - // matmul_clamp_qai8_qai8_qsi8cxp #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h" diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 79bfac8b50910ed212b0d13f1a9b3378e82b0394..c78b5106bc5d35fe19a4356ea6bfb0e0540fc333 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -149,6 +149,7 @@ SME2_KERNELS = [ "imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", + "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla", "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla", "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa", @@ -156,7 +157,6 @@ SME2_KERNELS = [ "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot", - "matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", "matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot", "matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", ] diff --git a/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c diff --git a/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h index b1d747c59c7918bd91a6baa81618c32bdbd9cafe..8e3e6f4744f2187cd80c3c9caeee83fb5d38a5c7 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h @@ -1,14 +1,10 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // #pragma once -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) -#error This file must be compiled for AArch64, FEAT_BF16 -#else // Architectural features check. - #include #ifdef __cplusplus @@ -25,7 +21,7 @@ typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_mr_func_t)(void); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t)(void); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_kr_func_t)(void); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_packed_offset_func_t)(size_t m_idx, size_t lhs_stride); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t)(size_t m, size_t n); @@ -53,5 +49,3 @@ struct kai_matmul_clamp_f32_bf16p_bf16p_ukernel { #ifdef __cplusplus } #endif - -#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h b/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h deleted file mode 100644 index 38cc633a530a7fb6ded5b4397c078ff0d61fb02f..0000000000000000000000000000000000000000 --- a/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h +++ /dev/null @@ -1,55 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#pragma once - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) -#error This file must be compiled for AArch64, FEAT_BF16. -#else // Architectural features check. - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// All micro-kernels variants of the same type share the same interfaces -// In this case, the micro-kernel type is: matmul_clamp_f32_bf16p_bf16p - -/// Micro-kernel helper functions ("get" methods) -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_m_step_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_n_step_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_kr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_offset_func_t)(size_t m_idx, size_t lhs_stride); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t)(size_t m, size_t n); - -/// Micro-kernel core function ("run" method) -typedef void (*kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t)( - size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, - size_t dst_stride_row, size_t dst_stride_col, __fp16 scalar_min, __fp16 scalar_max); - -/// Micro-kernel interface -struct kai_matmul_clamp_f32_bf16p_bf16p_ukernel { - kai_matmul_clamp_f32_bf16p_bf16p_get_m_step_func_t get_m_step; - kai_matmul_clamp_f32_bf16p_bf16p_get_n_step_func_t get_n_step; - kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t get_nr; - kai_matmul_clamp_f32_bf16p_bf16p_get_kr_func_t get_kr; - kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t get_sr; - kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_offset_func_t get_lhs_packed_offset; - kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t get_rhs_packed_offset; - kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t get_dst_offset; - kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t get_dst_size; - kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t run_matmul; -}; - -#ifdef __cplusplus -} -#endif - -#endif // Architectural features check. diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index a17399ef285258e534967409ed8e31a3c04c4045..044c35e3024e89a5de0c02e505dc8efe0d2360da 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -4,14 +4,12 @@ // SPDX-License-Identifier: Apache-2.0 // -#include #include #include #include #include #include -#include #include #include #include @@ -34,17 +32,29 @@ // matmul_clamp_f32_bf16p_bf16p #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.h" -// SME files here. -#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" -#include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" +#define INITIALIZE_INTERFACE_F32_BF16P_BF16P(interface_name, kernel_name) \ + (interface_name).get_n_step = kai_get_n_step_##kernel_name; \ + (interface_name).get_m_step = kai_get_m_step_##kernel_name; \ + (interface_name).get_mr = kai_get_mr_##kernel_name; \ + (interface_name).get_nr = kai_get_nr_##kernel_name; \ + (interface_name).get_kr = kai_get_kr_##kernel_name; \ + (interface_name).get_sr = kai_get_sr_##kernel_name; \ + (interface_name).get_lhs_packed_offset = kai_get_lhs_packed_offset_##kernel_name; \ + (interface_name).get_rhs_packed_offset = kai_get_rhs_packed_offset_##kernel_name; \ + (interface_name).get_dst_offset = kai_get_dst_offset_##kernel_name; \ + (interface_name).get_dst_size = kai_get_dst_size_##kernel_name; \ + (interface_name).run_matmul = kai_run_##kernel_name; namespace kai::test { @@ -305,6 +315,13 @@ struct MatMulMethodInitializer { gemv_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; gemv_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; gemv_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + + // Test kernel interface match kernels API + kai_matmul_clamp_f32_bf16p_bf16p_ukernel interface_test [[maybe_unused]]{}; + + INITIALIZE_INTERFACE_F32_BF16P_BF16P(interface_test, matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa) + INITIALIZE_INTERFACE_F32_BF16P_BF16P(interface_test, matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla) + INITIALIZE_INTERFACE_F32_BF16P_BF16P(interface_test, matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot) }; };