diff --git a/BUILD.bazel b/BUILD.bazel index c4abdc25d53561cb9f090a6554aea430ab0d0536..a3c3e75546b5684f74d9f1af317fab2be055b817 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -29,6 +29,7 @@ config_setting( cc_library( name = "common", + srcs = ["kai/kai_common_sme_asm.S"], hdrs = ["kai/kai_common.h"], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6fe93eff1b7fd3558bfac5c9252950dd0a46999d..d08faf2e9b4cda2de0037891920dbf5916b2b959 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -241,6 +241,7 @@ set(KLEIDIAI_FILES_NEON_I8MM ) set(KLEIDIAI_FILES_SME_ASM + kai/kai_common_sme_asm.S kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla.c kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla_asm.S kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.c diff --git a/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt b/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt index afc8a4c551a9118cf5af50c3966eb79f2d2abefd..95a4b411e8b07f3261bc75673567c29b4a38e80b 100644 --- a/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt +++ b/examples/conv2d_imatmul_clamp_f16_f16_f16p_sme2/CMakeLists.txt @@ -23,6 +23,7 @@ set(KAI_SOURCES ${KAI_PATH}/kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.c ${KAI_PATH}/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme_asm.S ${KAI_PATH}/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c + ${KAI_PATH}/kai/kai_common_sme_asm.S ) # Files requires to build the executable diff --git a/kai/kai_common.h b/kai/kai_common.h index fb3003f70d6ff6de8dd58cd1fed33a3be5ec283e..ad30c4cb4853e92353f1f741dafb399cef02f62e 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -143,18 +143,9 @@ inline static size_t kai_roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; } -#ifdef __ARM_FEATURE_SVE2 +#if defined(__ARM_FEATURE_SVE2) || defined(_M_ARM64) /// Gets the SME vector length for 8-bit elements. -inline static uint64_t kai_get_sme_vector_length_u8(void) { - uint64_t res = 0; - __asm__ __volatile__( - ".inst 0x04bf5827 // rdsvl x7, #1\n" - "mov %0, x7\n" - : "=r"(res) - : /* no inputs */ - : "x7"); - return res; -} +uint64_t kai_get_sme_vector_length_u8(void); /// Gets the SME vector length for 16-bit elements. inline static uint64_t kai_get_sme_vector_length_u16(void) { @@ -165,7 +156,7 @@ inline static uint64_t kai_get_sme_vector_length_u16(void) { inline static uint64_t kai_get_sme_vector_length_u32(void) { return kai_get_sme_vector_length_u8() / 4; } -#endif // __ARM_FEATURE_SVE2 +#endif // defined(__ARM_FEATURE_SVE2) || defined(_M_ARM64) /// Extends the sign bit of int 4-bit value (stored in int8_t variable) /// @param[in] value The 4-bit int value diff --git a/kai/kai_common_sme_asm.S b/kai/kai_common_sme_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..baafe7ccff53b7a593b5dee59f4c932cdbad16b3 --- /dev/null +++ b/kai/kai_common_sme_asm.S @@ -0,0 +1,50 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(kai_common) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_get_sme_vector_length_u8) + +KAI_ASM_FUNCTION_TYPE(kai_get_sme_vector_length_u8) +KAI_ASM_FUNCTION_LABEL(kai_get_sme_vector_length_u8) + KAI_ASM_INST(0x04bf5820) // rdsvl x0, #1 + ret + KAI_ASM_FUNCTION_END(kai_get_sme_vector_length_u8) + + KAI_ASM_END +