diff --git a/CHANGELOG.md b/CHANGELOG.md index f50c025a785d13f74419409a4c019ada30fce754..592d1389fbb51cfc7982e0ad6135020e9890afde 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Remove `-Werror` from default build flags as to not cause integration problems - Expose the rhs_packed_stride in the header file - Fix validation error when n > nr in kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa +- Add MSVC support for `kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0` and `kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0` packing kernels. + - Add assembler implementation of f32 to f16 typecasting to avoid use of float16_t. ## v1.2.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index 19675fbfbad8d94fa9434451b2a07eebd842b6cd..3fd05182b94e78e862a76cb0fc8f2b2931164adf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,10 +88,12 @@ set(KLEIDIAI_FILES_SCALAR kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_cast_asm.S kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_cast_asm.S ) set(KLEIDIAI_FILES_NEON_FP16 diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt index 28bbd67c58e088866cbac5756653cd468be27fc2..bec71b04707bb9b44965c9ee9c69bb1f98352539 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt @@ -14,6 +14,8 @@ set(KLEIDIAI_PATH ../../) set(MATMUL_PACK_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/pack/) set(MATMUL_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/) +enable_language(ASM) + # KleidiAI include directories include_directories( ${KLEIDIAI_PATH} @@ -25,7 +27,9 @@ add_executable(matmul_clamp_f32_qai8dxp_qsi4c32p matmul_clamp_f32_qai8dxp_qsi4c32p.cpp ${KLEIDIAI_PATH}/kai/kai_common.h ${MATMUL_PACK_PATH}/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c + ${MATMUL_PACK_PATH}/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_cast_asm.S ${MATMUL_PACK_PATH}/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c + ${MATMUL_PACK_PATH}/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_cast_asm.S ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_qai8dxp_f32.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 97f122188e1eedf6051b7924af9043d893ff847b..d3999867633c7ce7d0ca2a5dc8cdd45e3ae2599c 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -42,6 +42,8 @@ NEON_KERNELS = [ NEON_KERNELS_ASM = [ "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla", + "pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_cast", + "pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_cast", ] # buildifier: keep sorted diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c index 608de4105e0f6a65343da2a8289147a3af5c96b8..92c34e9a6a4f4b0505faefd2c6598754ee828539 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c @@ -11,6 +11,8 @@ #include "kai/kai_common.h" +float kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_float_from_float16(uint16_t value); + static const size_t kai_num_bytes_sum_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); static const size_t kai_nr_multiple_of = 4; @@ -204,7 +206,8 @@ void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( d = ((float*)rhs_packed_scale)[nr_idx]; break; case kai_dt_f16: - d = kai_cast_f32_f16(((uint16_t*)rhs_packed_scale)[nr_idx]); + d = kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_float_from_float16( + ((uint16_t*)rhs_packed_scale)[nr_idx]); break; case kai_dt_bf16: d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_cast_asm.S b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_cast_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..06a10a04355ce647de4a62fcd1295813d2d6355b --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_cast_asm.S @@ -0,0 +1,36 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) +# define KAI_ASM_CODE(name) AREA name, CODE, READONLY +# define KAI_ASM_LABEL(label) label +# define KAI_ASM_LABEL_GLOBAL(label) label +# define KAI_ASM_GLOBAL(symbol) global symbol +# define KAI_ASM_END end +#elif defined(__APPLE__) +# define KAI_ASM_CODE(name) .text +# define KAI_ASM_LABEL(label) _##label: +# define KAI_ASM_LABEL_GLOBAL(label) _##label: +# define KAI_ASM_GLOBAL(symbol) .global _##symbol +# define KAI_ASM_END +#else +# define KAI_ASM_CODE(name) .text +# define KAI_ASM_LABEL(label) label: +# define KAI_ASM_LABEL_GLOBAL(label) label: +# define KAI_ASM_GLOBAL(symbol) .global symbol +# define KAI_ASM_END +#endif + + KAI_ASM_CODE(kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_float16) + + KAI_ASM_GLOBAL(kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_float_from_float16) + +KAI_ASM_LABEL_GLOBAL(kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_float_from_float16) + fcvt h0, s0 + fmov w0, s0 + ret + + KAI_ASM_END diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index 5a86b8a0ef3557c5e1ca13ce40bdf35ca20cbe07..ae59f6154885b0f73ef0508d4e909afee84a20d6 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -11,6 +11,8 @@ #include "kai/kai_common.h" +float kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_float_from_float16(uint16_t value); + static const size_t kai_num_bytes_sum_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); static const size_t kai_nr_multiple_of = 4; @@ -197,7 +199,8 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( d = ((float*)rhs_packed_scale)[nr_idx]; break; case kai_dt_f16: - d = kai_cast_f32_f16(((uint16_t*)rhs_packed_scale)[nr_idx]); + d = kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_float_from_float16( + ((uint16_t*)rhs_packed_scale)[nr_idx]); break; case kai_dt_bf16: d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_cast_asm.S b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_cast_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..dd4b68c1008ea92744935c86f521254283d0380c --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_cast_asm.S @@ -0,0 +1,36 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) +# define KAI_ASM_CODE(name) AREA name, CODE, READONLY +# define KAI_ASM_LABEL(label) label +# define KAI_ASM_LABEL_GLOBAL(label) label +# define KAI_ASM_GLOBAL(symbol) global symbol +# define KAI_ASM_END end +#elif defined(__APPLE__) +# define KAI_ASM_CODE(name) .text +# define KAI_ASM_LABEL(label) _##label: +# define KAI_ASM_LABEL_GLOBAL(label) _##label: +# define KAI_ASM_GLOBAL(symbol) .global _##symbol +# define KAI_ASM_END +#else +# define KAI_ASM_CODE(name) .text +# define KAI_ASM_LABEL(label) label: +# define KAI_ASM_LABEL_GLOBAL(label) label: +# define KAI_ASM_GLOBAL(symbol) .global symbol +# define KAI_ASM_END +#endif + + KAI_ASM_CODE(kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_float16) + + KAI_ASM_GLOBAL(kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_float_from_float16) + +KAI_ASM_LABEL_GLOBAL(kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_float_from_float16) + fcvt h0, s0 + fmov w0, s0 + ret + + KAI_ASM_END