From 351321b303b42e90e0d62f5c6982f4c5ab17301f Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Fri, 2 Aug 2024 13:21:52 +0100 Subject: [PATCH 1/3] Rename FP16 GEMM micro kernels Redundant or unused information is removed from the file name. The micro kernels are regenerated. Signed-off-by: Felix Thomasmathibalan --- CMakeLists.txt | 2 +- .../matmul_clamp_f16_f16_f16p/CMakeLists.txt | 2 +- .../matmul_clamp_f16_f16_f16p.cpp | 24 +- kai/ukernels/matmul/BUILD.bazel | 4 +- ..._f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c} | 326 +++++++++--------- ..._f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h} | 26 +- test/tests/matmul_test.cpp | 22 +- 7 files changed, 203 insertions(+), 203 deletions(-) rename kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/{kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c => kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c} (93%) rename kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/{kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h => kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h} (78%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 80adb858..344638d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,7 +82,7 @@ set(KLEIDIAI_FILES_SCALAR set(KLEIDIAI_FILES_NEON_FP16 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c - kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c + kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c ) set(KLEIDIAI_FILES_NEON_DOTPROD diff --git a/examples/matmul_clamp_f16_f16_f16p/CMakeLists.txt b/examples/matmul_clamp_f16_f16_f16p/CMakeLists.txt index 487f69d1..2d3da701 100644 --- a/examples/matmul_clamp_f16_f16_f16p/CMakeLists.txt +++ b/examples/matmul_clamp_f16_f16_f16p/CMakeLists.txt @@ -20,7 +20,7 @@ include_directories( # Files requires to build the executable add_executable(matmul_clamp_f16_f16_f16p matmul_clamp_f16_f16_f16p.cpp - ${MATMUL_PATH}/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c + ${MATMUL_PATH}/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c ${MATMUL_PACK_PATH}/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c ) diff --git a/examples/matmul_clamp_f16_f16_f16p/matmul_clamp_f16_f16_f16p.cpp b/examples/matmul_clamp_f16_f16_f16p/matmul_clamp_f16_f16_f16p.cpp index 725d91e5..1e1975da 100644 --- a/examples/matmul_clamp_f16_f16_f16p/matmul_clamp_f16_f16_f16p.cpp +++ b/examples/matmul_clamp_f16_f16_f16p/matmul_clamp_f16_f16_f16p.cpp @@ -25,8 +25,8 @@ #include // Include micro-kernel variants -#include "kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" #include "kai_matmul_clamp_f16_f16_f16p_interface.h" +#include "kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c" #include "kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" #define FLOAT16_MIN (-65504) @@ -35,16 +35,16 @@ namespace { /// Micro-kernel interface constexpr kai_matmul_clamp_f16_f16_f16p_ukernel ukernel{ - kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla}; + kai_get_m_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_n_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_nr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_kr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_sr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_lhs_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_dst_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_get_dst_size_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + kai_run_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla}; /// Reference implementation of matrix multiplication void run_matmul_ref( @@ -209,7 +209,7 @@ int main() { const bool is_valid = is_output_correct(M, N, 0.0001, dst_ref, dst); std::cout << "TEST[matmul_clamp_f16_f16_f16p]\n"; - std::cout << "- ukernel: matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla\n"; + std::cout << "- ukernel: matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla\n"; if (is_valid) { std::cout << "- Status: PASSED\n"; std::cout << "- Performance: " << time_matmul.count() << "ns\n"; diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 013a96c3..351109e2 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -25,8 +25,8 @@ kai_c_library( kai_c_library( name = "clamp_f16_f16_f16p", - srcs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c"], - hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h"], + srcs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c"], + hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c"], cpu_uarch = kai_cpu_fp16(), deps = [ ":clamp_f16_f16_f16p_interface", diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c similarity index 93% rename from kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c rename to kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c index 9f9d5a75..840258ad 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c @@ -9,7 +9,7 @@ #error This file must be compiled for AArch64, FEAT_FP16. #else // Architectural features check. -#include "kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" +#include "kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h" #include #include @@ -22,39 +22,39 @@ static const size_t kai_nr = 16; static const size_t kai_kr = 1; static const size_t kai_sr = 1; -size_t kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void) { +size_t kai_get_m_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void) { return kai_mr; } -size_t kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void) { +size_t kai_get_n_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void) { return kai_nr; } -size_t kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void) { +size_t kai_get_nr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void) { +size_t kai_get_kr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void) { +size_t kai_get_sr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void) { return kai_sr; } -size_t kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(size_t m_idx, size_t stride) { +size_t kai_get_lhs_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(size_t m_idx, size_t stride) { KAI_ASSUME(m_idx % kai_mr == 0); return m_idx * stride; } -size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_nr == 0); return n_idx / kai_nr * (kai_nr * sizeof(__fp16) + kai_nr * k * sizeof(__fp16)); } -size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( +size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla( size_t m_idx, size_t n_idx, size_t stride) { KAI_ASSUME(m_idx % kai_mr == 0); KAI_ASSUME(n_idx % kai_nr == 0); @@ -62,11 +62,11 @@ size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( return m_idx * stride + n_idx * sizeof(__fp16); } -size_t kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(size_t m, size_t n) { return m * n * sizeof(__fp16); } -void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( +void kai_run_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla( size_t m, size_t n, size_t k, // const void* lhs, size_t lhs_stride, // const void* rhs_packed, // @@ -235,9 +235,9 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "blt 19f\n" "18:" // Height 1: Multiply loop: Main loop head "fmla v20.8h, v6.8h, v0.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q23, [x10, #0xe0]\n" "fmla v21.8h, v7.8h, v0.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q22, [x10, #0xf0]\n" "sub x27, x27, #0x8\n" "add x26, x26, #0x10\n" "cmp x27, #0x10\n" @@ -267,17 +267,17 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "ldr q18, [x10, #0xc0]\n" "fmla v21.8h, v19.8h, v0.h[6]\n" "ldr q19, [x10, #0xd0]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" + "fmla v20.8h, v23.8h, v0.h[7]\n" "ldr q6, [x10, #0x0]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" + "fmla v21.8h, v22.8h, v0.h[7]\n" "ldr q0, [x26, #0x0]\n" "ldr q7, [x10, #0x10]\n" "bge 18b\n" "19:" // Height 1: Multiply loop: Single iteration only "fmla v20.8h, v6.8h, v0.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q23, [x10, #0xe0]\n" "fmla v21.8h, v7.8h, v0.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q22, [x10, #0xf0]\n" "add x26, x26, #0x10\n" "sub x27, x27, #0x8\n" "add x10, x10, #0x100\n" @@ -294,18 +294,18 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v21.8h, v17.8h, v0.h[5]\n" "fmla v20.8h, v18.8h, v0.h[6]\n" "fmla v21.8h, v19.8h, v0.h[6]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" + "fmla v20.8h, v23.8h, v0.h[7]\n" + "fmla v21.8h, v22.8h, v0.h[7]\n" "20:" // Height 1: Multiply loop: Main loop skip "cbz x27, 22f\n" "21:" // Height 1: Multiply loop: Odd block loop "ldr h0, [x26], #0x2\n" - "ldr q8, [x10, #0x0]\n" + "ldr q17, [x10, #0x0]\n" "sub x27, x27, #0x1\n" - "ldr q9, [x10, #0x10]\n" + "ldr q16, [x10, #0x10]\n" "add x10, x10, #0x20\n" - "fmla v20.8h, v8.8h, v0.h[0]\n" - "fmla v21.8h, v9.8h, v0.h[0]\n" + "fmla v20.8h, v17.8h, v0.h[0]\n" + "fmla v21.8h, v16.8h, v0.h[0]\n" "cbnz x27, 21b\n" "22:" // Height 1: Multiply loop: No odd multiplies "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" @@ -519,11 +519,11 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "51:" // Height 2: Multiply loop: Main loop head "fmla v20.8h, v6.8h, v0.h[0]\n" "fmla v22.8h, v6.8h, v1.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q25, [x10, #0xe0]\n" "sub x27, x27, #0x8\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "fmla v23.8h, v7.8h, v1.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q24, [x10, #0xf0]\n" "add x26, x26, #0x10\n" "add x25, x25, #0x10\n" "cmp x27, #0x10\n" @@ -566,23 +566,23 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v21.8h, v19.8h, v0.h[6]\n" "fmla v23.8h, v19.8h, v1.h[6]\n" "ldr q19, [x10, #0xd0]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" + "fmla v20.8h, v25.8h, v0.h[7]\n" + "fmla v22.8h, v25.8h, v1.h[7]\n" "ldr q6, [x10, #0x0]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" + "fmla v21.8h, v24.8h, v0.h[7]\n" "ldr q0, [x26, #0x0]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" + "fmla v23.8h, v24.8h, v1.h[7]\n" "ldr q1, [x25, #0x0]\n" "ldr q7, [x10, #0x10]\n" "bge 51b\n" "52:" // Height 2: Multiply loop: Single iteration only "fmla v20.8h, v6.8h, v0.h[0]\n" "fmla v22.8h, v6.8h, v1.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q25, [x10, #0xe0]\n" "add x26, x26, #0x10\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "fmla v23.8h, v7.8h, v1.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q24, [x10, #0xf0]\n" "add x25, x25, #0x10\n" "sub x27, x27, #0x8\n" "prfm pldl1keep, [x26, #0x80]\n" @@ -612,23 +612,23 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v22.8h, v18.8h, v1.h[6]\n" "fmla v21.8h, v19.8h, v0.h[6]\n" "fmla v23.8h, v19.8h, v1.h[6]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" + "fmla v20.8h, v25.8h, v0.h[7]\n" + "fmla v22.8h, v25.8h, v1.h[7]\n" + "fmla v21.8h, v24.8h, v0.h[7]\n" + "fmla v23.8h, v24.8h, v1.h[7]\n" "53:" // Height 2: Multiply loop: Main loop skip "cbz x27, 55f\n" "54:" // Height 2: Multiply loop: Odd block loop - "ldr h0, [x26], #0x2\n" - "ldr h1, [x25], #0x2\n" + "ldr h1, [x26], #0x2\n" + "ldr h0, [x25], #0x2\n" "sub x27, x27, #0x1\n" - "ldr q8, [x10, #0x0]\n" - "ldr q9, [x10, #0x10]\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" "add x10, x10, #0x20\n" - "fmla v20.8h, v8.8h, v0.h[0]\n" - "fmla v22.8h, v8.8h, v1.h[0]\n" - "fmla v21.8h, v9.8h, v0.h[0]\n" - "fmla v23.8h, v9.8h, v1.h[0]\n" + "fmla v20.8h, v17.8h, v1.h[0]\n" + "fmla v22.8h, v17.8h, v0.h[0]\n" + "fmla v21.8h, v16.8h, v1.h[0]\n" + "fmla v23.8h, v16.8h, v0.h[0]\n" "cbnz x27, 54b\n" "55:" // Height 2: Multiply loop: No odd multiplies "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" @@ -895,12 +895,12 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "sub x27, x27, #0x8\n" "add x26, x26, #0x10\n" "fmla v24.8h, v6.8h, v2.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q27, [x10, #0xe0]\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "add x25, x25, #0x10\n" "fmla v23.8h, v7.8h, v1.h[0]\n" "fmla v25.8h, v7.8h, v2.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q26, [x10, #0xf0]\n" "add x24, x24, #0x10\n" "cmp x27, #0x10\n" "add x10, x10, #0x100\n" @@ -955,15 +955,15 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v23.8h, v19.8h, v1.h[6]\n" "fmla v25.8h, v19.8h, v2.h[6]\n" "ldr q19, [x10, #0xd0]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v24.8h, v6.8h, v2.h[7]\n" + "fmla v20.8h, v27.8h, v0.h[7]\n" + "fmla v22.8h, v27.8h, v1.h[7]\n" + "fmla v24.8h, v27.8h, v2.h[7]\n" "ldr q6, [x10, #0x0]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" + "fmla v21.8h, v26.8h, v0.h[7]\n" "ldr q0, [x26, #0x0]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" + "fmla v23.8h, v26.8h, v1.h[7]\n" "ldr q1, [x25, #0x0]\n" - "fmla v25.8h, v7.8h, v2.h[7]\n" + "fmla v25.8h, v26.8h, v2.h[7]\n" "ldr q2, [x24, #0x0]\n" "ldr q7, [x10, #0x10]\n" "bge 84b\n" @@ -973,12 +973,12 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "add x26, x26, #0x10\n" "add x25, x25, #0x10\n" "fmla v24.8h, v6.8h, v2.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q27, [x10, #0xe0]\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "add x24, x24, #0x10\n" "fmla v23.8h, v7.8h, v1.h[0]\n" "fmla v25.8h, v7.8h, v2.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q26, [x10, #0xf0]\n" "prfm pldl1keep, [x26, #0x80]\n" "sub x27, x27, #0x8\n" "prfm pldl1keep, [x25, #0x80]\n" @@ -1020,28 +1020,28 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v21.8h, v19.8h, v0.h[6]\n" "fmla v23.8h, v19.8h, v1.h[6]\n" "fmla v25.8h, v19.8h, v2.h[6]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v24.8h, v6.8h, v2.h[7]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" - "fmla v25.8h, v7.8h, v2.h[7]\n" + "fmla v20.8h, v27.8h, v0.h[7]\n" + "fmla v22.8h, v27.8h, v1.h[7]\n" + "fmla v24.8h, v27.8h, v2.h[7]\n" + "fmla v21.8h, v26.8h, v0.h[7]\n" + "fmla v23.8h, v26.8h, v1.h[7]\n" + "fmla v25.8h, v26.8h, v2.h[7]\n" "86:" // Height 3: Multiply loop: Main loop skip "cbz x27, 88f\n" "87:" // Height 3: Multiply loop: Odd block loop - "ldr h0, [x26], #0x2\n" + "ldr h2, [x26], #0x2\n" "ldr h1, [x25], #0x2\n" "sub x27, x27, #0x1\n" - "ldr h2, [x24], #0x2\n" - "ldr q8, [x10, #0x0]\n" - "ldr q9, [x10, #0x10]\n" + "ldr h0, [x24], #0x2\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" "add x10, x10, #0x20\n" - "fmla v20.8h, v8.8h, v0.h[0]\n" - "fmla v22.8h, v8.8h, v1.h[0]\n" - "fmla v24.8h, v8.8h, v2.h[0]\n" - "fmla v21.8h, v9.8h, v0.h[0]\n" - "fmla v23.8h, v9.8h, v1.h[0]\n" - "fmla v25.8h, v9.8h, v2.h[0]\n" + "fmla v20.8h, v17.8h, v2.h[0]\n" + "fmla v22.8h, v17.8h, v1.h[0]\n" + "fmla v24.8h, v17.8h, v0.h[0]\n" + "fmla v21.8h, v16.8h, v2.h[0]\n" + "fmla v23.8h, v16.8h, v1.h[0]\n" + "fmla v25.8h, v16.8h, v0.h[0]\n" "cbnz x27, 87b\n" "88:" // Height 3: Multiply loop: No odd multiplies "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" @@ -1358,7 +1358,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "add x26, x26, #0x10\n" "fmla v24.8h, v6.8h, v2.h[0]\n" "fmla v26.8h, v6.8h, v3.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q29, [x10, #0xe0]\n" "add x25, x25, #0x10\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "fmla v23.8h, v7.8h, v1.h[0]\n" @@ -1366,7 +1366,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "add x23, x23, #0x10\n" "fmla v25.8h, v7.8h, v2.h[0]\n" "fmla v27.8h, v7.8h, v3.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q28, [x10, #0xf0]\n" "cmp x27, #0x10\n" "fmla v20.8h, v8.8h, v0.h[1]\n" "fmla v22.8h, v8.8h, v1.h[1]\n" @@ -1433,18 +1433,18 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v25.8h, v19.8h, v2.h[6]\n" "fmla v27.8h, v19.8h, v3.h[6]\n" "ldr q19, [x10, #0xd0]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v24.8h, v6.8h, v2.h[7]\n" - "fmla v26.8h, v6.8h, v3.h[7]\n" + "fmla v20.8h, v29.8h, v0.h[7]\n" + "fmla v22.8h, v29.8h, v1.h[7]\n" + "fmla v24.8h, v29.8h, v2.h[7]\n" + "fmla v26.8h, v29.8h, v3.h[7]\n" "ldr q6, [x10, #0x0]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" + "fmla v21.8h, v28.8h, v0.h[7]\n" "ldr q0, [x26, #0x0]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" + "fmla v23.8h, v28.8h, v1.h[7]\n" "ldr q1, [x25, #0x0]\n" - "fmla v25.8h, v7.8h, v2.h[7]\n" + "fmla v25.8h, v28.8h, v2.h[7]\n" "ldr q2, [x24, #0x0]\n" - "fmla v27.8h, v7.8h, v3.h[7]\n" + "fmla v27.8h, v28.8h, v3.h[7]\n" "ldr q3, [x23, #0x0]\n" "ldr q7, [x10, #0x10]\n" "bge 117b\n" @@ -1455,7 +1455,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "add x25, x25, #0x10\n" "fmla v24.8h, v6.8h, v2.h[0]\n" "fmla v26.8h, v6.8h, v3.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q29, [x10, #0xe0]\n" "add x24, x24, #0x10\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "fmla v23.8h, v7.8h, v1.h[0]\n" @@ -1463,7 +1463,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "sub x27, x27, #0x8\n" "fmla v25.8h, v7.8h, v2.h[0]\n" "fmla v27.8h, v7.8h, v3.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q28, [x10, #0xf0]\n" "prfm pldl1keep, [x26, #0x80]\n" "fmla v20.8h, v8.8h, v0.h[1]\n" "fmla v22.8h, v8.8h, v1.h[1]\n" @@ -1517,33 +1517,33 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v23.8h, v19.8h, v1.h[6]\n" "fmla v25.8h, v19.8h, v2.h[6]\n" "fmla v27.8h, v19.8h, v3.h[6]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v24.8h, v6.8h, v2.h[7]\n" - "fmla v26.8h, v6.8h, v3.h[7]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" - "fmla v25.8h, v7.8h, v2.h[7]\n" - "fmla v27.8h, v7.8h, v3.h[7]\n" + "fmla v20.8h, v29.8h, v0.h[7]\n" + "fmla v22.8h, v29.8h, v1.h[7]\n" + "fmla v24.8h, v29.8h, v2.h[7]\n" + "fmla v26.8h, v29.8h, v3.h[7]\n" + "fmla v21.8h, v28.8h, v0.h[7]\n" + "fmla v23.8h, v28.8h, v1.h[7]\n" + "fmla v25.8h, v28.8h, v2.h[7]\n" + "fmla v27.8h, v28.8h, v3.h[7]\n" "119:" // Height 4: Multiply loop: Main loop skip "cbz x27, 121f\n" "120:" // Height 4: Multiply loop: Odd block loop - "ldr h0, [x26], #0x2\n" - "ldr h1, [x25], #0x2\n" + "ldr h3, [x26], #0x2\n" + "ldr h2, [x25], #0x2\n" "sub x27, x27, #0x1\n" - "ldr h2, [x24], #0x2\n" - "ldr h3, [x23], #0x2\n" - "ldr q8, [x10, #0x0]\n" - "ldr q9, [x10, #0x10]\n" + "ldr h1, [x24], #0x2\n" + "ldr h0, [x23], #0x2\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" "add x10, x10, #0x20\n" - "fmla v20.8h, v8.8h, v0.h[0]\n" - "fmla v22.8h, v8.8h, v1.h[0]\n" - "fmla v24.8h, v8.8h, v2.h[0]\n" - "fmla v26.8h, v8.8h, v3.h[0]\n" - "fmla v21.8h, v9.8h, v0.h[0]\n" - "fmla v23.8h, v9.8h, v1.h[0]\n" - "fmla v25.8h, v9.8h, v2.h[0]\n" - "fmla v27.8h, v9.8h, v3.h[0]\n" + "fmla v20.8h, v17.8h, v3.h[0]\n" + "fmla v22.8h, v17.8h, v2.h[0]\n" + "fmla v24.8h, v17.8h, v1.h[0]\n" + "fmla v26.8h, v17.8h, v0.h[0]\n" + "fmla v21.8h, v16.8h, v3.h[0]\n" + "fmla v23.8h, v16.8h, v2.h[0]\n" + "fmla v25.8h, v16.8h, v1.h[0]\n" + "fmla v27.8h, v16.8h, v0.h[0]\n" "cbnz x27, 120b\n" "121:" // Height 4: Multiply loop: No odd multiplies "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" @@ -1912,7 +1912,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "add x25, x25, #0x10\n" "add x24, x24, #0x10\n" "fmla v28.8h, v6.8h, v4.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q31, [x10, #0xe0]\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "add x23, x23, #0x10\n" "fmla v23.8h, v7.8h, v1.h[0]\n" @@ -1921,7 +1921,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "cmp x27, #0x10\n" "fmla v27.8h, v7.8h, v3.h[0]\n" "fmla v29.8h, v7.8h, v4.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q30, [x10, #0xf0]\n" "add x10, x10, #0x100\n" "fmla v20.8h, v8.8h, v0.h[1]\n" "fmla v22.8h, v8.8h, v1.h[1]\n" @@ -2000,21 +2000,21 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v27.8h, v19.8h, v3.h[6]\n" "fmla v29.8h, v19.8h, v4.h[6]\n" "ldr q19, [x10, #0xd0]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v24.8h, v6.8h, v2.h[7]\n" - "fmla v26.8h, v6.8h, v3.h[7]\n" - "fmla v28.8h, v6.8h, v4.h[7]\n" + "fmla v20.8h, v31.8h, v0.h[7]\n" + "fmla v22.8h, v31.8h, v1.h[7]\n" + "fmla v24.8h, v31.8h, v2.h[7]\n" + "fmla v26.8h, v31.8h, v3.h[7]\n" + "fmla v28.8h, v31.8h, v4.h[7]\n" "ldr q6, [x10, #0x0]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" + "fmla v21.8h, v30.8h, v0.h[7]\n" "ldr q0, [x26, #0x0]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" + "fmla v23.8h, v30.8h, v1.h[7]\n" "ldr q1, [x25, #0x0]\n" - "fmla v25.8h, v7.8h, v2.h[7]\n" + "fmla v25.8h, v30.8h, v2.h[7]\n" "ldr q2, [x24, #0x0]\n" - "fmla v27.8h, v7.8h, v3.h[7]\n" + "fmla v27.8h, v30.8h, v3.h[7]\n" "ldr q3, [x23, #0x0]\n" - "fmla v29.8h, v7.8h, v4.h[7]\n" + "fmla v29.8h, v30.8h, v4.h[7]\n" "ldr q4, [x22, #0x0]\n" "ldr q7, [x10, #0x10]\n" "bge 150b\n" @@ -2028,7 +2028,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "add x24, x24, #0x10\n" "add x23, x23, #0x10\n" "fmla v28.8h, v6.8h, v4.h[0]\n" - "ldr q6, [x10, #0xe0]\n" + "ldr q31, [x10, #0xe0]\n" "fmla v21.8h, v7.8h, v0.h[0]\n" "add x22, x22, #0x10\n" "fmla v23.8h, v7.8h, v1.h[0]\n" @@ -2037,7 +2037,7 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "prfm pldl1keep, [x26, #0x80]\n" "fmla v27.8h, v7.8h, v3.h[0]\n" "fmla v29.8h, v7.8h, v4.h[0]\n" - "ldr q7, [x10, #0xf0]\n" + "ldr q30, [x10, #0xf0]\n" "prfm pldl1keep, [x25, #0x80]\n" "fmla v20.8h, v8.8h, v0.h[1]\n" "fmla v22.8h, v8.8h, v1.h[1]\n" @@ -2103,38 +2103,38 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "fmla v25.8h, v19.8h, v2.h[6]\n" "fmla v27.8h, v19.8h, v3.h[6]\n" "fmla v29.8h, v19.8h, v4.h[6]\n" - "fmla v20.8h, v6.8h, v0.h[7]\n" - "fmla v22.8h, v6.8h, v1.h[7]\n" - "fmla v24.8h, v6.8h, v2.h[7]\n" - "fmla v26.8h, v6.8h, v3.h[7]\n" - "fmla v28.8h, v6.8h, v4.h[7]\n" - "fmla v21.8h, v7.8h, v0.h[7]\n" - "fmla v23.8h, v7.8h, v1.h[7]\n" - "fmla v25.8h, v7.8h, v2.h[7]\n" - "fmla v27.8h, v7.8h, v3.h[7]\n" - "fmla v29.8h, v7.8h, v4.h[7]\n" + "fmla v20.8h, v31.8h, v0.h[7]\n" + "fmla v22.8h, v31.8h, v1.h[7]\n" + "fmla v24.8h, v31.8h, v2.h[7]\n" + "fmla v26.8h, v31.8h, v3.h[7]\n" + "fmla v28.8h, v31.8h, v4.h[7]\n" + "fmla v21.8h, v30.8h, v0.h[7]\n" + "fmla v23.8h, v30.8h, v1.h[7]\n" + "fmla v25.8h, v30.8h, v2.h[7]\n" + "fmla v27.8h, v30.8h, v3.h[7]\n" + "fmla v29.8h, v30.8h, v4.h[7]\n" "152:" // Height 5: Multiply loop: Main loop skip "cbz x27, 154f\n" "153:" // Height 5: Multiply loop: Odd block loop - "ldr h0, [x26], #0x2\n" - "ldr h1, [x25], #0x2\n" + "ldr h4, [x26], #0x2\n" + "ldr h3, [x25], #0x2\n" "sub x27, x27, #0x1\n" "ldr h2, [x24], #0x2\n" - "ldr h3, [x23], #0x2\n" - "ldr h4, [x22], #0x2\n" - "ldr q8, [x10, #0x0]\n" - "ldr q9, [x10, #0x10]\n" + "ldr h1, [x23], #0x2\n" + "ldr h0, [x22], #0x2\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" "add x10, x10, #0x20\n" - "fmla v20.8h, v8.8h, v0.h[0]\n" - "fmla v22.8h, v8.8h, v1.h[0]\n" - "fmla v24.8h, v8.8h, v2.h[0]\n" - "fmla v26.8h, v8.8h, v3.h[0]\n" - "fmla v28.8h, v8.8h, v4.h[0]\n" - "fmla v21.8h, v9.8h, v0.h[0]\n" - "fmla v23.8h, v9.8h, v1.h[0]\n" - "fmla v25.8h, v9.8h, v2.h[0]\n" - "fmla v27.8h, v9.8h, v3.h[0]\n" - "fmla v29.8h, v9.8h, v4.h[0]\n" + "fmla v20.8h, v17.8h, v4.h[0]\n" + "fmla v22.8h, v17.8h, v3.h[0]\n" + "fmla v24.8h, v17.8h, v2.h[0]\n" + "fmla v26.8h, v17.8h, v1.h[0]\n" + "fmla v28.8h, v17.8h, v0.h[0]\n" + "fmla v21.8h, v16.8h, v4.h[0]\n" + "fmla v23.8h, v16.8h, v3.h[0]\n" + "fmla v25.8h, v16.8h, v2.h[0]\n" + "fmla v27.8h, v16.8h, v1.h[0]\n" + "fmla v29.8h, v16.8h, v0.h[0]\n" "cbnz x27, 153b\n" "154:" // Height 5: Multiply loop: No odd multiplies "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" @@ -2797,28 +2797,28 @@ void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( "185:" // Height 6: Multiply loop: Main loop skip "cbz x27, 187f\n" "186:" // Height 6: Multiply loop: Odd block loop - "ldr h0, [x26], #0x2\n" - "ldr h1, [x25], #0x2\n" + "ldr h5, [x26], #0x2\n" + "ldr h4, [x25], #0x2\n" "sub x27, x27, #0x1\n" - "ldr h2, [x24], #0x2\n" - "ldr h3, [x23], #0x2\n" - "ldr h4, [x22], #0x2\n" - "ldr h5, [x21], #0x2\n" - "ldr q8, [x10, #0x0]\n" - "ldr q9, [x10, #0x10]\n" + "ldr h3, [x24], #0x2\n" + "ldr h2, [x23], #0x2\n" + "ldr h1, [x22], #0x2\n" + "ldr h0, [x21], #0x2\n" + "ldr q17, [x10, #0x0]\n" + "ldr q16, [x10, #0x10]\n" "add x10, x10, #0x20\n" - "fmla v20.8h, v8.8h, v0.h[0]\n" - "fmla v22.8h, v8.8h, v1.h[0]\n" - "fmla v24.8h, v8.8h, v2.h[0]\n" - "fmla v26.8h, v8.8h, v3.h[0]\n" - "fmla v28.8h, v8.8h, v4.h[0]\n" - "fmla v30.8h, v8.8h, v5.h[0]\n" - "fmla v21.8h, v9.8h, v0.h[0]\n" - "fmla v23.8h, v9.8h, v1.h[0]\n" - "fmla v25.8h, v9.8h, v2.h[0]\n" - "fmla v27.8h, v9.8h, v3.h[0]\n" - "fmla v29.8h, v9.8h, v4.h[0]\n" - "fmla v31.8h, v9.8h, v5.h[0]\n" + "fmla v20.8h, v17.8h, v5.h[0]\n" + "fmla v22.8h, v17.8h, v4.h[0]\n" + "fmla v24.8h, v17.8h, v3.h[0]\n" + "fmla v26.8h, v17.8h, v2.h[0]\n" + "fmla v28.8h, v17.8h, v1.h[0]\n" + "fmla v30.8h, v17.8h, v0.h[0]\n" + "fmla v21.8h, v16.8h, v5.h[0]\n" + "fmla v23.8h, v16.8h, v4.h[0]\n" + "fmla v25.8h, v16.8h, v3.h[0]\n" + "fmla v27.8h, v16.8h, v2.h[0]\n" + "fmla v29.8h, v16.8h, v1.h[0]\n" + "fmla v31.8h, v16.8h, v0.h[0]\n" "cbnz x27, 186b\n" "187:" // Height 6: Multiply loop: No odd multiplies "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h similarity index 78% rename from kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h rename to kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h index a70de290..b61c2c38 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h @@ -28,35 +28,35 @@ extern "C" { /// The starting row index must be divisible by `m_step`. /// /// @return The m step value. -size_t kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void); +size_t kai_get_m_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void); /// Gets n step value. /// /// The starting column index must be divisible by `n_step`. /// /// @return The n step value. -size_t kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void); +size_t kai_get_n_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void); /// Gets nr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The nr value. -size_t kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void); +size_t kai_get_nr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void); /// Gets kr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The kr value. -size_t kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void); +size_t kai_get_kr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void); /// Gets sr value. /// /// This is the packing parameter which must be used to pack the RHS matrix. /// /// @return The sr value. -size_t kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void); +size_t kai_get_sr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(void); /// Gets the offset in bytes to the data element in the LHS matrix buffer. /// @@ -64,7 +64,7 @@ size_t kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(void); /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(size_t m_idx, size_t stride); +size_t kai_get_lhs_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(size_t m_idx, size_t stride); /// Gets the offset in bytes to the data element in the packed RHS matrix buffer. /// @@ -72,7 +72,7 @@ size_t kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(s /// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(size_t n_idx, size_t k); /// Gets the offset in bytes to the data element in the destination matrix buffer. /// @@ -81,7 +81,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neo /// @param[in] stride Row stride in bytes. /// /// @return The offset in bytes to the data element. -size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( +size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla( size_t m_idx, size_t n_idx, size_t stride); /// Gets the size in bytes of the destination matrix buffer. @@ -90,16 +90,16 @@ size_t kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( /// @param[in] n Number of columns. /// /// @return The size in bytes of the destination matrix buffer. -size_t kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla(size_t m, size_t n); /// Runs the matrix multiplication microkernel followed by a clamp operation. /// /// The pointer of each buffers (LHS, packed RHS and output) needs to be added with offset /// calculated using the following functions: /// -/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla. -/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla. -/// * Output: @ref kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla. +/// * LHS: @ref kai_get_lhs_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla. +/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla. +/// * Output: @ref kai_get_dst_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla. /// /// @param[in] m Number of output rows to be computed. /// @param[in] n Number of output columns to be computed. @@ -112,7 +112,7 @@ size_t kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla(siz /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(__fp16) /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. -void kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla( +void kai_run_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla( size_t m, size_t n, size_t k, // const void* lhs, size_t lhs_stride, // const void* rhs_packed, // diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index ac534645..e6eb3359 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -33,7 +33,7 @@ #include "test/reference/pack.hpp" // matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla -#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" // matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa @@ -340,15 +340,15 @@ static const std::array matmul_methods = { .bias_format = DataFormat(DataType::FP16), .fn_get_mr = nullptr, - .fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_get_nr = kai_get_nr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + .fn_get_kr = kai_get_kr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + .fn_get_sr = kai_get_sr_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, - .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_get_main_m_step = kai_get_m_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, .fn_get_pack_rhs_n_step = kai_get_n_step_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_get_main_n_step = kai_get_n_step_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, - .fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_get_lhs_offset = kai_get_lhs_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, .fn_get_packed_lhs_size = nullptr, .fn_get_packed_lhs_offset = nullptr, .fn_pack_lhs = nullptr, @@ -356,15 +356,15 @@ static const std::array matmul_methods = { .fn_get_rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, .fn_get_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, .fn_get_pack_rhs_packed_rhs_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_get_main_packed_rhs_offset = kai_get_rhs_packed_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, .fn_pack_rhs = kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, .fn_get_bias_offset = kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon, - .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, - .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_get_dst_offset = kai_get_dst_offset_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, + .fn_get_dst_size = kai_get_dst_size_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, - .fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla, + .fn_matmul_f16_f16_f16p = kai_run_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla, .fn_matmul_f32_f32p_f32p = nullptr, }, -- GitLab From 6812766cc1d86a306bd01c94af39bac4e95313ca Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Fri, 2 Aug 2024 15:19:57 +0100 Subject: [PATCH 2/3] Fix bazel build Signed-off-by: Felix Thomasmathibalan --- kai/ukernels/matmul/BUILD.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 351109e2..840d10e3 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -26,7 +26,7 @@ kai_c_library( kai_c_library( name = "clamp_f16_f16_f16p", srcs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c"], - hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c"], + hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h"], cpu_uarch = kai_cpu_fp16(), deps = [ ":clamp_f16_f16_f16p_interface", -- GitLab From b2ed3ec873e609626c2f0804cac2be382ebce81b Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Fri, 2 Aug 2024 15:43:57 +0100 Subject: [PATCH 3/3] Fix bazel build Signed-off-by: Felix Thomasmathibalan --- test/tests/matmul_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index e6eb3359..216c8a1e 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -33,7 +33,7 @@ #include "test/reference/pack.hpp" // matmul_nt_nt_fp16_fp16_fp16_6x16_neon_mla -#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.c" +#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16pbiasf16_6x16_neon_fp16_mla.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" // matmul_nt_nt_fp32_fp32_fp32_2vlx2vl_sme2_mopa -- GitLab