diff --git a/CMakeLists.txt b/CMakeLists.txt index 80adb858238e89478ba8db2821d3757ded311d1a..344638d5de3e4adbfd9129efc01933e6caab263a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,7 +82,7 @@ set(KLEIDIAI_FILES_SCALAR set(KLEIDIAI_FILES_NEON_FP16 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c - kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c + kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c ) set(KLEIDIAI_FILES_NEON_DOTPROD diff --git a/examples/matmul_clamp_f16_f16_f16p/CMakeLists.txt b/examples/matmul_clamp_f16_f16_f16p/CMakeLists.txt index 487f69d1a83fd8b17a82b72efb1523221bd03886..2d3da7018627d725aafc644ee93fb0e8cc1e3f01 100644 --- a/examples/matmul_clamp_f16_f16_f16p/CMakeLists.txt +++ b/examples/matmul_clamp_f16_f16_f16p/CMakeLists.txt @@ -20,7 +20,7 @@ include_directories( # Files requires to build the executable add_executable(matmul_clamp_f16_f16_f16p matmul_clamp_f16_f16_f16p.cpp - ${MATMUL_PATH}/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c + ${MATMUL_PATH}/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c ${MATMUL_PACK_PATH}/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c ) diff --git a/examples/matmul_clamp_f16_f16_f16p/matmul_clamp_f16_f16_f16p.cpp b/examples/matmul_clamp_f16_f16_f16p/matmul_clamp_f16_f16_f16p.cpp index 725d91e5bc130f70775dd7f183678914fa4033fa..1e1975dada1170aa70b633ede8f7141f5cd3da17 100644 --- a/examples/matmul_clamp_f16_f16_f16p/matmul_clamp_f16_f16_f16p.cpp +++ b/examples/matmul_clamp_f16_f16_f16p/matmul_clamp_f16_f16_f16p.cpp @@ -25,8 +25,8 @@ #include // Include micro-kernel variants -#include "kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" #include "kai_matmul_clamp_f16_f16_f16p_interface.h" +#include "kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c" #include "kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" #define FLOAT16_MIN (-65504) @@ -35,16 +35,16 @@ namespace { /// Micro-kernel interface constexpr kai_matmul_clamp_f16_f16_f16p_ukernel ukernel{ - kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla}; + kai_get_m_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_n_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_nr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_kr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_sr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_lhs_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_dst_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_dst_size_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_run_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla}; /// Reference implementation of matrix multiplication void run_matmul_ref( @@ -209,7 +209,7 @@ int main() { const bool is_valid = is_output_correct(M, N, 0.0001, dst_ref, dst); std::cout << "TEST[matmul_clamp_f16_f16_f16p]\n"; - std::cout << "- ukernel: matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla\n"; + std::cout << "- ukernel: matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla\n"; if (is_valid) { std::cout << "- Status: PASSED\n"; std::cout << "- Performance: " << time_matmul.count() << "ns\n"; diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 013a96c30514c85ecdbead955b5cf85dcfada887..840d10e313ee1b8ceba2ceb060dc6f280057b6ed 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -25,8 +25,8 @@ kai_c_library( kai_c_library( name = "clamp_f16_f16_f16p", - srcs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c"], - hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h"], + srcs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c"], + hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h"], cpu_uarch = kai_cpu_fp16(), deps = [ ":clamp_f16_f16_f16p_interface", diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c similarity index 93% rename from kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c rename to kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c index 9f9d5a753284e88bf9bf4ac714d31fc9888ca20c..840258adcfd771b2b2601983c1b7bc87996f8916 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c @@ -9,7 +9,7 @@ #error This file must be compiled for AArch64, FEAT_FP16. #else // Architectural features check. -#include "kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" +#include "kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h" #include #include @@ -22,39 +22,39 @@ static const size_t kai_nr = 16; static const size_t kai_kr = 1; static const size_t kai_sr = 1; -size_t kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void) { +size_t kai_get_m_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void) { return kai_mr; } -size_t kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void) { +size_t kai_get_n_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void) { return kai_nr; } -size_t kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void) { +size_t kai_get_nr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void) { +size_t kai_get_kr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void) { +size_t kai_get_sr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void) { return kai_sr; } -size_t kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(size_t m_idx, size_t stride) { +size_t kai_get_lhs_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(size_t m_idx, size_t stride) { KAI_ASSUME(m_idx % kai_mr == 0); return m_idx * stride; } -size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_nr == 0); return n_idx / kai_nr * (kai_nr * sizeof(__fp16) + kai_nr * k * sizeof(__fp16)); } -size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( +size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla( size_t m_idx, size_t n_idx, size_t stride) { KAI_ASSUME(m_idx % kai_mr == 0); KAI_ASSUME(n_idx % kai_nr == 0); @@ -62,11 +62,11 @@ size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( return m_idx * stride + n_idx * sizeof(__fp16); } -size_t kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(size_t m, size_t n) { return m * n * sizeof(__fp16); } -void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( +void kai_run_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla( size_t m, size_t n, size_t k, // const void* lhs, size_t lhs_stride, // const void* rhs_packed, // @@ -235,9 +235,9 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "blt 19f\n" "18:" // Height 1: Multiply loop: Main loop head "fmla v20.8h, v6.8h, v0.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q23, [x10, #0xe0]\n" "fmla v21.8h, v7.8h, v0.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q22, [x10, #0xf0]\n" "sub x27, x27, #0x8\n" "add x26, x26, #0x10\n" "cmp x27, #0x10\n" @@ -267,17 +267,17 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "ldr q18, [x10, #0xc0]\n" "fmla v21.8h, v19.8h, v0.h[6]\n" "ldr q19, [x10, #0xd0]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" + "fmla v20.8h, v23.8h, v0.h[7]\n" "ldr q6, [x10, #0x0]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" + "fmla v21.8h, v22.8h, v0.h[7]\n" "ldr q0, [x26, #0x0]\n" "ldr q7, [x10, #0x10]\n" "bge 18b\n" "19:" // Height 1: Multiply loop: Single iteration only "fmla v20.8h, v6.8h, v0.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q23, [x10, #0xe0]\n" "fmla v21.8h, v7.8h, v0.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q22, [x10, #0xf0]\n" "add x26, x26, #0x10\n" "sub x27, x27, #0x8\n" "add x10, x10, #0x100\n" @@ -294,18 +294,18 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v21.8h, v17.8h, v0.h[5]\n" "fmla v20.8h, v18.8h, v0.h[6]\n" "fmla v21.8h, v19.8h, v0.h[6]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" + "fmla v20.8h, v23.8h, v0.h[7]\n" + "fmla v21.8h, v22.8h, v0.h[7]\n" "20:" // Height 1: Multiply loop: Main loop skip "cbz x27, 22f\n" "21:" // Height 1: Multiply loop: Odd block loop "ldr h0, [x26], #0x2\n" - "ldr q8, [x10, #0x0]\n" + "ldr q17, [x10, #0x0]\n" "sub x27, x27, #0x1\n" - "ldr q9, [x10, #0x10]\n" + "ldr q16, [x10, #0x10]\n" "add x10, x10, #0x20\n" - "fmla v20.8h, v8.8h, v0.h[0]\n" - "fmla v21.8h, v9.8h, v0.h[0]\n" + "fmla v20.8h, v17.8h, v0.h[0]\n" + "fmla v21.8h, v16.8h, v0.h[0]\n" "cbnz x27, 21b\n" "22:" // Height 1: Multiply loop: No odd multiplies "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" @@ -519,11 +519,11 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "51:" // Height 2: Multiply loop: Main loop head "fmla v20.8h, v6.8h, v0.h[0]\n" "fmla v22.8h, v6.8h, v1.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q25, [x10, #0xe0]\n" "sub x27, x27, #0x8\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "fmla v23.8h, v7.8h, v1.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q24, [x10, #0xf0]\n" "add x26, x26, #0x10\n" "add x25, x25, #0x10\n" "cmp x27, #0x10\n" @@ -566,23 +566,23 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v21.8h, v19.8h, v0.h[6]\n" "fmla v23.8h, v19.8h, v1.h[6]\n" "ldr q19, [x10, #0xd0]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" + "fmla v20.8h, v25.8h, v0.h[7]\n" + "fmla v22.8h, v25.8h, v1.h[7]\n" "ldr q6, [x10, #0x0]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" + "fmla v21.8h, v24.8h, v0.h[7]\n" "ldr q0, [x26, #0x0]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" + "fmla v23.8h, v24.8h, v1.h[7]\n" "ldr q1, [x25, #0x0]\n" "ldr q7, [x10, #0x10]\n" "bge 51b\n" "52:" // Height 2: Multiply loop: Single iteration only "fmla v20.8h, v6.8h, v0.h[0]\n" "fmla v22.8h, v6.8h, v1.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q25, [x10, #0xe0]\n" "add x26, x26, #0x10\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "fmla v23.8h, v7.8h, v1.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q24, [x10, #0xf0]\n" "add x25, x25, #0x10\n" "sub x27, x27, #0x8\n" "prfm pldl1keep, [x26, #0x80]\n" @@ -612,23 +612,23 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v22.8h, v18.8h, v1.h[6]\n" "fmla v21.8h, v19.8h, v0.h[6]\n" "fmla v23.8h, v19.8h, v1.h[6]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" + "fmla v20.8h, v25.8h, v0.h[7]\n" + "fmla v22.8h, v25.8h, v1.h[7]\n" + "fmla v21.8h, v24.8h, v0.h[7]\n" + "fmla v23.8h, v24.8h, v1.h[7]\n" "53:" // Height 2: Multiply loop: Main loop skip "cbz x27, 55f\n" "54:" // Height 2: Multiply loop: Odd block loop - "ldr h0, [x26], #0x2\n" - "ldr h1, [x25], #0x2\n" + "ldr h1, [x26], #0x2\n" + "ldr h0, [x25], #0x2\n" "sub x27, x27, #0x1\n" - "ldr q8, [x10, #0x0]\n" - "ldr q9, [x10, #0x10]\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" "add x10, x10, #0x20\n" - "fmla v20.8h, v8.8h, v0.h[0]\n" - "fmla v22.8h, v8.8h, v1.h[0]\n" - "fmla v21.8h, v9.8h, v0.h[0]\n" - "fmla v23.8h, v9.8h, v1.h[0]\n" + "fmla v20.8h, v17.8h, v1.h[0]\n" + "fmla v22.8h, v17.8h, v0.h[0]\n" + "fmla v21.8h, v16.8h, v1.h[0]\n" + "fmla v23.8h, v16.8h, v0.h[0]\n" "cbnz x27, 54b\n" "55:" // Height 2: Multiply loop: No odd multiplies "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" @@ -895,12 +895,12 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "sub x27, x27, #0x8\n" "add x26, x26, #0x10\n" "fmla v24.8h, v6.8h, v2.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q27, [x10, #0xe0]\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "add x25, x25, #0x10\n" "fmla v23.8h, v7.8h, v1.h[0]\n" "fmla v25.8h, v7.8h, v2.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q26, [x10, #0xf0]\n" "add x24, x24, #0x10\n" "cmp x27, #0x10\n" "add x10, x10, #0x100\n" @@ -955,15 +955,15 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v23.8h, v19.8h, v1.h[6]\n" "fmla v25.8h, v19.8h, v2.h[6]\n" "ldr q19, [x10, #0xd0]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v24.8h, v6.8h, v2.h[7]\n" + "fmla v20.8h, v27.8h, v0.h[7]\n" + "fmla v22.8h, v27.8h, v1.h[7]\n" + "fmla v24.8h, v27.8h, v2.h[7]\n" "ldr q6, [x10, #0x0]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" + "fmla v21.8h, v26.8h, v0.h[7]\n" "ldr q0, [x26, #0x0]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" + "fmla v23.8h, v26.8h, v1.h[7]\n" "ldr q1, [x25, #0x0]\n" - "fmla v25.8h, v7.8h, v2.h[7]\n" + "fmla v25.8h, v26.8h, v2.h[7]\n" "ldr q2, [x24, #0x0]\n" "ldr q7, [x10, #0x10]\n" "bge 84b\n" @@ -973,12 +973,12 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "add x26, x26, #0x10\n" "add x25, x25, #0x10\n" "fmla v24.8h, v6.8h, v2.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q27, [x10, #0xe0]\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "add x24, x24, #0x10\n" "fmla v23.8h, v7.8h, v1.h[0]\n" "fmla v25.8h, v7.8h, v2.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q26, [x10, #0xf0]\n" "prfm pldl1keep, [x26, #0x80]\n" "sub x27, x27, #0x8\n" "prfm pldl1keep, [x25, #0x80]\n" @@ -1020,28 +1020,28 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v21.8h, v19.8h, v0.h[6]\n" "fmla v23.8h, v19.8h, v1.h[6]\n" "fmla v25.8h, v19.8h, v2.h[6]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v24.8h, v6.8h, v2.h[7]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" - "fmla v25.8h, v7.8h, v2.h[7]\n" + "fmla v20.8h, v27.8h, v0.h[7]\n" + "fmla v22.8h, v27.8h, v1.h[7]\n" + "fmla v24.8h, v27.8h, v2.h[7]\n" + "fmla v21.8h, v26.8h, v0.h[7]\n" + "fmla v23.8h, v26.8h, v1.h[7]\n" + "fmla v25.8h, v26.8h, v2.h[7]\n" "86:" // Height 3: Multiply loop: Main loop skip "cbz x27, 88f\n" "87:" // Height 3: Multiply loop: Odd block loop - "ldr h0, [x26], #0x2\n" + "ldr h2, [x26], #0x2\n" "ldr h1, [x25], #0x2\n" "sub x27, x27, #0x1\n" - "ldr h2, [x24], #0x2\n" - "ldr q8, [x10, #0x0]\n" - "ldr q9, [x10, #0x10]\n" + "ldr h0, [x24], #0x2\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" "add x10, x10, #0x20\n" - "fmla v20.8h, v8.8h, v0.h[0]\n" - "fmla v22.8h, v8.8h, v1.h[0]\n" - "fmla v24.8h, v8.8h, v2.h[0]\n" - "fmla v21.8h, v9.8h, v0.h[0]\n" - "fmla v23.8h, v9.8h, v1.h[0]\n" - "fmla v25.8h, v9.8h, v2.h[0]\n" + "fmla v20.8h, v17.8h, v2.h[0]\n" + "fmla v22.8h, v17.8h, v1.h[0]\n" + "fmla v24.8h, v17.8h, v0.h[0]\n" + "fmla v21.8h, v16.8h, v2.h[0]\n" + "fmla v23.8h, v16.8h, v1.h[0]\n" + "fmla v25.8h, v16.8h, v0.h[0]\n" "cbnz x27, 87b\n" "88:" // Height 3: Multiply loop: No odd multiplies "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" @@ -1358,7 +1358,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "add x26, x26, #0x10\n" "fmla v24.8h, v6.8h, v2.h[0]\n" "fmla v26.8h, v6.8h, v3.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q29, [x10, #0xe0]\n" "add x25, x25, #0x10\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "fmla v23.8h, v7.8h, v1.h[0]\n" @@ -1366,7 +1366,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "add x23, x23, #0x10\n" "fmla v25.8h, v7.8h, v2.h[0]\n" "fmla v27.8h, v7.8h, v3.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q28, [x10, #0xf0]\n" "cmp x27, #0x10\n" "fmla v20.8h, v8.8h, v0.h[1]\n" "fmla v22.8h, v8.8h, v1.h[1]\n" @@ -1433,18 +1433,18 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v25.8h, v19.8h, v2.h[6]\n" "fmla v27.8h, v19.8h, v3.h[6]\n" "ldr q19, [x10, #0xd0]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v24.8h, v6.8h, v2.h[7]\n" - "fmla v26.8h, v6.8h, v3.h[7]\n" + "fmla v20.8h, v29.8h, v0.h[7]\n" + "fmla v22.8h, v29.8h, v1.h[7]\n" + "fmla v24.8h, v29.8h, v2.h[7]\n" + "fmla v26.8h, v29.8h, v3.h[7]\n" "ldr q6, [x10, #0x0]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" + "fmla v21.8h, v28.8h, v0.h[7]\n" "ldr q0, [x26, #0x0]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" + "fmla v23.8h, v28.8h, v1.h[7]\n" "ldr q1, [x25, #0x0]\n" - "fmla v25.8h, v7.8h, v2.h[7]\n" + "fmla v25.8h, v28.8h, v2.h[7]\n" "ldr q2, [x24, #0x0]\n" - "fmla v27.8h, v7.8h, v3.h[7]\n" + "fmla v27.8h, v28.8h, v3.h[7]\n" "ldr q3, [x23, #0x0]\n" "ldr q7, [x10, #0x10]\n" "bge 117b\n" @@ -1455,7 +1455,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "add x25, x25, #0x10\n" "fmla v24.8h, v6.8h, v2.h[0]\n" "fmla v26.8h, v6.8h, v3.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q29, [x10, #0xe0]\n" "add x24, x24, #0x10\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "fmla v23.8h, v7.8h, v1.h[0]\n" @@ -1463,7 +1463,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "sub x27, x27, #0x8\n" "fmla v25.8h, v7.8h, v2.h[0]\n" "fmla v27.8h, v7.8h, v3.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q28, [x10, #0xf0]\n" "prfm pldl1keep, [x26, #0x80]\n" "fmla v20.8h, v8.8h, v0.h[1]\n" "fmla v22.8h, v8.8h, v1.h[1]\n" @@ -1517,33 +1517,33 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v23.8h, v19.8h, v1.h[6]\n" "fmla v25.8h, v19.8h, v2.h[6]\n" "fmla v27.8h, v19.8h, v3.h[6]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v24.8h, v6.8h, v2.h[7]\n" - "fmla v26.8h, v6.8h, v3.h[7]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" - "fmla v25.8h, v7.8h, v2.h[7]\n" - "fmla v27.8h, v7.8h, v3.h[7]\n" + "fmla v20.8h, v29.8h, v0.h[7]\n" + "fmla v22.8h, v29.8h, v1.h[7]\n" + "fmla v24.8h, v29.8h, v2.h[7]\n" + "fmla v26.8h, v29.8h, v3.h[7]\n" + "fmla v21.8h, v28.8h, v0.h[7]\n" + "fmla v23.8h, v28.8h, v1.h[7]\n" + "fmla v25.8h, v28.8h, v2.h[7]\n" + "fmla v27.8h, v28.8h, v3.h[7]\n" "119:" // Height 4: Multiply loop: Main loop skip "cbz x27, 121f\n" "120:" // Height 4: Multiply loop: Odd block loop - "ldr h0, [x26], #0x2\n" - "ldr h1, [x25], #0x2\n" + "ldr h3, [x26], #0x2\n" + "ldr h2, [x25], #0x2\n" "sub x27, x27, #0x1\n" - "ldr h2, [x24], #0x2\n" - "ldr h3, [x23], #0x2\n" - "ldr q8, [x10, #0x0]\n" - "ldr q9, [x10, #0x10]\n" + "ldr h1, [x24], #0x2\n" + "ldr h0, [x23], #0x2\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" "add x10, x10, #0x20\n" - "fmla v20.8h, v8.8h, v0.h[0]\n" - "fmla v22.8h, v8.8h, v1.h[0]\n" - "fmla v24.8h, v8.8h, v2.h[0]\n" - "fmla v26.8h, v8.8h, v3.h[0]\n" - "fmla v21.8h, v9.8h, v0.h[0]\n" - "fmla v23.8h, v9.8h, v1.h[0]\n" - "fmla v25.8h, v9.8h, v2.h[0]\n" - "fmla v27.8h, v9.8h, v3.h[0]\n" + "fmla v20.8h, v17.8h, v3.h[0]\n" + "fmla v22.8h, v17.8h, v2.h[0]\n" + "fmla v24.8h, v17.8h, v1.h[0]\n" + "fmla v26.8h, v17.8h, v0.h[0]\n" + "fmla v21.8h, v16.8h, v3.h[0]\n" + "fmla v23.8h, v16.8h, v2.h[0]\n" + "fmla v25.8h, v16.8h, v1.h[0]\n" + "fmla v27.8h, v16.8h, v0.h[0]\n" "cbnz x27, 120b\n" "121:" // Height 4: Multiply loop: No odd multiplies "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" @@ -1912,7 +1912,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "add x25, x25, #0x10\n" "add x24, x24, #0x10\n" "fmla v28.8h, v6.8h, v4.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q31, [x10, #0xe0]\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "add x23, x23, #0x10\n" "fmla v23.8h, v7.8h, v1.h[0]\n" @@ -1921,7 +1921,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "cmp x27, #0x10\n" "fmla v27.8h, v7.8h, v3.h[0]\n" "fmla v29.8h, v7.8h, v4.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q30, [x10, #0xf0]\n" "add x10, x10, #0x100\n" "fmla v20.8h, v8.8h, v0.h[1]\n" "fmla v22.8h, v8.8h, v1.h[1]\n" @@ -2000,21 +2000,21 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v27.8h, v19.8h, v3.h[6]\n" "fmla v29.8h, v19.8h, v4.h[6]\n" "ldr q19, [x10, #0xd0]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v24.8h, v6.8h, v2.h[7]\n" - "fmla v26.8h, v6.8h, v3.h[7]\n" - "fmla v28.8h, v6.8h, v4.h[7]\n" + "fmla v20.8h, v31.8h, v0.h[7]\n" + "fmla v22.8h, v31.8h, v1.h[7]\n" + "fmla v24.8h, v31.8h, v2.h[7]\n" + "fmla v26.8h, v31.8h, v3.h[7]\n" + "fmla v28.8h, v31.8h, v4.h[7]\n" "ldr q6, [x10, #0x0]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" + "fmla v21.8h, v30.8h, v0.h[7]\n" "ldr q0, [x26, #0x0]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" + "fmla v23.8h, v30.8h, v1.h[7]\n" "ldr q1, [x25, #0x0]\n" - "fmla v25.8h, v7.8h, v2.h[7]\n" + "fmla v25.8h, v30.8h, v2.h[7]\n" "ldr q2, [x24, #0x0]\n" - "fmla v27.8h, v7.8h, v3.h[7]\n" + "fmla v27.8h, v30.8h, v3.h[7]\n" "ldr q3, [x23, #0x0]\n" - "fmla v29.8h, v7.8h, v4.h[7]\n" + "fmla v29.8h, v30.8h, v4.h[7]\n" "ldr q4, [x22, #0x0]\n" "ldr q7, [x10, #0x10]\n" "bge 150b\n" @@ -2028,7 +2028,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "add x24, x24, #0x10\n" "add x23, x23, #0x10\n" "fmla v28.8h, v6.8h, v4.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q31, [x10, #0xe0]\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "add x22, x22, #0x10\n" "fmla v23.8h, v7.8h, v1.h[0]\n" @@ -2037,7 +2037,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "prfm pldl1keep, [x26, #0x80]\n" "fmla v27.8h, v7.8h, v3.h[0]\n" "fmla v29.8h, v7.8h, v4.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q30, [x10, #0xf0]\n" "prfm pldl1keep, [x25, #0x80]\n" "fmla v20.8h, v8.8h, v0.h[1]\n" "fmla v22.8h, v8.8h, v1.h[1]\n" @@ -2103,38 +2103,38 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v25.8h, v19.8h, v2.h[6]\n" "fmla v27.8h, v19.8h, v3.h[6]\n" "fmla v29.8h, v19.8h, v4.h[6]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v24.8h, v6.8h, v2.h[7]\n" - "fmla v26.8h, v6.8h, v3.h[7]\n" - "fmla v28.8h, v6.8h, v4.h[7]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" - "fmla v25.8h, v7.8h, v2.h[7]\n" - "fmla v27.8h, v7.8h, v3.h[7]\n" - "fmla v29.8h, v7.8h, v4.h[7]\n" + "fmla v20.8h, v31.8h, v0.h[7]\n" + "fmla v22.8h, v31.8h, v1.h[7]\n" + "fmla v24.8h, v31.8h, v2.h[7]\n" + "fmla v26.8h, v31.8h, v3.h[7]\n" + "fmla v28.8h, v31.8h, v4.h[7]\n" + "fmla v21.8h, v30.8h, v0.h[7]\n" + "fmla v23.8h, v30.8h, v1.h[7]\n" + "fmla v25.8h, v30.8h, v2.h[7]\n" + "fmla v27.8h, v30.8h, v3.h[7]\n" + "fmla v29.8h, v30.8h, v4.h[7]\n" "152:" // Height 5: Multiply loop: Main loop skip "cbz x27, 154f\n" "153:" // Height 5: Multiply loop: Odd block loop - "ldr h0, [x26], #0x2\n" - "ldr h1, [x25], #0x2\n" + "ldr h4, [x26], #0x2\n" + "ldr h3, [x25], #0x2\n" "sub x27, x27, #0x1\n" "ldr h2, [x24], #0x2\n" - "ldr h3, [x23], #0x2\n" - "ldr h4, [x22], #0x2\n" - "ldr q8, [x10, #0x0]\n" - "ldr q9, [x10, #0x10]\n" + "ldr h1, [x23], #0x2\n" + "ldr h0, [x22], #0x2\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" "add x10, x10, #0x20\n" - "fmla v20.8h, v8.8h, v0.h[0]\n" - "fmla v22.8h, v8.8h, v1.h[0]\n" - "fmla v24.8h, v8.8h, v2.h[0]\n" - "fmla v26.8h, v8.8h, v3.h[0]\n" - "fmla v28.8h, v8.8h, v4.h[0]\n" - "fmla v21.8h, v9.8h, v0.h[0]\n" - "fmla v23.8h, v9.8h, v1.h[0]\n" - "fmla v25.8h, v9.8h, v2.h[0]\n" - "fmla v27.8h, v9.8h, v3.h[0]\n" - "fmla v29.8h, v9.8h, v4.h[0]\n" + "fmla v20.8h, v17.8h, v4.h[0]\n" + "fmla v22.8h, v17.8h, v3.h[0]\n" + "fmla v24.8h, v17.8h, v2.h[0]\n" + "fmla v26.8h, v17.8h, v1.h[0]\n" + "fmla v28.8h, v17.8h, v0.h[0]\n" + "fmla v21.8h, v16.8h, v4.h[0]\n" + "fmla v23.8h, v16.8h, v3.h[0]\n" + "fmla v25.8h, v16.8h, v2.h[0]\n" + "fmla v27.8h, v16.8h, v1.h[0]\n" + "fmla v29.8h, v16.8h, v0.h[0]\n" "cbnz x27, 153b\n" "154:" // Height 5: Multiply loop: No odd multiplies "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" @@ -2797,28 +2797,28 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "185:" // Height 6: Multiply loop: Main loop skip "cbz x27, 187f\n" "186:" // Height 6: Multiply loop: Odd block loop - "ldr h0, [x26], #0x2\n" - "ldr h1, [x25], #0x2\n" + "ldr h5, [x26], #0x2\n" + "ldr h4, [x25], #0x2\n" "sub x27, x27, #0x1\n" - "ldr h2, [x24], #0x2\n" - "ldr h3, [x23], #0x2\n" - "ldr h4, [x22], #0x2\n" - "ldr h5, [x21], #0x2\n" - "ldr q8, [x10, #0x0]\n" - "ldr q9, [x10, #0x10]\n" + "ldr h3, [x24], #0x2\n" + "ldr h2, [x23], #0x2\n" + "ldr h1, [x22], #0x2\n" + "ldr h0, [x21], #0x2\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" "add x10, x10, #0x20\n" - "fmla v20.8h, v8.8h, v0.h[0]\n" - "fmla v22.8h, v8.8h, v1.h[0]\n" - "fmla v24.8h, v8.8h, v2.h[0]\n" - "fmla v26.8h, v8.8h, v3.h[0]\n" - "fmla v28.8h, v8.8h, v4.h[0]\n" - "fmla v30.8h, v8.8h, v5.h[0]\n" - "fmla v21.8h, v9.8h, v0.h[0]\n" - "fmla v23.8h, v9.8h, v1.h[0]\n" - "fmla v25.8h, v9.8h, v2.h[0]\n" - "fmla v27.8h, v9.8h, v3.h[0]\n" - "fmla v29.8h, v9.8h, v4.h[0]\n" - "fmla v31.8h, v9.8h, v5.h[0]\n" + "fmla v20.8h, v17.8h, v5.h[0]\n" + "fmla v22.8h, v17.8h, v4.h[0]\n" + "fmla v24.8h, v17.8h, v3.h[0]\n" + "fmla v26.8h, v17.8h, v2.h[0]\n" + "fmla v28.8h, v17.8h, v1.h[0]\n" + "fmla v30.8h, v17.8h, v0.h[0]\n" + "fmla v21.8h, v16.8h, v5.h[0]\n" + "fmla v23.8h, v16.8h, v4.h[0]\n" + "fmla v25.8h, v16.8h, v3.h[0]\n" + "fmla v27.8h, v16.8h, v2.h[0]\n" + "fmla v29.8h, v16.8h, v1.h[0]\n" + "fmla v31.8h, v16.8h, v0.h[0]\n" "cbnz x27, 186b\n" "187:" // Height 6: Multiply loop: No odd multiplies "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h similarity index 78% rename from kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h rename to kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h index a70de290a102270350902e67bbc42a3c786430a6..b61c2c383b65b449804799f676e362c35294b975 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h @@ -28,35 +28,35 @@ extern "C" { /// The starting row index must be divisible by `m_step`. /// /// @return The m step value. -size_t kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void); +size_t kai_get_m_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void); /// Gets n step value. /// /// The starting column index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void); +size_t kai_get_n_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void); /// Gets nr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The nr value. -size_t kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void); +size_t kai_get_nr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void); /// Gets kr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The kr value. -size_t kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void); +size_t kai_get_kr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void); /// Gets sr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The sr value. -size_t kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void); +size_t kai_get_sr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void); /// Gets the offset in bytes to the data element in the LHS matrix buffer. /// @@ -64,7 +64,7 @@ size_t kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void); /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(size_t m_idx, size_t stride); +size_t kai_get_lhs_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(size_t m_idx, size_t stride); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// @@ -72,7 +72,7 @@ size_t kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(s /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// @@ -81,7 +81,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neo /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( +size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla( size_t m_idx, size_t n_idx, size_t stride); /// Gets the size in bytes of the destination matrix buffer. @@ -90,16 +90,16 @@ size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( /// @param[in] n Number of columns. /// /// @return The size in bytes of the destination matrix buffer. -size_t kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(size_t m, size_t n); /// Runs the matrix multiplication microkernel followed by a clamp operation. /// /// The pointer of each buffers (LHS, packed RHS and output) needs to be added with offset /// calculated using the following functions: /// -/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla. -/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla. -/// * Output: @ref kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla. +/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla. /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. @@ -112,7 +112,7 @@ size_t kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(siz /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(__fp16) /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. -void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( +void kai_run_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla( size_t m, size_t n, size_t k, // const void* lhs, size_t lhs_stride, // const void* rhs_packed, // diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index ac534645bae2bfecc68839bb2c6a30c36c64390f..216c8a1e69ef6cd84a3511a1d9698bbe5cdab68f 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -33,7 +33,7 @@ #include "test/reference/pack.hpp" // matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla -#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" // matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa @@ -340,15 +340,15 @@ static const std::array matmul_methods = { .bias_format = DataFormat(DataType::FP16), .fn_get_mr = nullptr, - .fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + .fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + .fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, - .fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, .fn_get_packed_lhs_size = nullptr, .fn_get_packed_lhs_offset = nullptr, .fn_pack_lhs = nullptr, @@ -356,15 +356,15 @@ static const std::array matmul_methods = { .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, .fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, - .fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, .fn_matmul_f32_f32p_f32p = nullptr, },