diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c index 50c04ad3fe5ecb86a9bcc523f0dccbcd0ba445c3..61b8ba480ec595d630909b915f749a3ad3be6d2b 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c @@ -25,18 +25,19 @@ size_t kai_get_rhs_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(size_t n_idx return n_idx * sizeof(uint16_t); } + size_t kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(size_t n_idx) { return n_idx * sizeof(uint16_t); } -size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(size_t k, size_t n_idx) { +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(size_t n_idx, size_t k) { KAI_ASSUME(n_idx % kai_nr == 0); - return n_idx / kai_nr * (kai_nr * sizeof(uint16_t) + kai_nr * k * sizeof(uint16_t)); + return n_idx * (sizeof(uint16_t) + k * sizeof(uint16_t)); } size_t kai_get_rhs_packed_size_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(size_t n, size_t k) { - return kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(k, kai_roundup(n, kai_nr)); + return kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(kai_roundup(n, kai_nr), k); } void kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon( @@ -65,30 +66,30 @@ void kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon( "mov x21, %x[out]\n" "cmp x22, #0x10\n" "blt 2f\n" - "1:" // Bias: Full row loop body - "ldr q16, [%x[bias], #0x0]\n" - "sub x22, x22, #0x10\n" - "cmp x22, #0x10\n" - "str q16, [x21, #0x0]\n" + "1:" // Bias: Full loop + "ldr q17, [%x[bias], #0x0]\n" "ldr q16, [%x[bias], #0x10]\n" + "sub x22, x22, #0x10\n" "add %x[bias], %x[bias], #0x20\n" + "cmp x22, #0x10\n" + "str q17, [x21, #0x0]\n" "str q16, [x21, #0x10]\n" "add x21, x21, %x[out_stride]\n" "bge 1b\n" - "2:" // Bias: Tail row loop start - "cbz x22, 4f\n" - "3:" // Bias: Tail row loop body + "cbz x22, 3f\n" + "2:" // Bias: Tail loop "ldr h20, [%x[bias], #0x0]\n" "sub x22, x22, #0x1\n" "add %x[bias], %x[bias], #0x2\n" + "cmp x22, #0x0\n" "str h20, [x21]\n" "add x21, x21, #0x2\n" - "cbnz x22, 3b\n" - "4:" // Bias: Done + "bgt 2b\n" + "3:" // Bias: Done "cmp %x[height], #0x4\n" "add %x[out], %x[out], #0x20\n" - "blt 13f\n" - "5:" // Main row loop: Head + "blt 12f\n" + "4:" // Main row loop: Head "mov x25, %x[in]\n" "mov x24, %x[width]\n" "mov x23, %x[out]\n" @@ -98,8 +99,8 @@ void kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon( "add x20, x21, %x[in_stride]\n" "cmp x24, #0x10\n" "add %x[in], x20, %x[in_stride]\n" - "blt 7f\n" - "6:" // Main row loop: Column loop + "blt 6f\n" + "5:" // Main row loop: Column loop "ldr q23, [x25], #0x10\n" "ldr q22, [x22], #0x10\n" "sub x24, x24, #0x10\n" @@ -119,9 +120,9 @@ void kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon( "str q20, [x23, #0x60]\n" "str q16, [x23, #0x70]\n" "add x23, x23, %x[out_stride]\n" - "bge 6b\n" - "7:" // Main row loop: Column loop skip - "cbz x24, 12f\n" + "bge 5b\n" + "6:" // Main row loop: Column loop skip + "cbz x24, 11f\n" "cmp x24, #0x4\n" "movi v16.8h, #0x0\n" "str q16, [x23, #0x0]\n" @@ -132,8 +133,8 @@ void kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon( "str q16, [x23, #0x50]\n" "str q16, [x23, #0x60]\n" "str q16, [x23, #0x70]\n" - "blt 9f\n" - "8:" // Main row loop: width 4 loop: loop + "blt 8f\n" + "7:" // Main row loop: width 4 loop: loop "ldr d19, [x25], #0x8\n" "ldr d18, [x22], #0x8\n" "sub x24, x24, #0x4\n" @@ -145,11 +146,11 @@ void kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon( "str d17, [x23, #0x40]\n" "str d16, [x23, #0x60]\n" "add x23, x23, #0x8\n" - "bge 8b\n" - "9:" // Main row loop: width 4 loop: skip + "bge 7b\n" + "8:" // Main row loop: width 4 loop: skip "cmp x24, #0x1\n" - "blt 11f\n" - "10:" // Main row loop: width 1 loop: loop + "blt 10f\n" + "9:" // Main row loop: width 1 loop: loop "ldr h19, [x25], #0x2\n" "ldr h18, [x22], #0x2\n" "sub x24, x24, #0x1\n" @@ -161,23 +162,23 @@ void kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon( "str h17, [x23, #0x40]\n" "str h16, [x23, #0x60]\n" "add x23, x23, #0x2\n" - "bge 10b\n" - "11:" // Main row loop: width 1 loop: skip - "12:" // Main row loop: odd col skip + "bge 9b\n" + "10:" // Main row loop: width 1 loop: skip + "11:" // Main row loop: odd col skip "cmp %x[height], #0x4\n" "add %x[out], %x[out], #0x80\n" - "bge 5b\n" - "cbz %x[height], 22f\n" - "13:" // Main loop skip - "14:" // Tail row loop: Head + "bge 4b\n" + "cbz %x[height], 21f\n" + "12:" // Main loop skip + "13:" // Tail row loop: Head "mov x20, %x[width]\n" "mov x25, %x[in]\n" "mov x23, %x[out]\n" "sub %x[height], %x[height], #0x1\n" "cmp x20, #0x10\n" "add %x[in], x25, %x[in_stride]\n" - "blt 16f\n" - "15:" // Tail row loop: Column loop + "blt 15f\n" + "14:" // Tail row loop: Column loop "ldr q17, [x25], #0x10\n" "sub x20, x20, #0x10\n" "ldr q16, [x25], #0x10\n" @@ -185,37 +186,37 @@ void kai_run_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon( "str q17, [x23, #0x0]\n" "str q16, [x23, #0x10]\n" "add x23, x23, %x[out_stride]\n" - "bge 15b\n" - "16:" // Tail row loop: Column loop skip - "cbz x20, 21f\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cbz x20, 20f\n" "cmp x20, #0x4\n" "movi v16.8h, #0x0\n" "str q16, [x23, #0x0]\n" "str q16, [x23, #0x10]\n" - "blt 18f\n" - "17:" // Tail row loop: width 4 loop: loop + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop "ldr d16, [x25], #0x8\n" "sub x20, x20, #0x4\n" "cmp x20, #0x4\n" "str d16, [x23, #0x0]\n" "add x23, x23, #0x8\n" - "bge 17b\n" - "18:" // Tail row loop: width 4 loop: skip + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip "cmp x20, #0x1\n" - "blt 20f\n" - "19:" // Tail row loop: width 1 loop: loop + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop "ldr h16, [x25], #0x2\n" "sub x20, x20, #0x1\n" "cmp x20, #0x1\n" "str h16, [x23, #0x0]\n" "add x23, x23, #0x2\n" - "bge 19b\n" - "20:" // Tail row loop: width 1 loop: skip - "21:" // Tail row loop: odd col skip + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "20:" // Tail row loop: odd col skip "cmp %x[height], #0x1\n" "add %x[out], %x[out], #0x20\n" - "bge 14b\n" - "22:" // Done + "bge 13b\n" + "21:" // Done : [bias] "+&r"(bias), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out) : [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [width] "r"(width) : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "x22", "x23", "x24", diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h index 7b131d0d015afa4c3ce77002dd895386e2268ca6..4721f8b9efa731c8f1881b8b7b9e4714dbd37c35 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h @@ -35,11 +35,11 @@ size_t kai_get_bias_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(size_t n_id /// Gets the offset in bytes to the data element in the packed RHS buffer. /// -/// @param[in] k Number of columns. /// @param[in] n_idx Row index. +/// @param[in] k Number of columns. /// /// @return The offset in bytes to the data element. -size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(size_t k, size_t n_idx); +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon(size_t n_idx, size_t k); /// Gets the size in bytes of the packed RHS buffer. /// diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index 9b9e481faf28e55a66576d214a550e7ec1baea6e..0d0f1d2da18b3dc241f432c62cacab85cb49ae6a 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -137,11 +137,11 @@ struct MatMulMethod { /// Gets the offset in bytes of the packed RHS matrix. /// - /// @param[in] k Size of the matrix in K dimension. /// @param[in] n_idx Coordinate of the matrix in N dimension. + /// @param[in] k Size of the matrix in K dimension. /// /// @return The offset in bytes. - std::function fn_get_packed_rhs_offset; + std::function fn_get_packed_rhs_offset; std::function