From 137515a740fbad9b72966358ca2d38a1a90f5a07 Mon Sep 17 00:00:00 2001 From: Anton Bondarenko Date: Tue, 8 Apr 2025 18:15:46 +0200 Subject: [PATCH] Move matmul_clamp_fp32_bf16p_bf16p into matmul_clamp_f32_bf16p_bf16p Folder name matmul_clamp_fp32_bf16p_bf16p does not match naming pattern so it needs to be fixed. Signed-off-by: Anton Bondarenko --- CMakeLists.txt | 2 +- benchmark/matmul/matmul_registry.cpp | 4 +- kai/ukernels/matmul/BUILD.bazel | 2 +- ..._bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c | 0 ..._bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h | 0 ...i_matmul_clamp_f32_bf16p_bf16p_interface.h | 10 +--- ...i_matmul_clamp_f32_bf16p_bf16p_interface.h | 55 ------------------- .../matmul_clamp_f32_bf16p_bf16p_test.cpp | 29 ++++++++-- 8 files changed, 28 insertions(+), 74 deletions(-) rename kai/ukernels/matmul/{matmul_clamp_fp32_bf16p_bf16p => matmul_clamp_f32_bf16p_bf16p}/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c (100%) rename kai/ukernels/matmul/{matmul_clamp_fp32_bf16p_bf16p => matmul_clamp_f32_bf16p_bf16p}/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h (100%) delete mode 100644 kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 8e0c7130..308648cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -222,6 +222,7 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c + kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c @@ -229,7 +230,6 @@ set(KLEIDIAI_FILES_SME2 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c - kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.c kai/ukernels/matmul/matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa.c kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c diff --git a/benchmark/matmul/matmul_registry.cpp b/benchmark/matmul/matmul_registry.cpp index 75d6a48b..8db6b36e 100644 --- a/benchmark/matmul/matmul_registry.cpp +++ b/benchmark/matmul/matmul_registry.cpp @@ -40,6 +40,7 @@ // matmul_clamp_f32_bf16p_bf16p #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" // matmul_clamp_f32_f32_f32p @@ -91,9 +92,6 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" -// matmul_clamp_fp32_bf16p_bf16p -#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" - // matmul_clamp_qai8_qai8_qsi8cxp #include "kai/ukernels/matmul/matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot.h" diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 79bfac8b..c78b5106 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -149,6 +149,7 @@ SME2_KERNELS = [ "imatmul_clamp_qai8_qai8p_qsi8cxp/kai_imatmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot", "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa", + "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla", "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla", "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa", @@ -156,7 +157,6 @@ SME2_KERNELS = [ "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa", "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot", - "matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa", "matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot", "matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa", ] diff --git a/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c diff --git a/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h similarity index 100% rename from kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h rename to kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h index b1d747c5..8e3e6f47 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h @@ -1,14 +1,10 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // #pragma once -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) -#error This file must be compiled for AArch64, FEAT_BF16 -#else // Architectural features check. - #include #ifdef __cplusplus @@ -25,7 +21,7 @@ typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_mr_func_t)(void); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t)(void); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_kr_func_t)(void); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_packed_offset_func_t)(size_t m_idx, size_t lhs_stride); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t)(size_t m, size_t n); @@ -53,5 +49,3 @@ struct kai_matmul_clamp_f32_bf16p_bf16p_ukernel { #ifdef __cplusplus } #endif - -#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h b/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h deleted file mode 100644 index 38cc633a..00000000 --- a/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h +++ /dev/null @@ -1,55 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#pragma once - -#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) -#error This file must be compiled for AArch64, FEAT_BF16. -#else // Architectural features check. - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// All micro-kernels variants of the same type share the same interfaces -// In this case, the micro-kernel type is: matmul_clamp_f32_bf16p_bf16p - -/// Micro-kernel helper functions ("get" methods) -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_m_step_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_n_step_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_kr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_offset_func_t)(size_t m_idx, size_t lhs_stride); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); -typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t)(size_t m, size_t n); - -/// Micro-kernel core function ("run" method) -typedef void (*kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t)( - size_t m, size_t n, size_t k, const void* lhs, size_t lhs_stride, const void* rhs_packed, void* dst, - size_t dst_stride_row, size_t dst_stride_col, __fp16 scalar_min, __fp16 scalar_max); - -/// Micro-kernel interface -struct kai_matmul_clamp_f32_bf16p_bf16p_ukernel { - kai_matmul_clamp_f32_bf16p_bf16p_get_m_step_func_t get_m_step; - kai_matmul_clamp_f32_bf16p_bf16p_get_n_step_func_t get_n_step; - kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t get_nr; - kai_matmul_clamp_f32_bf16p_bf16p_get_kr_func_t get_kr; - kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t get_sr; - kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_offset_func_t get_lhs_packed_offset; - kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t get_rhs_packed_offset; - kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t get_dst_offset; - kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t get_dst_size; - kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t run_matmul; -}; - -#ifdef __cplusplus -} -#endif - -#endif // Architectural features check. diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp index a17399ef..044c35e3 100644 --- a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp +++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp @@ -4,14 +4,12 @@ // SPDX-License-Identifier: Apache-2.0 // -#include #include #include #include #include #include -#include #include #include #include @@ -34,17 +32,29 @@ // matmul_clamp_f32_bf16p_bf16p #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" #include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.h" -// SME files here. -#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" -#include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" +#define INITIALIZE_INTERFACE_F32_BF16P_BF16P(interface_name, kernel_name) \ + (interface_name).get_n_step = kai_get_n_step_##kernel_name; \ + (interface_name).get_m_step = kai_get_m_step_##kernel_name; \ + (interface_name).get_mr = kai_get_mr_##kernel_name; \ + (interface_name).get_nr = kai_get_nr_##kernel_name; \ + (interface_name).get_kr = kai_get_kr_##kernel_name; \ + (interface_name).get_sr = kai_get_sr_##kernel_name; \ + (interface_name).get_lhs_packed_offset = kai_get_lhs_packed_offset_##kernel_name; \ + (interface_name).get_rhs_packed_offset = kai_get_rhs_packed_offset_##kernel_name; \ + (interface_name).get_dst_offset = kai_get_dst_offset_##kernel_name; \ + (interface_name).get_dst_size = kai_get_dst_size_##kernel_name; \ + (interface_name).run_matmul = kai_run_##kernel_name; namespace kai::test { @@ -305,6 +315,13 @@ struct MatMulMethodInitializer { gemv_methods[1].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; gemv_methods[1].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; gemv_methods[1].fn_matmul_f32_bf16p_bf16p = kai_run_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot; + + // Test kernel interface match kernels API + kai_matmul_clamp_f32_bf16p_bf16p_ukernel interface_test [[maybe_unused]]{}; + + INITIALIZE_INTERFACE_F32_BF16P_BF16P(interface_test, matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa) + INITIALIZE_INTERFACE_F32_BF16P_BF16P(interface_test, matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla) + INITIALIZE_INTERFACE_F32_BF16P_BF16P(interface_test, matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot) }; }; -- GitLab