diff --git a/CHANGELOG.md b/CHANGELOG.md index e9b3148ea5631e5e0b966e9a49159091f3ead92c..c96364c83bb2d7e800ea878c245d6efdbd7b792a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,8 +41,9 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F32 output, optimized for FEAT_DotProd. - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_DotProd. - Optimized version of kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 kernel for block depth of 8 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon`) +- New SME micro-kernels: + - Added GEMM F16 and F32 kernels using SME1 MOPA instruction, block size 2VLx2VL. - Added Convolution example using SME Indirect Matmul Kernels -- Added GEMM F32 kernel using SME1 MOPA instruction, block size 2VLx2VL. - Fixes: - Fix issue where kai_get_m_step() returns the incorrect value for kernels - matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla diff --git a/CMakeLists.txt b/CMakeLists.txt index d826540d18c42c08fd63308cedbf33489fa9d9ff..66a345488ea96dee12b0be2ccc276ccfa4ceeaab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -231,6 +231,8 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME_ASM + kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.c + kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa_asm.S kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.c kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_asm.S kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 4e63ebabe922dd5aa432fcfd52e325fffa49327f..1abee3c7af957cc5d8784f37078786000e2651c9 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -141,7 +141,6 @@ I8MM_KERNELS_ASM = [ # buildifier: keep sorted SME_KERNELS = [ - "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa", "pack/kai_lhs_pack_bf16p2vlx2_f32_sme", "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme", "pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme", @@ -150,6 +149,7 @@ SME_KERNELS = [ # buildifier: keep sorted SME_KERNELS_ASM = [ + "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa", "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa", "pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme", "pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme", diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c index a1ccaa278b89f54a0b63a98790b6aa797ad40e5a..7d9a8dc068a148c77a2b4efcc745ff8b41eee60d 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c @@ -97,7 +97,8 @@ size_t kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(s void kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max) { - KAI_ASSUME(dst_stride_col == sizeof(uint16_t)); + KAI_UNUSED(dst_stride_col); + KernelArgs args; args.A = lhs_packed; diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h index bc61fe28117f7fdc685b0e2359b09c8a204fa838..be0c787d43ad1e331d32fc72ae7c0eeddf179eea 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h @@ -14,8 +14,8 @@ extern "C" { /// Micro-kernel dependencies /// -/// -# kai_lhs_pack_f16p2vlx2_f16_sme to pack the LHS matrix. -/// -# kai_rhs_pack_kxn_f16p2vlx2b_f16_f16_sme to pack the RHS matrix. +/// -# kai_lhs_pack_x16p2vlx2_x16_sme to pack the LHS matrix. +/// -# kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme or kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme to pack the RHS matrix. /// Gets m step value. /// @@ -109,7 +109,7 @@ size_t kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(s /// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. /// @param[in] dst_stride_row Row stride in bytes of the output matrix. -/// @param[in] dst_stride_col Column stride in bytes of the output matrix. Must be 2 +/// @param[in] dst_stride_col Column stride in bytes of the output matrix. Currently, an unused parameter. /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. /// diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.c new file mode 100644 index 0000000000000000000000000000000000000000..eeecd7e4946c637d4b4e56c66844625671ada5c8 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.c @@ -0,0 +1,119 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if (!defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2)) && !defined(_M_ARM64) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural features check. +#include "kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + const void* A; + const void* B; + void* C; + uint64_t ldcb; + uint64_t M; + uint64_t N; + uint64_t K; + uint16_t min; + uint16_t max; + void* accumulator_buffer; + uint64_t flags; +} KernelArgs; + +static const size_t kai_mr = 2; +static const size_t kai_nr = 2; +static const size_t kai_kr = 2; +static const size_t kai_sr = 1; + +void kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(KernelArgs* args); +uint16_t kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(float value); + +// Returns a constant value specific to this kernel that's relative to vector length +static size_t kai_get_kernel_vec_length_constant(void) { + const size_t kernel_vec_length_constant = kai_get_sme_vector_length_u16() / kai_kr; + return kernel_vec_length_constant; +} + +size_t kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(void) { + return kai_mr * kai_get_kernel_vec_length_constant(); +} + +size_t kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(void) { + return kai_nr * kai_get_kernel_vec_length_constant(); +} + +size_t kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(void) { + return kai_mr * kai_get_kernel_vec_length_constant(); +} + +size_t kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(void) { + return kai_nr * kai_get_kernel_vec_length_constant(); +} + +size_t kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(size_t m_idx, size_t k) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa() == 0); + return m_idx * kai_roundup(k, kai_kr) * sizeof(uint16_t); +} + +static size_t kai_get_rhs_packed_stride_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(size_t k) { + return kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa() * + (sizeof(uint16_t) + kai_roundup(k, kai_kr) * sizeof(uint16_t)); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(size_t n_idx, size_t k) { + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa() == 0); + const size_t block_idx = n_idx / kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(); + return block_idx * kai_get_rhs_packed_stride_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride_row) { + KAI_ASSUME(m_idx % kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa() == 0); + KAI_ASSUME(n_idx % kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa() == 0); + + return m_idx * dst_stride_row + n_idx * sizeof(uint16_t); +} + +size_t kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(size_t m, size_t n) { + return m * n * sizeof(uint16_t); +} + +void kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float clamp_min, float clamp_max) { + KAI_UNUSED(dst_stride_col); + + KernelArgs args; + + args.A = lhs_packed; + args.B = rhs_packed; + args.C = dst; + args.ldcb = dst_stride_row; + args.M = m; + args.N = n; + args.K = k; + args.min = kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(clamp_min); + args.max = kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(clamp_max); + args.accumulator_buffer = NULL; + args.flags = 0; + + kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h new file mode 100644 index 0000000000000000000000000000000000000000..e9e3d81b78cfff1e2e38bb1a77f6098605baff3c --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h @@ -0,0 +1,125 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_pack_x16p2vlx2_x16_sme to pack the LHS matrix. +/// -# kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme or kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme to pack the RHS matrix. + +/// Gets m step value. +/// +/// The starting row index must be divisible by `m_step`. +/// +/// @return The m step value. +size_t kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(void); + +/// Gets n step value. +/// +/// The starting column index must be divisible by `n_step`. +/// +/// @return The n step value. +size_t kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(void); + +/// Gets mr value. +/// +/// This is the packing parameter which must be used to pack the LHS matrix. +/// +/// @return The mr value. +size_t kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(void); + +/// Gets nr value. +/// +/// This is the packing parameter which must be used to pack the RHS matrix. +/// +/// @return The nr value. +size_t kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(void); + +/// Gets kr value. +/// +/// This is the packing parameter which must be used to pack the LHS and RHS matrix. +/// +/// @return The kr value. +size_t kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(void); + +/// Gets sr value. +/// +/// This is the packing parameter which must be used to pack the LHS and RHS matrix. +/// +/// @return The sr value. +size_t kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(void); + +/// Gets the offset in bytes to the data element in the packed LHS matrix buffer. +/// +/// @param[in] m_idx Row index in the unpacked LHS matrix. Must be a multiple of `m_step`. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(size_t m_idx, size_t k); + +/// Gets the offset in bytes to the data element in the packed RHS matrix buffer. +/// +/// @param[in] n_idx Column index in the unpacked RHS matrix. Must be a multiple of `n_step`. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(size_t n_idx, size_t k); + +/// Gets the offset in bytes to the data element in the destination matrix buffer. +/// +/// @param[in] m_idx Row index. Must be a multiple of `m_step`. +/// @param[in] n_idx Column index. Must be a multiple of `n_step`. +/// @param[in] dst_stride_row Row stride in bytes. +/// +/// @return The offset in bytes to the data element. +size_t kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride_row); + +/// Gets the size in bytes of the destination matrix buffer. +/// +/// @param[in] m Number of rows. +/// @param[in] n Number of columns. +/// +/// @return The size in bytes of the destination matrix buffer. +size_t kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa(size_t m, size_t n); + +/// Runs the matrix multiplication microkernel followed by a clamp operation. +/// +/// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset +/// calculated using the following functions: +/// +/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa. +/// +/// @param[in] m Number of output rows to be computed. +/// @param[in] n Number of output columns to be computed. +/// @param[in] k Number of columns in the unpacked LHS matrix. +/// @param[in] lhs_packed Packed LHS matrix buffer. +/// @param[in] rhs_packed Packed RHS matrix buffer. +/// @param[out] dst Output matrix buffer. +/// @param[in] dst_stride_row Row stride in bytes of the output matrix. +/// @param[in] dst_stride_col Column stride in bytes of the output matrix. Currently, an unused parameter. +/// @param[in] clamp_min Minimum value to clamp the final result. +/// @param[in] clamp_max Maximum value to clamp the final result. +/// +/// @note Clamp minimum and maximum values are cast internally to the destination type before clamping the computed +/// values. +/// +void kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, + size_t dst_stride_col, float clamp_min, float clamp_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..43f12f1dd19d97c7d0c37996fb9aa3640aa02460 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa_asm.S @@ -0,0 +1,228 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa) + + KAI_ASM_GLOBAL(kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa) +KAI_ASM_FUNCTION_LABEL(kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa) + fcvt h0, s0 + fmov w0, h0 + ret + KAI_ASM_FUNCTION_END(kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + KAI_ASM_INST(0xd503477f) // SMSTART ZA + mov x14, #0x0 + ldr x13, [x0, #0x30] + ptrue p1.b + mov x11, #0x0 + ldr w10, [x0, #0x20] + ldr w9, [x0, #0x28] + add x13, x13, #0x1 + ldr x28, [x0, #0x0] + lsr x13, x13, #0x1 +KAI_ASM_LABEL(label_1) // M loop + ldr x27, [x0, #0x8] +KAI_ASM_LABEL(label_2) // N loop + fmov z19.h, #0.0 + ld1h { z16.h }, p1/Z, [x27] + fmov z18.h, #1.0 + mov x20, x11 + whilelt p8.s, x20, x9 + incw x20 + KAI_ASM_INST(0xc00800ff) // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 } + mov x26, x28 + whilelt p8.s, x20, x9 + inch x27, ALL, MUL #2 + zip1 z17.h, z16.h, z19.h + zip2 z16.h, z16.h, z19.h + KAI_ASM_INST(0x81b12640) // fmopa za0.s, p1/M, p1/M, z18.h, z17.h + KAI_ASM_INST(0x81b02641) // fmopa za1.s, p1/M, p1/M, z18.h, z16.h + KAI_ASM_INST(0x81b12642) // fmopa za2.s, p1/M, p1/M, z18.h, z17.h + KAI_ASM_INST(0x81b02643) // fmopa za3.s, p1/M, p1/M, z18.h, z16.h + lsr x21, x13, #0x2 + and x20, x13, #0x3 + cbz x21, label_6 + subs x21, x21, #0x1 + ld1h { z31.h }, p1/Z, [x26] + ld1h { z30.h }, p1/Z, [x26, #1, MUL VL] + ld1h { z29.h }, p1/Z, [x26, #2, MUL VL] + ld1h { z28.h }, p1/Z, [x26, #3, MUL VL] + ld1h { z27.h }, p1/Z, [x26, #4, MUL VL] + ld1h { z26.h }, p1/Z, [x26, #5, MUL VL] + ld1h { z25.h }, p1/Z, [x26, #6, MUL VL] + ld1h { z24.h }, p1/Z, [x26, #7, MUL VL] + addvl x26, x26, #8 + ld1h { z23.h }, p1/Z, [x27] + ld1h { z22.h }, p1/Z, [x27, #1, MUL VL] + ld1h { z21.h }, p1/Z, [x27, #2, MUL VL] + ld1h { z20.h }, p1/Z, [x27, #3, MUL VL] + ld1h { z19.h }, p1/Z, [x27, #4, MUL VL] + ld1h { z18.h }, p1/Z, [x27, #5, MUL VL] + ld1h { z17.h }, p1/Z, [x27, #6, MUL VL] + ld1h { z16.h }, p1/Z, [x27, #7, MUL VL] + addvl x27, x27, #8 + ble label_5 +KAI_ASM_LABEL(label_4) // K loop + KAI_ASM_INST(0x81b727e0) // fmopa za0.s, p1/M, p1/M, z31.h, z23.h + subs x21, x21, #0x1 + KAI_ASM_INST(0x81b627e1) // fmopa za1.s, p1/M, p1/M, z31.h, z22.h + ld1h { z31.h }, p1/Z, [x26] + KAI_ASM_INST(0x81b727c2) // fmopa za2.s, p1/M, p1/M, z30.h, z23.h + ld1h { z23.h }, p1/Z, [x27] + KAI_ASM_INST(0x81b627c3) // fmopa za3.s, p1/M, p1/M, z30.h, z22.h + ld1h { z30.h }, p1/Z, [x26, #1, MUL VL] + KAI_ASM_INST(0x81b527a0) // fmopa za0.s, p1/M, p1/M, z29.h, z21.h + ld1h { z22.h }, p1/Z, [x27, #1, MUL VL] + KAI_ASM_INST(0x81b427a1) // fmopa za1.s, p1/M, p1/M, z29.h, z20.h + ld1h { z29.h }, p1/Z, [x26, #2, MUL VL] + KAI_ASM_INST(0x81b52782) // fmopa za2.s, p1/M, p1/M, z28.h, z21.h + ld1h { z21.h }, p1/Z, [x27, #2, MUL VL] + KAI_ASM_INST(0x81b42783) // fmopa za3.s, p1/M, p1/M, z28.h, z20.h + ld1h { z28.h }, p1/Z, [x26, #3, MUL VL] + KAI_ASM_INST(0x81b32760) // fmopa za0.s, p1/M, p1/M, z27.h, z19.h + ld1h { z20.h }, p1/Z, [x27, #3, MUL VL] + KAI_ASM_INST(0x81b22761) // fmopa za1.s, p1/M, p1/M, z27.h, z18.h + ld1h { z27.h }, p1/Z, [x26, #4, MUL VL] + KAI_ASM_INST(0x81b32742) // fmopa za2.s, p1/M, p1/M, z26.h, z19.h + ld1h { z19.h }, p1/Z, [x27, #4, MUL VL] + KAI_ASM_INST(0x81b22743) // fmopa za3.s, p1/M, p1/M, z26.h, z18.h + ld1h { z26.h }, p1/Z, [x26, #5, MUL VL] + KAI_ASM_INST(0x81b12720) // fmopa za0.s, p1/M, p1/M, z25.h, z17.h + ld1h { z18.h }, p1/Z, [x27, #5, MUL VL] + KAI_ASM_INST(0x81b02721) // fmopa za1.s, p1/M, p1/M, z25.h, z16.h + ld1h { z25.h }, p1/Z, [x26, #6, MUL VL] + KAI_ASM_INST(0x81b12702) // fmopa za2.s, p1/M, p1/M, z24.h, z17.h + ld1h { z17.h }, p1/Z, [x27, #6, MUL VL] + KAI_ASM_INST(0x81b02703) // fmopa za3.s, p1/M, p1/M, z24.h, z16.h + ld1h { z24.h }, p1/Z, [x26, #7, MUL VL] + addvl x26, x26, #8 + ld1h { z16.h }, p1/Z, [x27, #7, MUL VL] + addvl x27, x27, #8 + bgt label_4 +KAI_ASM_LABEL(label_5) // K loop tail + KAI_ASM_INST(0x81b727e0) // fmopa za0.s, p1/M, p1/M, z31.h, z23.h + KAI_ASM_INST(0x81b627e1) // fmopa za1.s, p1/M, p1/M, z31.h, z22.h + KAI_ASM_INST(0x81b727c2) // fmopa za2.s, p1/M, p1/M, z30.h, z23.h + KAI_ASM_INST(0x81b627c3) // fmopa za3.s, p1/M, p1/M, z30.h, z22.h + KAI_ASM_INST(0x81b527a0) // fmopa za0.s, p1/M, p1/M, z29.h, z21.h + KAI_ASM_INST(0x81b427a1) // fmopa za1.s, p1/M, p1/M, z29.h, z20.h + KAI_ASM_INST(0x81b52782) // fmopa za2.s, p1/M, p1/M, z28.h, z21.h + KAI_ASM_INST(0x81b42783) // fmopa za3.s, p1/M, p1/M, z28.h, z20.h + KAI_ASM_INST(0x81b32760) // fmopa za0.s, p1/M, p1/M, z27.h, z19.h + KAI_ASM_INST(0x81b22761) // fmopa za1.s, p1/M, p1/M, z27.h, z18.h + KAI_ASM_INST(0x81b32742) // fmopa za2.s, p1/M, p1/M, z26.h, z19.h + KAI_ASM_INST(0x81b22743) // fmopa za3.s, p1/M, p1/M, z26.h, z18.h + KAI_ASM_INST(0x81b12720) // fmopa za0.s, p1/M, p1/M, z25.h, z17.h + KAI_ASM_INST(0x81b02721) // fmopa za1.s, p1/M, p1/M, z25.h, z16.h + KAI_ASM_INST(0x81b12702) // fmopa za2.s, p1/M, p1/M, z24.h, z17.h + KAI_ASM_INST(0x81b02703) // fmopa za3.s, p1/M, p1/M, z24.h, z16.h +KAI_ASM_LABEL(label_6) // K oddments + cbz x20, label_8 +KAI_ASM_LABEL(label_7) // K oddments: Loop + ld1h { z19.h }, p1/Z, [x26] + subs x20, x20, #0x1 + ld1h { z18.h }, p1/Z, [x26, #1, MUL VL] + addvl x26, x26, #2 + ld1h { z17.h }, p1/Z, [x27] + ld1h { z16.h }, p1/Z, [x27, #1, MUL VL] + addvl x27, x27, #2 + KAI_ASM_INST(0x81b12660) // fmopa za0.s, p1/M, p1/M, z19.h, z17.h + KAI_ASM_INST(0x81b02661) // fmopa za1.s, p1/M, p1/M, z19.h, z16.h + KAI_ASM_INST(0x81b12642) // fmopa za2.s, p1/M, p1/M, z18.h, z17.h + KAI_ASM_INST(0x81b02643) // fmopa za3.s, p1/M, p1/M, z18.h, z16.h + bgt label_7 +KAI_ASM_LABEL(label_8) // K oddments: End + ldr x25, [x0, #0x10] + sub x24, x10, x14 + cntw x23, ALL, MUL #2 + KAI_ASM_INST(0x84dca413) // ld1rh { z19.h }, p1/Z, [x0, #56] + ldr x22, [x0, #0x18] + whilelt p0.h, x11, x9 + cmp x24, x23 + KAI_ASM_INST(0x84dda412) // ld1rh { z18.h }, p1/Z, [x0, #58] + mov x12, #0x0 + mov x21, #0x0 + add x25, x25, x11, LSL #1 // C += n + mov x20, #0x2 + madd x25, x14, x22, x25 // C += m * ldc + csel x24, x24, x23, LT +KAI_ASM_LABEL(label_10) // Store to output array: Accumulator loop + KAI_ASM_INST(0xc0020411) // mova z17.b, p1/M, za0h.b[x12] + add x21, x21, #0x1 + KAI_ASM_INST(0xc0020430) // mova z16.b, p1/M, za0h.b[x12, #1] + fcvt z17.h, p1/m, z17.s + add x12, x12, #0x4 + fcvt z16.h, p1/m, z16.s + cmp x12, x23, LSL #1 + csel x12, x12, x20, LT + cmp x21, x24 + uzp1 z16.h, z17.h, z16.h + fmin z16.h, p1/M, z16.h, z18.h + fmax z16.h, p1/M, z16.h, z19.h + st1h { z16.h }, p0, [x25] + add x25, x25, x22 + blt label_10 + incw x11, ALL, MUL #2 + cmp x11, x9 + blt label_2 + incw x14, ALL, MUL #2 + mov x11, #0x0 + cmp x14, x10 + mov x28, x26 + blt label_1 + KAI_ASM_INST(0xd503467f) // SMSTOP + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa) + + KAI_ASM_END diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index aa8a73ce1c336d8f5beb3bf97fcfe8f7ae29d557..dd893ec3a67bc359fd4dfc9796105de3d1595800 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -36,12 +36,15 @@ #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" -// matmul_clamp_f16_f16p_f16p +// matmul_clamp_f16_f16p_f16p_2vlx2vl_sme2_mopa #include "kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" #include "kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.h" +// matmul_clamp_f16_f16p_f16p_2vlx2vl_sme_mopa +#include "kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h" + // matmul_clamp_f16_f16_f16p #include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h" @@ -61,9 +64,9 @@ namespace kai::test { -static const std::array& get_matmul_methods() { +static const std::array& get_matmul_methods() { // List of supported matrix multiplication methods. - static std::array matmul_methods{}; + static std::array matmul_methods{}; matmul_methods[0].name = "matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla"; matmul_methods[0].m0 = 6; @@ -259,6 +262,52 @@ static const std::array& get_matmul_methods() { matmul_methods[4].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; matmul_methods[4].fn_matmul_f32_f32p_f32p = kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa; + matmul_methods[5].name = "matmul_nt_nt_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa"; + matmul_methods[5].m0 = 2 * get_sme_vector_length(); + matmul_methods[5].n0 = 2 * get_sme_vector_length(); + matmul_methods[5].dst_format = DataFormat(DataType::FP16); + matmul_methods[5].lhs_format = DataFormat(DataType::FP16); + matmul_methods[5].packed_lhs_format = DataFormat(DataType::FP16, 2 * get_sme_vector_length(), 2); + matmul_methods[5].rhs_format = DataFormat(DataType::FP16); + matmul_methods[5].packed_rhs_format = DataFormat( + DataType::FP16, // Output type + 2 * get_sme_vector_length(), 2, // Block size + DataFormat::PackFormat::BIAS_PER_ROW, // Data layout + DataType::FP16, // Bias format + DataType::UNKNOWN, // Scaling type + 2 * get_sme_vector_length(), 2); // Sub-block + matmul_methods[5].bias_format = DataFormat(DataType::FP16); + matmul_methods[5].fn_is_supported = cpu_has_sme; + matmul_methods[5].fn_get_mr = kai_get_mr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; + matmul_methods[5].fn_get_nr = kai_get_nr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; + matmul_methods[5].fn_get_kr = kai_get_kr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; + matmul_methods[5].fn_get_sr = kai_get_sr_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; + matmul_methods[5].fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; + matmul_methods[5].fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[5].fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; + matmul_methods[5].fn_get_lhs_offset = kai_get_lhs_offset_lhs_pack_x16p2vlx2_x16_sme; + matmul_methods[5].fn_get_packed_lhs_size = kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme; + matmul_methods[5].fn_get_packed_lhs_offset = + kai_get_lhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; + matmul_methods[5].fn_pack_lhs = kai_run_lhs_pack_x16p2vlx2_x16_sme; + matmul_methods[5].fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[5].fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[5].fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[5].fn_get_main_packed_rhs_offset = + kai_get_rhs_packed_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; + matmul_methods[5].fn_pack_rhs = kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[5].fn_pack_rhs_nxk_get_n_step = kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[5].fn_pack_rhs_nxk_get_rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[5].fn_pack_rhs_nxk_get_bias_offset = kai_get_bias_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[5].fn_pack_rhs_nxk_get_packed_rhs_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[5].fn_pack_rhs_nxk_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[5].fn_pack_rhs_nxk = kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme; + matmul_methods[5].fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme; + matmul_methods[5].fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; + matmul_methods[5].fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; + matmul_methods[5].fn_matmul_f16_f16p_f16p = kai_run_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa; + return matmul_methods; }