From 9cefc417db875f8dddf343ad165df359ee77c9a2 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 17 Jun 2024 15:25:30 +0100 Subject: [PATCH 1/4] Extend kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 to pack biases - Extend kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 to pack biases at the of each row - Adjust the int4 matmul micro-kernels (qsi4cx) to skip the bias Signed-off-by: Gian Marco Iodice --- ...ai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c | 4 + ...ai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c | 4 + ...2_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c | 156 +-- ...2_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c | 406 +++--- ...2_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c | 257 ++-- ...2_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c | 1104 +++++++++-------- .../kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c | 20 +- 7 files changed, 995 insertions(+), 956 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c index fa9a3f5c..d062c5be 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c @@ -23,6 +23,7 @@ static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_k_roundedup(size_t k) { // Since we pack a float and int32 value at the end of the row, @@ -184,6 +185,9 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( const float32x4_t rhs_scale = vld1q_f32((const float*)rhs_ptr); rhs_ptr += sizeof(float32x4_t); + // Skip the bias + rhs_ptr += kai_nr * kai_num_bytes_bias; + // Add the reduction sum iacc = vmlaq_s32(iacc, sum_n_s32, lhs_offset); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c index 17707488..707cb0bd 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c @@ -23,6 +23,7 @@ static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_k_roundedup(size_t k) { // Since we pack a float and int32 value at the end of the row, @@ -211,6 +212,9 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod( const float32x4_t rhs_scale1 = vld1q_f32((const float*)rhs_ptr); rhs_ptr += sizeof(float32x4_t); + // Skip the bias + rhs_ptr += kai_nr * kai_num_bytes_bias; + // Add the reduction sum iacc0 = vmlaq_s32(iacc0, sum_n_s32_0, lhs_offset); iacc1 = vmlaq_s32(iacc1, sum_n_s32_1, lhs_offset); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c index ebcaaf54..97e3e308 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c @@ -23,6 +23,7 @@ static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_k_roundedup(size_t k) { // Since we pack a float and int32 value at the end of the row, @@ -36,7 +37,8 @@ inline static size_t kai_lhs_packed_stride(size_t k) { KAI_ASSERT((k_internal % 2) == 0); - return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + return kai_mr * + (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs + kai_num_bytes_bias); } inline static size_t kai_rhs_packed_stride(size_t k) { @@ -116,19 +118,19 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( "movi v4.16b, #0xf0\n" "mov x27, %x[m]\n" "madd x28, %x[num_blocks], x28, x20\n" - "cbz x27, 8f\n" + "cbz x27, 9f\n" "1:" // Row loop "mov x26, %x[rhs_packed]\n" "mov x25, %x[n]\n" "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" "2:" // Column loop + "mov x21, %x[lhs_packed]\n" "movi v3.4s, #0x0\n" "movi v2.4s, #0x0\n" - "mov x21, %x[lhs_packed]\n" "mov x20, %x[num_blocks]\n" "movi v1.4s, #0x0\n" "movi v0.4s, #0x0\n" - "3:" // Block loop + "3:" // Sub block loop "ldr q31, [x26, #0x0]\n" "ldr q30, [x26, #0x10]\n" "subs x20, x20, #0x1\n" @@ -170,43 +172,61 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( ".inst 0x4e9aa640 // smmla v0.4s, v18.16b, v26.16b\n" "bgt 3b\n" "ldr q18, [x26, #0x0]\n" - "ldr q17, [x21, #0x0]\n" - "uzp1 v26.2d, v3.2d, v2.2d\n" - "uzp2 v25.2d, v3.2d, v2.2d\n" - "ldr q24, [x26, #0x10]\n" - "ldr q16, [x21, #0x10]\n" - "uzp1 v23.2d, v1.2d, v0.2d\n" - "uzp2 v22.2d, v1.2d, v0.2d\n" - "ld1r { v21.4s }, [%x[clamp_vals]]\n" - "add x20, %x[clamp_vals], #0x4\n" - "cmp x25, #0x4\n" - "ld1r { v20.4s }, [x20]\n" - "mla v26.4s, v18.4s, v17.s[0]\n" - "mla v25.4s, v18.4s, v17.s[1]\n" + "ld1 { v17.4s }, [x21]\n" + "uzp1 v24.2d, v3.2d, v2.2d\n" + "uzp2 v23.2d, v3.2d, v2.2d\n" + "ldr q22, [x26, #0x10]\n" + "uzp1 v21.2d, v1.2d, v0.2d\n" + "uzp2 v20.2d, v1.2d, v0.2d\n" + "add x21, x21, #0x10\n" + "ldr q16, [x21, #0x0]\n" "add x26, x26, #0x20\n" - "mla v23.4s, v18.4s, v17.s[2]\n" - "mla v22.4s, v18.4s, v17.s[3]\n" - "fmul v19.4s, v24.4s, v16.s[0]\n" - "fmul v18.4s, v24.4s, v16.s[1]\n" - "fmul v17.4s, v24.4s, v16.s[2]\n" - "fmul v16.4s, v24.4s, v16.s[3]\n" - "scvtf v26.4s, v26.4s\n" - "scvtf v25.4s, v25.4s\n" + "mla v24.4s, v18.4s, v17.s[0]\n" + "mla v23.4s, v18.4s, v17.s[1]\n" + "mla v21.4s, v18.4s, v17.s[2]\n" + "mla v20.4s, v18.4s, v17.s[3]\n" + "fmul v19.4s, v22.4s, v16.s[0]\n" + "fmul v18.4s, v22.4s, v16.s[1]\n" + "fmul v17.4s, v22.4s, v16.s[2]\n" + "fmul v16.4s, v22.4s, v16.s[3]\n" + "scvtf v24.4s, v24.4s\n" "scvtf v23.4s, v23.4s\n" - "scvtf v22.4s, v22.4s\n" - "fmul v26.4s, v26.4s, v19.4s\n" - "fmul v25.4s, v25.4s, v18.4s\n" - "fmul v23.4s, v23.4s, v17.4s\n" - "fmul v22.4s, v22.4s, v16.4s\n" - "fmax v26.4s, v26.4s, v21.4s\n" - "fmax v25.4s, v25.4s, v21.4s\n" - "fmax v23.4s, v23.4s, v21.4s\n" - "fmax v22.4s, v22.4s, v21.4s\n" - "fmin v26.4s, v26.4s, v20.4s\n" - "fmin v25.4s, v25.4s, v20.4s\n" - "fmin v23.4s, v23.4s, v20.4s\n" - "fmin v22.4s, v22.4s, v20.4s\n" - "bge 6f\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" + "fmul v3.4s, v24.4s, v19.4s\n" + "fmul v2.4s, v23.4s, v18.4s\n" + "fmul v1.4s, v21.4s, v17.4s\n" + "fmul v0.4s, v20.4s, v16.4s\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x4\n" + "ld1r { v16.4s }, [x20]\n" + "add x26, x26, #0x10\n" + "fmax v3.4s, v3.4s, v17.4s\n" + "fmax v2.4s, v2.4s, v17.4s\n" + "fmax v1.4s, v1.4s, v17.4s\n" + "fmax v0.4s, v0.4s, v17.4s\n" + "fmin v3.4s, v3.4s, v16.4s\n" + "fmin v2.4s, v2.4s, v16.4s\n" + "fmin v1.4s, v1.4s, v16.4s\n" + "fmin v0.4s, v0.4s, v16.4s\n" + "blt 5f\n" + "mov x20, %x[dst]\n" + "cmp x27, #0x1\n" + "str q3, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 8f\n" + "cmp x27, #0x2\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 8f\n" + "cmp x27, #0x3\n" + "str q1, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 8f\n" + "str q0, [x20, #0x0]\n" + "b 8f\n" + "5:" // Partial output "mov x23, %x[dst]\n" "cmp x27, #0x1\n" "add x22, x23, %x[dst_stride_row]\n" @@ -217,40 +237,24 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( "cmp x27, #0x3\n" "add x20, x21, %x[dst_stride_row]\n" "csel x20, x20, x21, GE\n" - "tbz x25, #1, 4f\n" - "str d22, [x20], #0x8\n" - "str d23, [x21], #0x8\n" - "str d25, [x22], #0x8\n" - "str d26, [x23], #0x8\n" - "tbz x25, #0, 5f\n" - "st1 { v22.s }[2], [x20]\n" - "st1 { v23.s }[2], [x21]\n" - "st1 { v25.s }[2], [x22]\n" - "st1 { v26.s }[2], [x23]\n" - "b 5f\n" - "4:" // Output block 0: partial_1_0 - "str s22, [x20, #0x0]\n" - "str s23, [x21, #0x0]\n" - "str s25, [x22, #0x0]\n" - "str s26, [x23, #0x0]\n" - "5:" // Output block 0: Done + "tbz x25, #1, 6f\n" + "st1 { v0.d }[0], [x20], #0x8\n" + "st1 { v1.d }[0], [x21], #0x8\n" + "st1 { v2.d }[0], [x22], #0x8\n" + "st1 { v3.d }[0], [x23], #0x8\n" + "tbz x25, #0, 7f\n" + "st1 { v0.s }[2], [x20]\n" + "st1 { v1.s }[2], [x21]\n" + "st1 { v2.s }[2], [x22]\n" + "st1 { v3.s }[2], [x23]\n" "b 7f\n" - "6:" // Full output - "mov x20, %x[dst]\n" - "cmp x27, #0x1\n" - "str q26, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 7f\n" - "cmp x27, #0x2\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 7f\n" - "cmp x27, #0x3\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 7f\n" - "str q22, [x20, #0x0]\n" - "7:" // Output stage exit + "6:" // Output block 0: partial_1_0 + "st1 { v0.s }[0], [x20]\n" + "st1 { v1.s }[0], [x21]\n" + "st1 { v2.s }[0], [x22]\n" + "st1 { v3.s }[0], [x23]\n" + "7:" // Output block 0: Done + "8:" // Output stage exit "subs x25, x25, #0x4\n" "add %x[dst], %x[dst], #0x10\n" "bgt 2b\n" @@ -258,10 +262,10 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( "add %x[lhs_packed], %x[lhs_packed], x28\n" "mov %x[dst], x24\n" "bgt 1b\n" - "8:" // Row loop skip - : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) - : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), - [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + "9:" // Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c index ac9ec40e..71f7070d 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c @@ -23,6 +23,7 @@ static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_k_roundedup(size_t k) { // Since we pack a float and int32 value at the end of the row, @@ -44,7 +45,7 @@ inline static size_t kai_rhs_packed_stride(size_t k) { KAI_ASSERT((k_internal % 2) == 0); - return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { @@ -117,7 +118,7 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( "mov x20, #0x20\n" "cmp x12, #0x8\n" "madd x11, %x[num_blocks], x11, x20\n" - "blt 8f\n" + "blt 10f\n" "1:" // Row loop "mov x10, %x[rhs_packed]\n" "mov x9, %x[n]\n" @@ -131,10 +132,10 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( "movi v7.4s, #0x0\n" "movi v6.4s, #0x0\n" "movi v5.4s, #0x0\n" + "add x20, x22, x11\n" "movi v4.4s, #0x0\n" "movi v3.4s, #0x0\n" - "add x20, x22, x11\n" - "3:" // Block loop + "3:" // Sub block loop "ldr q2, [x10, #0x0]\n" "ldr q1, [x10, #0x10]\n" "subs x21, x21, #0x1\n" @@ -200,130 +201,133 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( ".inst 0x4e9fa644 // smmla v4.4s, v18.16b, v31.16b\n" ".inst 0x4e9ea643 // smmla v3.4s, v18.16b, v30.16b\n" "bgt 3b\n" - "ldr q20, [x10, #0x0]\n" - "ldr q19, [x22, #0x0]\n" - "uzp1 v2.2d, v10.2d, v9.2d\n" - "uzp2 v1.2d, v10.2d, v9.2d\n" - "ldr q18, [x20, #0x0]\n" - "ldr q0, [x10, #0x10]\n" - "uzp1 v31.2d, v8.2d, v7.2d\n" - "uzp2 v30.2d, v8.2d, v7.2d\n" - "ldr q17, [x22, #0x10]\n" - "ldr q16, [x20, #0x10]\n" - "uzp1 v29.2d, v6.2d, v5.2d\n" - "uzp2 v28.2d, v6.2d, v5.2d\n" - "ld1r { v27.4s }, [%x[clamp_vals]]\n" - "uzp1 v26.2d, v4.2d, v3.2d\n" - "uzp2 v25.2d, v4.2d, v3.2d\n" - "mla v2.4s, v20.4s, v19.s[0]\n" - "mla v1.4s, v20.4s, v19.s[1]\n" - "mla v31.4s, v20.4s, v19.s[2]\n" + "ldr q25, [x10, #0x0]\n" + "ld1 { v17.4s }, [x22]\n" + "uzp1 v23.2d, v10.2d, v9.2d\n" + "uzp2 v22.2d, v10.2d, v9.2d\n" + "ldr q24, [x10, #0x10]\n" + "uzp1 v21.2d, v8.2d, v7.2d\n" + "uzp2 v20.2d, v8.2d, v7.2d\n" + "add x22, x22, #0x10\n" + "ldr q16, [x22, #0x0]\n" + "add x10, x10, #0x20\n" + "mla v23.4s, v25.4s, v17.s[0]\n" + "mla v22.4s, v25.4s, v17.s[1]\n" + "mla v21.4s, v25.4s, v17.s[2]\n" + "mla v20.4s, v25.4s, v17.s[3]\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" + "fmul v10.4s, v23.4s, v19.4s\n" + "fmul v9.4s, v22.4s, v18.4s\n" + "fmul v8.4s, v21.4s, v17.4s\n" + "fmul v7.4s, v20.4s, v16.4s\n" + "ld1 { v17.4s }, [x20]\n" + "uzp1 v23.2d, v6.2d, v5.2d\n" + "uzp2 v22.2d, v6.2d, v5.2d\n" + "add x20, x20, #0x10\n" + "ldr q16, [x20, #0x0]\n" + "uzp1 v21.2d, v4.2d, v3.2d\n" + "uzp2 v20.2d, v4.2d, v3.2d\n" + "mla v23.4s, v25.4s, v17.s[0]\n" + "mla v22.4s, v25.4s, v17.s[1]\n" + "mla v21.4s, v25.4s, v17.s[2]\n" + "mla v20.4s, v25.4s, v17.s[3]\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "scvtf v23.4s, v23.4s\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" + "fmul v6.4s, v23.4s, v19.4s\n" + "fmul v5.4s, v22.4s, v18.4s\n" + "fmul v4.4s, v21.4s, v17.4s\n" + "fmul v3.4s, v20.4s, v16.4s\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" "add x20, %x[clamp_vals], #0x4\n" "cmp x9, #0x4\n" - "ld1r { v24.4s }, [x20]\n" - "mla v30.4s, v20.4s, v19.s[3]\n" - "mla v29.4s, v20.4s, v18.s[0]\n" - "fmul v23.4s, v0.4s, v17.s[0]\n" - "mla v28.4s, v20.4s, v18.s[1]\n" - "mla v26.4s, v20.4s, v18.s[2]\n" - "fmul v22.4s, v0.4s, v17.s[1]\n" - "add x10, x10, #0x20\n" - "mla v25.4s, v20.4s, v18.s[3]\n" - "scvtf v2.4s, v2.4s\n" - "scvtf v1.4s, v1.4s\n" - "scvtf v31.4s, v31.4s\n" - "fmul v21.4s, v0.4s, v17.s[2]\n" - "scvtf v30.4s, v30.4s\n" - "fmul v20.4s, v0.4s, v17.s[3]\n" - "scvtf v29.4s, v29.4s\n" - "fmul v19.4s, v0.4s, v16.s[0]\n" - "scvtf v28.4s, v28.4s\n" - "fmul v18.4s, v0.4s, v16.s[1]\n" - "scvtf v26.4s, v26.4s\n" - "fmul v17.4s, v0.4s, v16.s[2]\n" - "scvtf v25.4s, v25.4s\n" - "fmul v16.4s, v0.4s, v16.s[3]\n" - "fmul v2.4s, v2.4s, v23.4s\n" - "fmul v1.4s, v1.4s, v22.4s\n" - "fmul v31.4s, v31.4s, v21.4s\n" - "fmul v30.4s, v30.4s, v20.4s\n" - "fmul v29.4s, v29.4s, v19.4s\n" - "fmul v28.4s, v28.4s, v18.4s\n" - "fmul v26.4s, v26.4s, v17.4s\n" - "fmul v25.4s, v25.4s, v16.4s\n" - "fmax v2.4s, v2.4s, v27.4s\n" - "fmax v1.4s, v1.4s, v27.4s\n" - "fmax v31.4s, v31.4s, v27.4s\n" - "fmax v30.4s, v30.4s, v27.4s\n" - "fmax v29.4s, v29.4s, v27.4s\n" - "fmax v28.4s, v28.4s, v27.4s\n" - "fmax v26.4s, v26.4s, v27.4s\n" - "fmax v25.4s, v25.4s, v27.4s\n" - "fmin v2.4s, v2.4s, v24.4s\n" - "fmin v1.4s, v1.4s, v24.4s\n" - "fmin v31.4s, v31.4s, v24.4s\n" - "fmin v30.4s, v30.4s, v24.4s\n" - "fmin v29.4s, v29.4s, v24.4s\n" - "fmin v28.4s, v28.4s, v24.4s\n" - "fmin v26.4s, v26.4s, v24.4s\n" - "fmin v25.4s, v25.4s, v24.4s\n" - "bge 6f\n" - "mov x27, %x[dst]\n" - "add x26, x27, %x[dst_stride_row], LSL #2\n" - "add x25, x26, %x[dst_stride_row], LSL #1\n" - "add x24, x26, %x[dst_stride_row]\n" - "add x23, x25, %x[dst_stride_row]\n" - "add x22, x27, %x[dst_stride_row], LSL #1\n" - "add x21, x27, %x[dst_stride_row]\n" - "add x20, x22, %x[dst_stride_row]\n" - "tbz x9, #1, 4f\n" - "str d25, [x23], #0x8\n" - "str d26, [x25], #0x8\n" - "str d28, [x24], #0x8\n" - "str d29, [x26], #0x8\n" - "str d30, [x20], #0x8\n" - "str d31, [x22], #0x8\n" - "str d1, [x21], #0x8\n" - "str d2, [x27], #0x8\n" - "tbz x9, #0, 5f\n" - "st1 { v25.s }[2], [x23]\n" - "st1 { v26.s }[2], [x25]\n" - "st1 { v28.s }[2], [x24]\n" - "st1 { v29.s }[2], [x26]\n" - "st1 { v30.s }[2], [x20]\n" - "st1 { v31.s }[2], [x22]\n" - "st1 { v1.s }[2], [x21]\n" - "st1 { v2.s }[2], [x27]\n" - "b 5f\n" - "4:" // Output block 0: partial_1_0 - "str s25, [x23, #0x0]\n" - "str s26, [x25, #0x0]\n" - "str s28, [x24, #0x0]\n" - "str s29, [x26, #0x0]\n" - "str s30, [x20, #0x0]\n" - "str s31, [x22, #0x0]\n" - "str s1, [x21, #0x0]\n" - "str s2, [x27, #0x0]\n" - "5:" // Output block 0: Done - "b 7f\n" - "6:" // Full output + "ld1r { v16.4s }, [x20]\n" + "add x10, x10, #0x10\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v7.4s, v7.4s, v17.4s\n" + "fmax v6.4s, v6.4s, v17.4s\n" + "fmax v5.4s, v5.4s, v17.4s\n" + "fmax v4.4s, v4.4s, v17.4s\n" + "fmax v3.4s, v3.4s, v17.4s\n" + "fmin v10.4s, v10.4s, v16.4s\n" + "fmin v9.4s, v9.4s, v16.4s\n" + "fmin v8.4s, v8.4s, v16.4s\n" + "fmin v7.4s, v7.4s, v16.4s\n" + "fmin v6.4s, v6.4s, v16.4s\n" + "fmin v5.4s, v5.4s, v16.4s\n" + "fmin v4.4s, v4.4s, v16.4s\n" + "fmin v3.4s, v3.4s, v16.4s\n" + "blt 6f\n" "mov x20, %x[dst]\n" - "str q2, [x20, #0x0]\n" + "str q10, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q1, [x20, #0x0]\n" + "str q9, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q31, [x20, #0x0]\n" + "str q8, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q30, [x20, #0x0]\n" + "str q7, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q29, [x20, #0x0]\n" + "str q6, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q28, [x20, #0x0]\n" + "str q5, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q26, [x20, #0x0]\n" + "str q4, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q25, [x20, #0x0]\n" - "7:" // Output stage exit + "str q3, [x20, #0x0]\n" + "b 9f\n" + "6:" // Partial output + "mov x27, %x[dst]\n" + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x27, %x[dst_stride_row], LSL #1\n" + "add x21, x27, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "tbz x9, #1, 7f\n" + "st1 { v3.d }[0], [x23], #0x8\n" + "st1 { v4.d }[0], [x25], #0x8\n" + "st1 { v5.d }[0], [x24], #0x8\n" + "st1 { v6.d }[0], [x26], #0x8\n" + "st1 { v7.d }[0], [x20], #0x8\n" + "st1 { v8.d }[0], [x22], #0x8\n" + "st1 { v9.d }[0], [x21], #0x8\n" + "st1 { v10.d }[0], [x27], #0x8\n" + "tbz x9, #0, 8f\n" + "st1 { v3.s }[2], [x23]\n" + "st1 { v4.s }[2], [x25]\n" + "st1 { v5.s }[2], [x24]\n" + "st1 { v6.s }[2], [x26]\n" + "st1 { v7.s }[2], [x20]\n" + "st1 { v8.s }[2], [x22]\n" + "st1 { v9.s }[2], [x21]\n" + "st1 { v10.s }[2], [x27]\n" + "b 8f\n" + "7:" // Output block 0: partial_1_0 + "st1 { v3.s }[0], [x23]\n" + "st1 { v4.s }[0], [x25]\n" + "st1 { v5.s }[0], [x24]\n" + "st1 { v6.s }[0], [x26]\n" + "st1 { v7.s }[0], [x20]\n" + "st1 { v8.s }[0], [x22]\n" + "st1 { v9.s }[0], [x21]\n" + "st1 { v10.s }[0], [x27]\n" + "8:" // Output block 0: Done + "9:" // Output stage exit "subs x9, x9, #0x4\n" "add %x[dst], %x[dst], #0x10\n" "bgt 2b\n" @@ -333,20 +337,20 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( "mov %x[dst], x28\n" "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" "bge 1b\n" - "8:" // Row loop skip - "cbz x12, 16f\n" - "9:" // Row tail: Row loop + "10:" // Row loop skip + "cbz x12, 19f\n" + "11:" // Row tail: Row loop "mov x26, %x[rhs_packed]\n" "mov x25, %x[n]\n" "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" - "10:" // Row tail: Column loop + "12:" // Row tail: Column loop + "mov x22, %x[lhs_packed]\n" "movi v10.4s, #0x0\n" "movi v9.4s, #0x0\n" - "mov x22, %x[lhs_packed]\n" "mov x20, %x[num_blocks]\n" "movi v8.4s, #0x0\n" "movi v7.4s, #0x0\n" - "11:" // Row tail: Block loop + "13:" // Row tail: Sub block loop "ldr q31, [x26, #0x0]\n" "ldr q30, [x26, #0x10]\n" "subs x20, x20, #0x1\n" @@ -386,45 +390,63 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( ".inst 0x4e9aa669 // smmla v9.4s, v19.16b, v26.16b\n" ".inst 0x4e9ba648 // smmla v8.4s, v18.16b, v27.16b\n" ".inst 0x4e9aa647 // smmla v7.4s, v18.16b, v26.16b\n" - "bgt 11b\n" + "bgt 13b\n" "ldr q18, [x26, #0x0]\n" - "ldr q17, [x22, #0x0]\n" - "uzp1 v26.2d, v10.2d, v9.2d\n" - "uzp2 v25.2d, v10.2d, v9.2d\n" - "ldr q24, [x26, #0x10]\n" - "ldr q16, [x22, #0x10]\n" - "uzp1 v23.2d, v8.2d, v7.2d\n" - "uzp2 v22.2d, v8.2d, v7.2d\n" - "ld1r { v21.4s }, [%x[clamp_vals]]\n" - "add x20, %x[clamp_vals], #0x4\n" - "cmp x25, #0x4\n" - "ld1r { v20.4s }, [x20]\n" - "mla v26.4s, v18.4s, v17.s[0]\n" - "mla v25.4s, v18.4s, v17.s[1]\n" + "ld1 { v17.4s }, [x22]\n" + "uzp1 v24.2d, v10.2d, v9.2d\n" + "uzp2 v23.2d, v10.2d, v9.2d\n" + "ldr q22, [x26, #0x10]\n" + "uzp1 v21.2d, v8.2d, v7.2d\n" + "uzp2 v20.2d, v8.2d, v7.2d\n" + "add x22, x22, #0x10\n" + "ldr q16, [x22, #0x0]\n" "add x26, x26, #0x20\n" - "mla v23.4s, v18.4s, v17.s[2]\n" - "mla v22.4s, v18.4s, v17.s[3]\n" - "fmul v19.4s, v24.4s, v16.s[0]\n" - "fmul v18.4s, v24.4s, v16.s[1]\n" - "fmul v17.4s, v24.4s, v16.s[2]\n" - "fmul v16.4s, v24.4s, v16.s[3]\n" - "scvtf v26.4s, v26.4s\n" - "scvtf v25.4s, v25.4s\n" + "mla v24.4s, v18.4s, v17.s[0]\n" + "mla v23.4s, v18.4s, v17.s[1]\n" + "mla v21.4s, v18.4s, v17.s[2]\n" + "mla v20.4s, v18.4s, v17.s[3]\n" + "fmul v19.4s, v22.4s, v16.s[0]\n" + "fmul v18.4s, v22.4s, v16.s[1]\n" + "fmul v17.4s, v22.4s, v16.s[2]\n" + "fmul v16.4s, v22.4s, v16.s[3]\n" + "scvtf v24.4s, v24.4s\n" "scvtf v23.4s, v23.4s\n" - "scvtf v22.4s, v22.4s\n" - "fmul v26.4s, v26.4s, v19.4s\n" - "fmul v25.4s, v25.4s, v18.4s\n" - "fmul v23.4s, v23.4s, v17.4s\n" - "fmul v22.4s, v22.4s, v16.4s\n" - "fmax v26.4s, v26.4s, v21.4s\n" - "fmax v25.4s, v25.4s, v21.4s\n" - "fmax v23.4s, v23.4s, v21.4s\n" - "fmax v22.4s, v22.4s, v21.4s\n" - "fmin v26.4s, v26.4s, v20.4s\n" - "fmin v25.4s, v25.4s, v20.4s\n" - "fmin v23.4s, v23.4s, v20.4s\n" - "fmin v22.4s, v22.4s, v20.4s\n" - "bge 14f\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" + "fmul v10.4s, v24.4s, v19.4s\n" + "fmul v9.4s, v23.4s, v18.4s\n" + "fmul v8.4s, v21.4s, v17.4s\n" + "fmul v7.4s, v20.4s, v16.4s\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x4\n" + "ld1r { v16.4s }, [x20]\n" + "add x26, x26, #0x10\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v7.4s, v7.4s, v17.4s\n" + "fmin v10.4s, v10.4s, v16.4s\n" + "fmin v9.4s, v9.4s, v16.4s\n" + "fmin v8.4s, v8.4s, v16.4s\n" + "fmin v7.4s, v7.4s, v16.4s\n" + "blt 15f\n" + "mov x20, %x[dst]\n" + "cmp x12, #0x1\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 18f\n" + "cmp x12, #0x2\n" + "str q9, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 18f\n" + "cmp x12, #0x3\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 18f\n" + "str q7, [x20, #0x0]\n" + "b 18f\n" + "15:" // Row tail: Partial output "mov x23, %x[dst]\n" "cmp x12, #0x1\n" "add x22, x23, %x[dst_stride_row]\n" @@ -435,51 +457,35 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( "cmp x12, #0x3\n" "add x20, x21, %x[dst_stride_row]\n" "csel x20, x20, x21, GE\n" - "tbz x25, #1, 12f\n" - "str d22, [x20], #0x8\n" - "str d23, [x21], #0x8\n" - "str d25, [x22], #0x8\n" - "str d26, [x23], #0x8\n" - "tbz x25, #0, 13f\n" - "st1 { v22.s }[2], [x20]\n" - "st1 { v23.s }[2], [x21]\n" - "st1 { v25.s }[2], [x22]\n" - "st1 { v26.s }[2], [x23]\n" - "b 13f\n" - "12:" // Row tail: Output block 0: partial_1_0 - "str s22, [x20, #0x0]\n" - "str s23, [x21, #0x0]\n" - "str s25, [x22, #0x0]\n" - "str s26, [x23, #0x0]\n" - "13:" // Row tail: Output block 0: Done - "b 15f\n" - "14:" // Row tail: Full output - "mov x20, %x[dst]\n" - "cmp x12, #0x1\n" - "str q26, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 15f\n" - "cmp x12, #0x2\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 15f\n" - "cmp x12, #0x3\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 15f\n" - "str q22, [x20, #0x0]\n" - "15:" // Row tail: Output stage exit + "tbz x25, #1, 16f\n" + "st1 { v7.d }[0], [x20], #0x8\n" + "st1 { v8.d }[0], [x21], #0x8\n" + "st1 { v9.d }[0], [x22], #0x8\n" + "st1 { v10.d }[0], [x23], #0x8\n" + "tbz x25, #0, 17f\n" + "st1 { v7.s }[2], [x20]\n" + "st1 { v8.s }[2], [x21]\n" + "st1 { v9.s }[2], [x22]\n" + "st1 { v10.s }[2], [x23]\n" + "b 17f\n" + "16:" // Row tail: Output block 0: partial_1_0 + "st1 { v7.s }[0], [x20]\n" + "st1 { v8.s }[0], [x21]\n" + "st1 { v9.s }[0], [x22]\n" + "st1 { v10.s }[0], [x23]\n" + "17:" // Row tail: Output block 0: Done + "18:" // Row tail: Output stage exit "subs x25, x25, #0x4\n" "add %x[dst], %x[dst], #0x10\n" - "bgt 10b\n" + "bgt 12b\n" "subs x12, x12, #0x4\n" "add %x[lhs_packed], %x[lhs_packed], x11\n" "mov %x[dst], x24\n" - "bgt 9b\n" - "16:" // Row tail: Row loop skip - : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) - : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), - [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + "bgt 11b\n" + "19:" // Row tail: Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c index 93c1a350..136248e0 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c @@ -23,6 +23,7 @@ static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_k_roundedup(size_t k) { // Since we pack a float and int32 value at the end of the row, @@ -44,7 +45,7 @@ inline static size_t kai_rhs_packed_stride(size_t k) { KAI_ASSERT((k_internal % 2) == 0); - return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) { @@ -116,15 +117,15 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( "movi v12.16b, #0xf0\n" "mov x27, %x[m]\n" "madd x28, %x[num_blocks], x28, x20\n" - "cbz x27, 10f\n" + "cbz x27, 11f\n" "1:" // Row loop "mov x26, %x[rhs_packed]\n" "mov x25, %x[n]\n" "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" "2:" // Column loop + "mov x21, %x[lhs_packed]\n" "movi v11.4s, #0x0\n" "movi v10.4s, #0x0\n" - "mov x21, %x[lhs_packed]\n" "mov x20, %x[num_blocks]\n" "movi v9.4s, #0x0\n" "movi v8.4s, #0x0\n" @@ -132,7 +133,7 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( "movi v6.4s, #0x0\n" "movi v5.4s, #0x0\n" "movi v4.4s, #0x0\n" - "3:" // Block loop + "3:" // Sub block loop "ldr q3, [x26, #0x0]\n" "ldr q2, [x26, #0x10]\n" "subs x20, x20, #0x1\n" @@ -163,11 +164,11 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( ".inst 0x4e93a7c7 // smmla v7.4s, v30.16b, v19.16b\n" ".inst 0x4e92a7c5 // smmla v5.4s, v30.16b, v18.16b\n" "shl v19.16b, v29.16b, #0x4\n" - "add x21, x21, #0x80\n" + "add x26, x26, #0x80\n" ".inst 0x4e91a7c6 // smmla v6.4s, v30.16b, v17.16b\n" ".inst 0x4e90a7c4 // smmla v4.4s, v30.16b, v16.16b\n" "shl v18.16b, v28.16b, #0x4\n" - "add x26, x26, #0x80\n" + "add x21, x21, #0x80\n" "shl v17.16b, v27.16b, #0x4\n" "shl v16.16b, v26.16b, #0x4\n" ".inst 0x4e93a72b // smmla v11.4s, v25.16b, v19.16b\n" @@ -203,72 +204,94 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( "bgt 3b\n" "ldr q20, [x26, #0x0]\n" "ldr q19, [x26, #0x10]\n" - "uzp1 v2.2d, v11.2d, v9.2d\n" - "uzp1 v1.2d, v10.2d, v8.2d\n" - "ldr q18, [x21, #0x0]\n" + "uzp1 v0.2d, v11.2d, v9.2d\n" + "uzp2 v31.2d, v11.2d, v9.2d\n" + "ld1 { v18.4s }, [x21]\n" "ldr q17, [x26, #0x20]\n" - "uzp2 v0.2d, v11.2d, v9.2d\n" - "uzp2 v31.2d, v10.2d, v8.2d\n" - "ldr q30, [x26, #0x30]\n" - "ldr q16, [x21, #0x10]\n" - "uzp1 v29.2d, v7.2d, v5.2d\n" - "uzp1 v28.2d, v6.2d, v4.2d\n" - "ld1r { v27.4s }, [%x[clamp_vals]]\n" + "uzp1 v30.2d, v10.2d, v8.2d\n" + "uzp2 v29.2d, v10.2d, v8.2d\n" + "ldr q28, [x26, #0x30]\n" + "uzp1 v27.2d, v7.2d, v5.2d\n" "uzp2 v26.2d, v7.2d, v5.2d\n" - "uzp2 v25.2d, v6.2d, v4.2d\n" - "add x20, %x[clamp_vals], #0x4\n" - "ld1r { v24.4s }, [x20]\n" - "mla v2.4s, v20.4s, v18.s[0]\n" - "mla v1.4s, v19.4s, v18.s[0]\n" - "cmp x25, #0x8\n" - "mla v0.4s, v20.4s, v18.s[1]\n" - "mla v31.4s, v19.4s, v18.s[1]\n" - "fmul v23.4s, v17.4s, v16.s[0]\n" + "add x21, x21, #0x10\n" + "ldr q16, [x21, #0x0]\n" + "uzp1 v25.2d, v6.2d, v4.2d\n" + "uzp2 v24.2d, v6.2d, v4.2d\n" "add x26, x26, #0x40\n" - "mla v29.4s, v20.4s, v18.s[2]\n" - "mla v28.4s, v19.4s, v18.s[2]\n" - "fmul v22.4s, v30.4s, v16.s[0]\n" + "mla v0.4s, v20.4s, v18.s[0]\n" + "mla v30.4s, v19.4s, v18.s[0]\n" + "mla v31.4s, v20.4s, v18.s[1]\n" + "mla v29.4s, v19.4s, v18.s[1]\n" + "mla v27.4s, v20.4s, v18.s[2]\n" + "mla v25.4s, v19.4s, v18.s[2]\n" + "fmul v23.4s, v17.4s, v16.s[0]\n" "mla v26.4s, v20.4s, v18.s[3]\n" - "mla v25.4s, v19.4s, v18.s[3]\n" - "fmul v21.4s, v17.4s, v16.s[1]\n" - "scvtf v2.4s, v2.4s\n" - "scvtf v1.4s, v1.4s\n" + "mla v24.4s, v19.4s, v18.s[3]\n" + "fmul v22.4s, v28.4s, v16.s[0]\n" "scvtf v0.4s, v0.4s\n" + "scvtf v30.4s, v30.4s\n" + "fmul v21.4s, v17.4s, v16.s[1]\n" "scvtf v31.4s, v31.4s\n" - "fmul v20.4s, v30.4s, v16.s[1]\n" + "fmul v20.4s, v28.4s, v16.s[1]\n" "scvtf v29.4s, v29.4s\n" "fmul v19.4s, v17.4s, v16.s[2]\n" - "scvtf v28.4s, v28.4s\n" - "fmul v18.4s, v30.4s, v16.s[2]\n" - "scvtf v26.4s, v26.4s\n" - "fmul v17.4s, v17.4s, v16.s[3]\n" + "scvtf v27.4s, v27.4s\n" + "fmul v18.4s, v28.4s, v16.s[2]\n" "scvtf v25.4s, v25.4s\n" - "fmul v16.4s, v30.4s, v16.s[3]\n" - "fmul v2.4s, v2.4s, v23.4s\n" - "fmul v1.4s, v1.4s, v22.4s\n" - "fmul v0.4s, v0.4s, v21.4s\n" - "fmul v31.4s, v31.4s, v20.4s\n" - "fmul v29.4s, v29.4s, v19.4s\n" - "fmul v28.4s, v28.4s, v18.4s\n" - "fmul v26.4s, v26.4s, v17.4s\n" - "fmul v25.4s, v25.4s, v16.4s\n" - "fmax v2.4s, v2.4s, v27.4s\n" - "fmax v1.4s, v1.4s, v27.4s\n" - "fmax v0.4s, v0.4s, v27.4s\n" - "fmax v31.4s, v31.4s, v27.4s\n" - "fmax v29.4s, v29.4s, v27.4s\n" - "fmax v28.4s, v28.4s, v27.4s\n" - "fmax v26.4s, v26.4s, v27.4s\n" - "fmax v25.4s, v25.4s, v27.4s\n" - "fmin v2.4s, v2.4s, v24.4s\n" - "fmin v1.4s, v1.4s, v24.4s\n" - "fmin v0.4s, v0.4s, v24.4s\n" - "fmin v31.4s, v31.4s, v24.4s\n" - "fmin v29.4s, v29.4s, v24.4s\n" - "fmin v28.4s, v28.4s, v24.4s\n" - "fmin v26.4s, v26.4s, v24.4s\n" - "fmin v25.4s, v25.4s, v24.4s\n" - "bge 8f\n" + "fmul v17.4s, v17.4s, v16.s[3]\n" + "scvtf v26.4s, v26.4s\n" + "fmul v16.4s, v28.4s, v16.s[3]\n" + "scvtf v24.4s, v24.4s\n" + "fmul v11.4s, v0.4s, v23.4s\n" + "fmul v10.4s, v30.4s, v22.4s\n" + "fmul v9.4s, v31.4s, v21.4s\n" + "fmul v8.4s, v29.4s, v20.4s\n" + "fmul v7.4s, v27.4s, v19.4s\n" + "fmul v6.4s, v25.4s, v18.4s\n" + "fmul v5.4s, v26.4s, v17.4s\n" + "fmul v4.4s, v24.4s, v16.4s\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x8\n" + "ld1r { v16.4s }, [x20]\n" + "add x26, x26, #0x20\n" + "fmax v11.4s, v11.4s, v17.4s\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v7.4s, v7.4s, v17.4s\n" + "fmax v6.4s, v6.4s, v17.4s\n" + "fmax v5.4s, v5.4s, v17.4s\n" + "fmax v4.4s, v4.4s, v17.4s\n" + "fmin v11.4s, v11.4s, v16.4s\n" + "fmin v10.4s, v10.4s, v16.4s\n" + "fmin v9.4s, v9.4s, v16.4s\n" + "fmin v8.4s, v8.4s, v16.4s\n" + "fmin v7.4s, v7.4s, v16.4s\n" + "fmin v6.4s, v6.4s, v16.4s\n" + "fmin v5.4s, v5.4s, v16.4s\n" + "fmin v4.4s, v4.4s, v16.4s\n" + "blt 5f\n" + "mov x20, %x[dst]\n" + "cmp x27, #0x1\n" + "str q11, [x20, #0x0]\n" + "str q10, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 10f\n" + "cmp x27, #0x2\n" + "str q9, [x20, #0x0]\n" + "str q8, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 10f\n" + "cmp x27, #0x3\n" + "str q7, [x20, #0x0]\n" + "str q6, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 10f\n" + "str q5, [x20, #0x0]\n" + "str q4, [x20, #0x10]\n" + "b 10f\n" + "5:" // Partial output "mov x23, %x[dst]\n" "cmp x27, #0x1\n" "add x22, x23, %x[dst_stride_row]\n" @@ -279,68 +302,48 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( "cmp x27, #0x3\n" "add x20, x21, %x[dst_stride_row]\n" "csel x20, x20, x21, GE\n" - "tbz x25, #2, 5f\n" - "st1 { v26.4s }, [x20], #0x10\n" - "st1 { v29.4s }, [x21], #0x10\n" - "st1 { v0.4s }, [x22], #0x10\n" - "st1 { v2.4s }, [x23], #0x10\n" - "tbz x25, #1, 4f\n" - "str d25, [x20], #0x8\n" - "str d28, [x21], #0x8\n" - "str d31, [x22], #0x8\n" - "str d1, [x23], #0x8\n" - "tbz x25, #0, 7f\n" - "st1 { v25.s }[2], [x20]\n" - "st1 { v28.s }[2], [x21]\n" - "st1 { v31.s }[2], [x22]\n" - "st1 { v1.s }[2], [x23]\n" - "b 7f\n" - "4:" // Output block 0: partial_1_4 - "tbz x25, #0, 7f\n" - "str s25, [x20, #0x0]\n" - "str s28, [x21, #0x0]\n" - "str s31, [x22, #0x0]\n" - "str s1, [x23, #0x0]\n" - "b 7f\n" - "5:" // Output block 0: partial_2_0 + "tbz x25, #2, 7f\n" + "st1 { v5.4s }, [x20], #0x10\n" + "st1 { v7.4s }, [x21], #0x10\n" + "st1 { v9.4s }, [x22], #0x10\n" + "st1 { v11.4s }, [x23], #0x10\n" "tbz x25, #1, 6f\n" - "str d26, [x20], #0x8\n" - "str d29, [x21], #0x8\n" - "str d0, [x22], #0x8\n" - "str d2, [x23], #0x8\n" - "tbz x25, #0, 7f\n" - "st1 { v26.s }[2], [x20]\n" - "st1 { v29.s }[2], [x21]\n" - "st1 { v0.s }[2], [x22]\n" - "st1 { v2.s }[2], [x23]\n" - "b 7f\n" - "6:" // Output block 0: partial_1_0 - "str s26, [x20, #0x0]\n" - "str s29, [x21, #0x0]\n" - "str s0, [x22, #0x0]\n" - "str s2, [x23, #0x0]\n" - "7:" // Output block 0: Done + "st1 { v4.d }[0], [x20], #0x8\n" + "st1 { v6.d }[0], [x21], #0x8\n" + "st1 { v8.d }[0], [x22], #0x8\n" + "st1 { v10.d }[0], [x23], #0x8\n" + "tbz x25, #0, 9f\n" + "st1 { v4.s }[2], [x20]\n" + "st1 { v6.s }[2], [x21]\n" + "st1 { v8.s }[2], [x22]\n" + "st1 { v10.s }[2], [x23]\n" "b 9f\n" - "8:" // Full output - "mov x20, %x[dst]\n" - "cmp x27, #0x1\n" - "str q2, [x20, #0x0]\n" - "str q1, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 9f\n" - "cmp x27, #0x2\n" - "str q0, [x20, #0x0]\n" - "str q31, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 9f\n" - "cmp x27, #0x3\n" - "str q29, [x20, #0x0]\n" - "str q28, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 9f\n" - "str q26, [x20, #0x0]\n" - "str q25, [x20, #0x10]\n" - "9:" // Output stage exit + "6:" // Output block 0: partial_1_4 + "tbz x25, #0, 9f\n" + "st1 { v4.s }[0], [x20]\n" + "st1 { v6.s }[0], [x21]\n" + "st1 { v8.s }[0], [x22]\n" + "st1 { v10.s }[0], [x23]\n" + "b 9f\n" + "7:" // Output block 0: partial_2_0 + "tbz x25, #1, 8f\n" + "st1 { v5.d }[0], [x20], #0x8\n" + "st1 { v7.d }[0], [x21], #0x8\n" + "st1 { v9.d }[0], [x22], #0x8\n" + "st1 { v11.d }[0], [x23], #0x8\n" + "tbz x25, #0, 9f\n" + "st1 { v5.s }[2], [x20]\n" + "st1 { v7.s }[2], [x21]\n" + "st1 { v9.s }[2], [x22]\n" + "st1 { v11.s }[2], [x23]\n" + "b 9f\n" + "8:" // Output block 0: partial_1_0 + "st1 { v5.s }[0], [x20]\n" + "st1 { v7.s }[0], [x21]\n" + "st1 { v9.s }[0], [x22]\n" + "st1 { v11.s }[0], [x23]\n" + "9:" // Output block 0: Done + "10:" // Output stage exit "subs x25, x25, #0x8\n" "add %x[dst], %x[dst], #0x20\n" "bgt 2b\n" @@ -348,10 +351,10 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( "add %x[lhs_packed], %x[lhs_packed], x28\n" "mov %x[dst], x24\n" "bgt 1b\n" - "10:" // Row loop skip - : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) - : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), - [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + "11:" // Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c index 2b2b1e09..277d6e5f 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c @@ -23,6 +23,7 @@ static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_k_roundedup(size_t k) { // Since we pack a float and int32 value at the end of the row, @@ -44,7 +45,7 @@ inline static size_t kai_rhs_packed_stride(size_t k) { KAI_ASSERT((k_internal % 2) == 0); - return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void) { @@ -113,375 +114,378 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( __asm__ __volatile__( "mov x12, %x[m]\n" "mov x11, #0x80\n" - "movi v3.16b, #0xf0\n" + "movi v25.16b, #0xf0\n" "mov x20, #0x20\n" "cmp x12, #0x8\n" "madd x11, %x[num_blocks], x11, x20\n" - "blt 10f\n" + "blt 12f\n" "1:" // Row loop "mov x10, %x[rhs_packed]\n" "mov x9, %x[n]\n" "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" "2:" // Column loop "mov x22, %x[lhs_packed]\n" - "movi v13.4s, #0x0\n" - "movi v14.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v12.4s, #0x0\n" "mov x21, %x[num_blocks]\n" - "movi v25.4s, #0x0\n" - "movi v16.4s, #0x0\n" - "movi v26.4s, #0x0\n" - "movi v30.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "movi v19.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v4.4s, #0x0\n" + "movi v11.4s, #0x0\n" + "movi v14.4s, #0x0\n" "add x20, x22, x11\n" "movi v24.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "movi v28.4s, #0x0\n" - "movi v15.4s, #0x0\n" - "movi v29.4s, #0x0\n" - "movi v11.4s, #0x0\n" + "movi v1.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "movi v30.4s, #0x0\n" "movi v31.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "movi v13.4s, #0x0\n" "movi v7.4s, #0x0\n" - "3:" // Block loop - "ldr q21, [x10, #0x0]\n" - "ldr q20, [x10, #0x10]\n" + "movi v15.4s, #0x0\n" + "3:" // Sub block loop + "ldr q0, [x10, #0x0]\n" + "ldr q17, [x10, #0x10]\n" "subs x21, x21, #0x1\n" - "ldr q2, [x10, #0x20]\n" - "ldr q23, [x10, #0x30]\n" - "ldr q8, [x22, #0x0]\n" - "ldr q1, [x22, #0x10]\n" - "ldr q12, [x20, #0x0]\n" - "ldr q6, [x20, #0x10]\n" - "shl v17.16b, v21.16b, #0x4\n" - "shl v22.16b, v20.16b, #0x4\n" - "ldr q9, [x10, #0x40]\n" - "ldr q18, [x10, #0x50]\n" - "shl v4.16b, v2.16b, #0x4\n" - "shl v5.16b, v23.16b, #0x4\n" - "ldr q27, [x10, #0x60]\n" - "and v21.16b, v21.16b, v3.16b\n" - "and v20.16b, v20.16b, v3.16b\n" - ".inst 0x4e91a50d // smmla v13.4s, v8.16b, v17.16b\n" - ".inst 0x4e96a519 // smmla v25.4s, v8.16b, v22.16b\n" - ".inst 0x4e91a43a // smmla v26.4s, v1.16b, v17.16b\n" - "and v2.16b, v2.16b, v3.16b\n" - ".inst 0x4e84a50e // smmla v14.4s, v8.16b, v4.16b\n" - ".inst 0x4e85a510 // smmla v16.4s, v8.16b, v5.16b\n" - "ldr q8, [x10, #0x70]\n" - "and v23.16b, v23.16b, v3.16b\n" - ".inst 0x4e96a42a // smmla v10.4s, v1.16b, v22.16b\n" - ".inst 0x4e84a43e // smmla v30.4s, v1.16b, v4.16b\n" + "ldr q10, [x10, #0x20]\n" + "ldr q8, [x10, #0x30]\n" + "ldr q9, [x22, #0x0]\n" + "ldr q20, [x22, #0x10]\n" + "ldr q2, [x20, #0x0]\n" + "ldr q3, [x20, #0x10]\n" + "shl v23.16b, v0.16b, #0x4\n" + "shl v21.16b, v17.16b, #0x4\n" + "ldr q27, [x10, #0x40]\n" + "ldr q6, [x10, #0x50]\n" + "shl v16.16b, v10.16b, #0x4\n" + "shl v22.16b, v8.16b, #0x4\n" + "ldr q28, [x10, #0x60]\n" + "and v0.16b, v0.16b, v25.16b\n" + "and v17.16b, v17.16b, v25.16b\n" + ".inst 0x4e97a53d // smmla v29.4s, v9.16b, v23.16b\n" + ".inst 0x4e95a532 // smmla v18.4s, v9.16b, v21.16b\n" + ".inst 0x4e97a68b // smmla v11.4s, v20.16b, v23.16b\n" + "and v10.16b, v10.16b, v25.16b\n" + ".inst 0x4e90a52c // smmla v12.4s, v9.16b, v16.16b\n" + ".inst 0x4e96a524 // smmla v4.4s, v9.16b, v22.16b\n" + "ldr q9, [x10, #0x70]\n" + "and v8.16b, v8.16b, v25.16b\n" + ".inst 0x4e95a698 // smmla v24.4s, v20.16b, v21.16b\n" + ".inst 0x4e90a68e // smmla v14.4s, v20.16b, v16.16b\n" "add x10, x10, #0x80\n" - ".inst 0x4e85a433 // smmla v19.4s, v1.16b, v5.16b\n" - "ldr q1, [x22, #0x20]\n" - ".inst 0x4e91a598 // smmla v24.4s, v12.16b, v17.16b\n" - ".inst 0x4e96a59c // smmla v28.4s, v12.16b, v22.16b\n" - ".inst 0x4e84a580 // smmla v0.4s, v12.16b, v4.16b\n" - ".inst 0x4e85a58f // smmla v15.4s, v12.16b, v5.16b\n" - "ldr q12, [x22, #0x30]\n" - ".inst 0x4e91a4dd // smmla v29.4s, v6.16b, v17.16b\n" - "ldr q17, [x20, #0x20]\n" - ".inst 0x4e96a4df // smmla v31.4s, v6.16b, v22.16b\n" - "ldr q22, [x20, #0x30]\n" - ".inst 0x4e84a4cb // smmla v11.4s, v6.16b, v4.16b\n" - "ldr q4, [x22, #0x40]\n" - ".inst 0x4e85a4c7 // smmla v7.4s, v6.16b, v5.16b\n" - "ldr q5, [x22, #0x50]\n" - "shl v6.16b, v9.16b, #0x4\n" - "and v9.16b, v9.16b, v3.16b\n" - ".inst 0x4e86a42d // smmla v13.4s, v1.16b, v6.16b\n" - ".inst 0x4e86a59a // smmla v26.4s, v12.16b, v6.16b\n" - ".inst 0x4e86a638 // smmla v24.4s, v17.16b, v6.16b\n" - ".inst 0x4e86a6dd // smmla v29.4s, v22.16b, v6.16b\n" - "shl v6.16b, v18.16b, #0x4\n" - "and v18.16b, v18.16b, v3.16b\n" - ".inst 0x4e86a439 // smmla v25.4s, v1.16b, v6.16b\n" - ".inst 0x4e86a58a // smmla v10.4s, v12.16b, v6.16b\n" - ".inst 0x4e86a63c // smmla v28.4s, v17.16b, v6.16b\n" - ".inst 0x4e86a6df // smmla v31.4s, v22.16b, v6.16b\n" - "shl v6.16b, v27.16b, #0x4\n" - ".inst 0x4e95a48d // smmla v13.4s, v4.16b, v21.16b\n" - ".inst 0x4e95a4ba // smmla v26.4s, v5.16b, v21.16b\n" - "and v27.16b, v27.16b, v3.16b\n" - ".inst 0x4e86a42e // smmla v14.4s, v1.16b, v6.16b\n" - ".inst 0x4e86a59e // smmla v30.4s, v12.16b, v6.16b\n" - ".inst 0x4e86a620 // smmla v0.4s, v17.16b, v6.16b\n" - ".inst 0x4e86a6cb // smmla v11.4s, v22.16b, v6.16b\n" - "shl v6.16b, v8.16b, #0x4\n" - ".inst 0x4e94a499 // smmla v25.4s, v4.16b, v20.16b\n" - ".inst 0x4e94a4aa // smmla v10.4s, v5.16b, v20.16b\n" - "and v8.16b, v8.16b, v3.16b\n" - ".inst 0x4e86a430 // smmla v16.4s, v1.16b, v6.16b\n" - "ldr q1, [x20, #0x40]\n" - ".inst 0x4e86a593 // smmla v19.4s, v12.16b, v6.16b\n" - "ldr q12, [x20, #0x50]\n" - ".inst 0x4e86a62f // smmla v15.4s, v17.16b, v6.16b\n" - "ldr q17, [x22, #0x60]\n" - ".inst 0x4e86a6c7 // smmla v7.4s, v22.16b, v6.16b\n" + ".inst 0x4e96a681 // smmla v1.4s, v20.16b, v22.16b\n" + "ldr q20, [x22, #0x20]\n" + ".inst 0x4e97a453 // smmla v19.4s, v2.16b, v23.16b\n" + ".inst 0x4e95a45f // smmla v31.4s, v2.16b, v21.16b\n" + ".inst 0x4e90a45e // smmla v30.4s, v2.16b, v16.16b\n" + ".inst 0x4e96a45a // smmla v26.4s, v2.16b, v22.16b\n" + "ldr q2, [x22, #0x30]\n" + ".inst 0x4e97a465 // smmla v5.4s, v3.16b, v23.16b\n" + "ldr q23, [x20, #0x20]\n" + ".inst 0x4e95a467 // smmla v7.4s, v3.16b, v21.16b\n" + "ldr q21, [x20, #0x30]\n" + ".inst 0x4e90a46d // smmla v13.4s, v3.16b, v16.16b\n" + "ldr q16, [x22, #0x40]\n" + ".inst 0x4e96a46f // smmla v15.4s, v3.16b, v22.16b\n" + "ldr q3, [x22, #0x50]\n" + "shl v22.16b, v27.16b, #0x4\n" + "and v27.16b, v27.16b, v25.16b\n" + ".inst 0x4e96a69d // smmla v29.4s, v20.16b, v22.16b\n" + ".inst 0x4e96a44b // smmla v11.4s, v2.16b, v22.16b\n" + ".inst 0x4e96a6f3 // smmla v19.4s, v23.16b, v22.16b\n" + ".inst 0x4e96a6a5 // smmla v5.4s, v21.16b, v22.16b\n" + "shl v22.16b, v6.16b, #0x4\n" + "and v6.16b, v6.16b, v25.16b\n" + ".inst 0x4e96a692 // smmla v18.4s, v20.16b, v22.16b\n" + ".inst 0x4e96a458 // smmla v24.4s, v2.16b, v22.16b\n" + ".inst 0x4e96a6ff // smmla v31.4s, v23.16b, v22.16b\n" + ".inst 0x4e96a6a7 // smmla v7.4s, v21.16b, v22.16b\n" + "shl v22.16b, v28.16b, #0x4\n" + ".inst 0x4e80a61d // smmla v29.4s, v16.16b, v0.16b\n" + ".inst 0x4e80a46b // smmla v11.4s, v3.16b, v0.16b\n" + "and v28.16b, v28.16b, v25.16b\n" + ".inst 0x4e96a68c // smmla v12.4s, v20.16b, v22.16b\n" + ".inst 0x4e96a44e // smmla v14.4s, v2.16b, v22.16b\n" + ".inst 0x4e96a6fe // smmla v30.4s, v23.16b, v22.16b\n" + ".inst 0x4e96a6ad // smmla v13.4s, v21.16b, v22.16b\n" + "shl v22.16b, v9.16b, #0x4\n" + ".inst 0x4e91a612 // smmla v18.4s, v16.16b, v17.16b\n" + ".inst 0x4e91a478 // smmla v24.4s, v3.16b, v17.16b\n" + "and v9.16b, v9.16b, v25.16b\n" + ".inst 0x4e96a684 // smmla v4.4s, v20.16b, v22.16b\n" + "ldr q20, [x20, #0x40]\n" + ".inst 0x4e96a441 // smmla v1.4s, v2.16b, v22.16b\n" + "ldr q2, [x20, #0x50]\n" + ".inst 0x4e96a6fa // smmla v26.4s, v23.16b, v22.16b\n" + "ldr q23, [x22, #0x60]\n" + ".inst 0x4e96a6af // smmla v15.4s, v21.16b, v22.16b\n" "ldr q22, [x22, #0x70]\n" - "ldr q6, [x20, #0x60]\n" - ".inst 0x4e82a48e // smmla v14.4s, v4.16b, v2.16b\n" - ".inst 0x4e82a4be // smmla v30.4s, v5.16b, v2.16b\n" + "ldr q21, [x20, #0x60]\n" + ".inst 0x4e8aa60c // smmla v12.4s, v16.16b, v10.16b\n" + ".inst 0x4e8aa46e // smmla v14.4s, v3.16b, v10.16b\n" "add x22, x22, #0x80\n" - ".inst 0x4e95a438 // smmla v24.4s, v1.16b, v21.16b\n" - ".inst 0x4e94a43c // smmla v28.4s, v1.16b, v20.16b\n" - ".inst 0x4e97a490 // smmla v16.4s, v4.16b, v23.16b\n" - "ldr q4, [x20, #0x70]\n" - ".inst 0x4e97a4b3 // smmla v19.4s, v5.16b, v23.16b\n" + ".inst 0x4e80a693 // smmla v19.4s, v20.16b, v0.16b\n" + ".inst 0x4e91a69f // smmla v31.4s, v20.16b, v17.16b\n" + ".inst 0x4e88a604 // smmla v4.4s, v16.16b, v8.16b\n" + "ldr q16, [x20, #0x70]\n" + ".inst 0x4e88a461 // smmla v1.4s, v3.16b, v8.16b\n" "add x20, x20, #0x80\n" - ".inst 0x4e82a420 // smmla v0.4s, v1.16b, v2.16b\n" - ".inst 0x4e97a42f // smmla v15.4s, v1.16b, v23.16b\n" - ".inst 0x4e95a59d // smmla v29.4s, v12.16b, v21.16b\n" - ".inst 0x4e94a59f // smmla v31.4s, v12.16b, v20.16b\n" - ".inst 0x4e82a58b // smmla v11.4s, v12.16b, v2.16b\n" - ".inst 0x4e97a587 // smmla v7.4s, v12.16b, v23.16b\n" - ".inst 0x4e89a62d // smmla v13.4s, v17.16b, v9.16b\n" - ".inst 0x4e92a639 // smmla v25.4s, v17.16b, v18.16b\n" - ".inst 0x4e9ba62e // smmla v14.4s, v17.16b, v27.16b\n" - ".inst 0x4e88a630 // smmla v16.4s, v17.16b, v8.16b\n" - ".inst 0x4e89a6da // smmla v26.4s, v22.16b, v9.16b\n" - ".inst 0x4e92a6ca // smmla v10.4s, v22.16b, v18.16b\n" - ".inst 0x4e9ba6de // smmla v30.4s, v22.16b, v27.16b\n" - ".inst 0x4e88a6d3 // smmla v19.4s, v22.16b, v8.16b\n" - ".inst 0x4e89a4d8 // smmla v24.4s, v6.16b, v9.16b\n" - ".inst 0x4e92a4dc // smmla v28.4s, v6.16b, v18.16b\n" - ".inst 0x4e9ba4c0 // smmla v0.4s, v6.16b, v27.16b\n" - ".inst 0x4e88a4cf // smmla v15.4s, v6.16b, v8.16b\n" - ".inst 0x4e89a49d // smmla v29.4s, v4.16b, v9.16b\n" - ".inst 0x4e92a49f // smmla v31.4s, v4.16b, v18.16b\n" - ".inst 0x4e9ba48b // smmla v11.4s, v4.16b, v27.16b\n" - ".inst 0x4e88a487 // smmla v7.4s, v4.16b, v8.16b\n" + ".inst 0x4e8aa69e // smmla v30.4s, v20.16b, v10.16b\n" + ".inst 0x4e88a69a // smmla v26.4s, v20.16b, v8.16b\n" + ".inst 0x4e80a445 // smmla v5.4s, v2.16b, v0.16b\n" + ".inst 0x4e91a447 // smmla v7.4s, v2.16b, v17.16b\n" + ".inst 0x4e8aa44d // smmla v13.4s, v2.16b, v10.16b\n" + ".inst 0x4e88a44f // smmla v15.4s, v2.16b, v8.16b\n" + ".inst 0x4e9ba6fd // smmla v29.4s, v23.16b, v27.16b\n" + ".inst 0x4e86a6f2 // smmla v18.4s, v23.16b, v6.16b\n" + ".inst 0x4e9ca6ec // smmla v12.4s, v23.16b, v28.16b\n" + ".inst 0x4e89a6e4 // smmla v4.4s, v23.16b, v9.16b\n" + ".inst 0x4e9ba6cb // smmla v11.4s, v22.16b, v27.16b\n" + ".inst 0x4e86a6d8 // smmla v24.4s, v22.16b, v6.16b\n" + ".inst 0x4e9ca6ce // smmla v14.4s, v22.16b, v28.16b\n" + ".inst 0x4e89a6c1 // smmla v1.4s, v22.16b, v9.16b\n" + ".inst 0x4e9ba6b3 // smmla v19.4s, v21.16b, v27.16b\n" + ".inst 0x4e86a6bf // smmla v31.4s, v21.16b, v6.16b\n" + ".inst 0x4e9ca6be // smmla v30.4s, v21.16b, v28.16b\n" + ".inst 0x4e89a6ba // smmla v26.4s, v21.16b, v9.16b\n" + ".inst 0x4e9ba605 // smmla v5.4s, v16.16b, v27.16b\n" + ".inst 0x4e86a607 // smmla v7.4s, v16.16b, v6.16b\n" + ".inst 0x4e9ca60d // smmla v13.4s, v16.16b, v28.16b\n" + ".inst 0x4e89a60f // smmla v15.4s, v16.16b, v9.16b\n" "bgt 3b\n" - "ldr q18, [x10, #0x0]\n" - "ldr q2, [x10, #0x10]\n" - "uzp1 v4.2d, v13.2d, v25.2d\n" - "uzp1 v5.2d, v14.2d, v16.2d\n" - "ldr q22, [x22, #0x0]\n" - "ldr q27, [x20, #0x0]\n" - "uzp2 v1.2d, v13.2d, v25.2d\n" - "uzp2 v20.2d, v14.2d, v16.2d\n" - "ldr q17, [x10, #0x20]\n" - "ldr q6, [x10, #0x30]\n" - "uzp1 v9.2d, v26.2d, v10.2d\n" - "uzp1 v13.2d, v30.2d, v19.2d\n" - "ldr q23, [x22, #0x10]\n" - "ldr q12, [x20, #0x10]\n" - "uzp2 v21.2d, v26.2d, v10.2d\n" - "uzp2 v25.2d, v30.2d, v19.2d\n" - "ld1r { v8.4s }, [%x[clamp_vals]]\n" - "uzp1 v16.2d, v24.2d, v28.2d\n" - "uzp1 v10.2d, v0.2d, v15.2d\n" - "mla v4.4s, v18.4s, v22.s[0]\n" - "uzp2 v30.2d, v24.2d, v28.2d\n" - "uzp2 v28.2d, v0.2d, v15.2d\n" - "mla v5.4s, v2.4s, v22.s[0]\n" - "add x20, %x[clamp_vals], #0x4\n" - "ld1r { v24.4s }, [x20]\n" - "uzp1 v14.2d, v29.2d, v31.2d\n" - "uzp1 v26.2d, v11.2d, v7.2d\n" - "mla v1.4s, v18.4s, v22.s[1]\n" - "uzp2 v0.2d, v29.2d, v31.2d\n" - "uzp2 v11.2d, v11.2d, v7.2d\n" - "mla v20.4s, v2.4s, v22.s[1]\n" - "cmp x9, #0x8\n" - "mla v9.4s, v18.4s, v22.s[2]\n" - "mla v13.4s, v2.4s, v22.s[2]\n" - "scvtf v4.4s, v4.4s\n" + "ldr q6, [x10, #0x0]\n" + "ldr q22, [x10, #0x10]\n" + "uzp1 v9.2d, v29.2d, v18.2d\n" + "uzp2 v2.2d, v29.2d, v18.2d\n" + "ld1 { v21.4s }, [x22]\n" + "ldr q20, [x10, #0x20]\n" + "uzp1 v16.2d, v12.2d, v4.2d\n" + "uzp2 v23.2d, v12.2d, v4.2d\n" + "ldr q10, [x10, #0x30]\n" + "uzp1 v17.2d, v11.2d, v24.2d\n" + "uzp2 v24.2d, v11.2d, v24.2d\n" + "add x22, x22, #0x10\n" + "ldr q28, [x22, #0x0]\n" + "uzp1 v0.2d, v14.2d, v1.2d\n" + "uzp2 v27.2d, v14.2d, v1.2d\n" "add x10, x10, #0x40\n" - "mla v21.4s, v18.4s, v22.s[3]\n" - "mla v25.4s, v2.4s, v22.s[3]\n" - "fmul v19.4s, v17.4s, v23.s[0]\n" - "mla v16.4s, v18.4s, v27.s[0]\n" - "mla v10.4s, v2.4s, v27.s[0]\n" - "scvtf v5.4s, v5.4s\n" - "mla v30.4s, v18.4s, v27.s[1]\n" - "mla v28.4s, v2.4s, v27.s[1]\n" - "fmul v15.4s, v6.4s, v23.s[0]\n" - "mla v14.4s, v18.4s, v27.s[2]\n" - "mla v26.4s, v2.4s, v27.s[2]\n" - "scvtf v1.4s, v1.4s\n" - "mla v0.4s, v18.4s, v27.s[3]\n" - "mla v11.4s, v2.4s, v27.s[3]\n" - "fmul v22.4s, v17.4s, v23.s[1]\n" - "scvtf v20.4s, v20.4s\n" - "fmul v29.4s, v6.4s, v23.s[1]\n" + "mla v9.4s, v6.4s, v21.s[0]\n" + "mla v16.4s, v22.4s, v21.s[0]\n" + "mla v2.4s, v6.4s, v21.s[1]\n" + "mla v23.4s, v22.4s, v21.s[1]\n" + "mla v17.4s, v6.4s, v21.s[2]\n" + "mla v0.4s, v22.4s, v21.s[2]\n" + "fmul v12.4s, v20.4s, v28.s[0]\n" + "mla v24.4s, v6.4s, v21.s[3]\n" + "mla v27.4s, v22.4s, v21.s[3]\n" + "fmul v11.4s, v10.4s, v28.s[0]\n" "scvtf v9.4s, v9.4s\n" - "fmul v2.4s, v17.4s, v23.s[2]\n" - "scvtf v13.4s, v13.4s\n" - "fmul v18.4s, v6.4s, v23.s[2]\n" - "scvtf v21.4s, v21.4s\n" - "fmul v31.4s, v17.4s, v23.s[3]\n" - "scvtf v25.4s, v25.4s\n" - "fmul v7.4s, v6.4s, v23.s[3]\n" "scvtf v16.4s, v16.4s\n" - "fmul v27.4s, v17.4s, v12.s[0]\n" - "scvtf v10.4s, v10.4s\n" - "fmul v23.4s, v6.4s, v12.s[0]\n" - "scvtf v30.4s, v30.4s\n" - "scvtf v28.4s, v28.4s\n" - "scvtf v14.4s, v14.4s\n" - "scvtf v26.4s, v26.4s\n" + "fmul v18.4s, v20.4s, v28.s[1]\n" + "scvtf v2.4s, v2.4s\n" + "fmul v1.4s, v10.4s, v28.s[1]\n" + "scvtf v23.4s, v23.4s\n" + "fmul v14.4s, v20.4s, v28.s[2]\n" + "scvtf v17.4s, v17.4s\n" + "fmul v3.4s, v10.4s, v28.s[2]\n" "scvtf v0.4s, v0.4s\n" - "scvtf v11.4s, v11.4s\n" - "fmul v4.4s, v4.4s, v19.4s\n" - "fmul v19.4s, v17.4s, v12.s[1]\n" - "fmul v5.4s, v5.4s, v15.4s\n" - "fmul v15.4s, v6.4s, v12.s[1]\n" - "fmul v1.4s, v1.4s, v22.4s\n" - "fmul v22.4s, v17.4s, v12.s[2]\n" - "fmul v17.4s, v17.4s, v12.s[3]\n" - "fmul v20.4s, v20.4s, v29.4s\n" - "fmul v29.4s, v6.4s, v12.s[2]\n" - "fmul v12.4s, v6.4s, v12.s[3]\n" - "fmul v9.4s, v9.4s, v2.4s\n" - "fmul v13.4s, v13.4s, v18.4s\n" - "fmul v21.4s, v21.4s, v31.4s\n" - "fmul v25.4s, v25.4s, v7.4s\n" - "fmul v16.4s, v16.4s, v27.4s\n" - "fmul v10.4s, v10.4s, v23.4s\n" - "fmul v30.4s, v30.4s, v19.4s\n" - "fmul v28.4s, v28.4s, v15.4s\n" - "fmul v14.4s, v14.4s, v22.4s\n" - "fmul v26.4s, v26.4s, v29.4s\n" - "fmul v0.4s, v0.4s, v17.4s\n" - "fmul v11.4s, v11.4s, v12.4s\n" - "fmax v4.4s, v4.4s, v8.4s\n" - "fmax v5.4s, v5.4s, v8.4s\n" - "fmax v1.4s, v1.4s, v8.4s\n" - "fmax v20.4s, v20.4s, v8.4s\n" - "fmax v9.4s, v9.4s, v8.4s\n" - "fmax v13.4s, v13.4s, v8.4s\n" - "fmax v21.4s, v21.4s, v8.4s\n" - "fmax v25.4s, v25.4s, v8.4s\n" - "fmax v16.4s, v16.4s, v8.4s\n" - "fmax v10.4s, v10.4s, v8.4s\n" - "fmax v30.4s, v30.4s, v8.4s\n" - "fmax v28.4s, v28.4s, v8.4s\n" - "fmax v14.4s, v14.4s, v8.4s\n" - "fmax v26.4s, v26.4s, v8.4s\n" - "fmax v0.4s, v0.4s, v8.4s\n" - "fmax v11.4s, v11.4s, v8.4s\n" - "fmin v4.4s, v4.4s, v24.4s\n" - "fmin v5.4s, v5.4s, v24.4s\n" - "fmin v1.4s, v1.4s, v24.4s\n" - "fmin v20.4s, v20.4s, v24.4s\n" - "fmin v9.4s, v9.4s, v24.4s\n" - "fmin v13.4s, v13.4s, v24.4s\n" - "fmin v21.4s, v21.4s, v24.4s\n" - "fmin v25.4s, v25.4s, v24.4s\n" - "fmin v16.4s, v16.4s, v24.4s\n" - "fmin v10.4s, v10.4s, v24.4s\n" - "fmin v30.4s, v30.4s, v24.4s\n" - "fmin v28.4s, v28.4s, v24.4s\n" - "fmin v14.4s, v14.4s, v24.4s\n" - "fmin v26.4s, v26.4s, v24.4s\n" - "fmin v0.4s, v0.4s, v24.4s\n" - "fmin v11.4s, v11.4s, v24.4s\n" - "bge 8f\n" - "mov x27, %x[dst]\n" - "add x26, x27, %x[dst_stride_row], LSL #2\n" - "add x25, x26, %x[dst_stride_row], LSL #1\n" - "add x24, x26, %x[dst_stride_row]\n" - "add x23, x25, %x[dst_stride_row]\n" - "add x22, x27, %x[dst_stride_row], LSL #1\n" - "add x21, x27, %x[dst_stride_row]\n" - "add x20, x22, %x[dst_stride_row]\n" - "tbz x9, #2, 5f\n" - "st1 { v0.4s }, [x23], #0x10\n" - "st1 { v14.4s }, [x25], #0x10\n" - "st1 { v30.4s }, [x24], #0x10\n" - "st1 { v16.4s }, [x26], #0x10\n" - "st1 { v21.4s }, [x20], #0x10\n" - "st1 { v9.4s }, [x22], #0x10\n" - "st1 { v1.4s }, [x21], #0x10\n" - "st1 { v4.4s }, [x27], #0x10\n" - "tbz x9, #1, 4f\n" - "str d11, [x23], #0x8\n" - "str d26, [x25], #0x8\n" - "str d28, [x24], #0x8\n" - "str d10, [x26], #0x8\n" - "str d25, [x20], #0x8\n" - "str d13, [x22], #0x8\n" - "str d20, [x21], #0x8\n" - "str d5, [x27], #0x8\n" - "tbz x9, #0, 7f\n" - "st1 { v11.s }[2], [x23]\n" - "st1 { v26.s }[2], [x25]\n" - "st1 { v28.s }[2], [x24]\n" - "st1 { v10.s }[2], [x26]\n" - "st1 { v25.s }[2], [x20]\n" - "st1 { v13.s }[2], [x22]\n" - "st1 { v20.s }[2], [x21]\n" - "st1 { v5.s }[2], [x27]\n" - "b 7f\n" - "4:" // Output block 0: partial_1_4 - "tbz x9, #0, 7f\n" - "str s11, [x23, #0x0]\n" - "str s26, [x25, #0x0]\n" - "str s28, [x24, #0x0]\n" - "str s10, [x26, #0x0]\n" - "str s25, [x20, #0x0]\n" - "str s13, [x22, #0x0]\n" - "str s20, [x21, #0x0]\n" - "str s5, [x27, #0x0]\n" - "b 7f\n" - "5:" // Output block 0: partial_2_0 - "tbz x9, #1, 6f\n" - "str d0, [x23], #0x8\n" - "str d14, [x25], #0x8\n" - "str d30, [x24], #0x8\n" - "str d16, [x26], #0x8\n" - "str d21, [x20], #0x8\n" - "str d9, [x22], #0x8\n" - "str d1, [x21], #0x8\n" - "str d4, [x27], #0x8\n" - "tbz x9, #0, 7f\n" - "st1 { v0.s }[2], [x23]\n" - "st1 { v14.s }[2], [x25]\n" - "st1 { v30.s }[2], [x24]\n" - "st1 { v16.s }[2], [x26]\n" - "st1 { v21.s }[2], [x20]\n" - "st1 { v9.s }[2], [x22]\n" - "st1 { v1.s }[2], [x21]\n" - "st1 { v4.s }[2], [x27]\n" - "b 7f\n" - "6:" // Output block 0: partial_1_0 - "str s0, [x23, #0x0]\n" - "str s14, [x25, #0x0]\n" - "str s30, [x24, #0x0]\n" - "str s16, [x26, #0x0]\n" - "str s21, [x20, #0x0]\n" - "str s9, [x22, #0x0]\n" - "str s1, [x21, #0x0]\n" - "str s4, [x27, #0x0]\n" - "7:" // Output block 0: Done - "b 9f\n" - "8:" // Full output + "fmul v8.4s, v20.4s, v28.s[3]\n" + "scvtf v24.4s, v24.4s\n" + "fmul v28.4s, v10.4s, v28.s[3]\n" + "scvtf v27.4s, v27.4s\n" + "fmul v29.4s, v9.4s, v12.4s\n" + "fmul v12.4s, v16.4s, v11.4s\n" + "fmul v18.4s, v2.4s, v18.4s\n" + "fmul v4.4s, v23.4s, v1.4s\n" + "fmul v11.4s, v17.4s, v14.4s\n" + "fmul v14.4s, v0.4s, v3.4s\n" + "fmul v24.4s, v24.4s, v8.4s\n" + "fmul v1.4s, v27.4s, v28.4s\n" + "ld1 { v0.4s }, [x20]\n" + "uzp1 v23.2d, v19.2d, v31.2d\n" + "uzp2 v2.2d, v19.2d, v31.2d\n" + "add x20, x20, #0x10\n" + "ldr q16, [x20, #0x0]\n" + "uzp1 v3.2d, v30.2d, v26.2d\n" + "uzp2 v21.2d, v30.2d, v26.2d\n" + "uzp1 v27.2d, v5.2d, v7.2d\n" + "uzp2 v9.2d, v5.2d, v7.2d\n" + "uzp1 v7.2d, v13.2d, v15.2d\n" + "uzp2 v28.2d, v13.2d, v15.2d\n" + "mla v23.4s, v6.4s, v0.s[0]\n" + "mla v3.4s, v22.4s, v0.s[0]\n" + "mla v2.4s, v6.4s, v0.s[1]\n" + "fmul v30.4s, v20.4s, v16.s[0]\n" + "mla v21.4s, v22.4s, v0.s[1]\n" + "mla v27.4s, v6.4s, v0.s[2]\n" + "fmul v5.4s, v10.4s, v16.s[0]\n" + "mla v7.4s, v22.4s, v0.s[2]\n" + "mla v9.4s, v6.4s, v0.s[3]\n" + "fmul v15.4s, v20.4s, v16.s[1]\n" + "mla v28.4s, v22.4s, v0.s[3]\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v3.4s, v3.4s\n" + "scvtf v2.4s, v2.4s\n" + "fmul v6.4s, v10.4s, v16.s[1]\n" + "scvtf v21.4s, v21.4s\n" + "fmul v13.4s, v20.4s, v16.s[2]\n" + "scvtf v27.4s, v27.4s\n" + "fmul v8.4s, v10.4s, v16.s[2]\n" + "scvtf v7.4s, v7.4s\n" + "fmul v0.4s, v20.4s, v16.s[3]\n" + "scvtf v9.4s, v9.4s\n" + "fmul v16.4s, v10.4s, v16.s[3]\n" + "scvtf v28.4s, v28.4s\n" + "fmul v19.4s, v23.4s, v30.4s\n" + "fmul v30.4s, v3.4s, v5.4s\n" + "fmul v31.4s, v2.4s, v15.4s\n" + "fmul v26.4s, v21.4s, v6.4s\n" + "fmul v5.4s, v27.4s, v13.4s\n" + "fmul v13.4s, v7.4s, v8.4s\n" + "fmul v7.4s, v9.4s, v0.4s\n" + "fmul v15.4s, v28.4s, v16.4s\n" + "ld1r { v3.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x9, #0x8\n" + "ld1r { v16.4s }, [x20]\n" + "add x10, x10, #0x20\n" + "fmax v29.4s, v29.4s, v3.4s\n" + "fmax v12.4s, v12.4s, v3.4s\n" + "fmax v18.4s, v18.4s, v3.4s\n" + "fmax v4.4s, v4.4s, v3.4s\n" + "fmax v11.4s, v11.4s, v3.4s\n" + "fmax v14.4s, v14.4s, v3.4s\n" + "fmax v24.4s, v24.4s, v3.4s\n" + "fmax v1.4s, v1.4s, v3.4s\n" + "fmax v19.4s, v19.4s, v3.4s\n" + "fmax v30.4s, v30.4s, v3.4s\n" + "fmax v31.4s, v31.4s, v3.4s\n" + "fmax v26.4s, v26.4s, v3.4s\n" + "fmax v5.4s, v5.4s, v3.4s\n" + "fmax v13.4s, v13.4s, v3.4s\n" + "fmax v7.4s, v7.4s, v3.4s\n" + "fmax v15.4s, v15.4s, v3.4s\n" + "fmin v29.4s, v29.4s, v16.4s\n" + "fmin v12.4s, v12.4s, v16.4s\n" + "fmin v18.4s, v18.4s, v16.4s\n" + "fmin v4.4s, v4.4s, v16.4s\n" + "fmin v11.4s, v11.4s, v16.4s\n" + "fmin v14.4s, v14.4s, v16.4s\n" + "fmin v24.4s, v24.4s, v16.4s\n" + "fmin v1.4s, v1.4s, v16.4s\n" + "fmin v19.4s, v19.4s, v16.4s\n" + "fmin v30.4s, v30.4s, v16.4s\n" + "fmin v31.4s, v31.4s, v16.4s\n" + "fmin v26.4s, v26.4s, v16.4s\n" + "fmin v5.4s, v5.4s, v16.4s\n" + "fmin v13.4s, v13.4s, v16.4s\n" + "fmin v7.4s, v7.4s, v16.4s\n" + "fmin v15.4s, v15.4s, v16.4s\n" + "blt 6f\n" "mov x20, %x[dst]\n" - "str q4, [x20, #0x0]\n" - "str q5, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q1, [x20, #0x0]\n" - "str q20, [x20, #0x10]\n" + "str q29, [x20, #0x0]\n" + "str q12, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q9, [x20, #0x0]\n" - "str q13, [x20, #0x10]\n" + "str q18, [x20, #0x0]\n" + "str q4, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q21, [x20, #0x0]\n" - "str q25, [x20, #0x10]\n" + "str q11, [x20, #0x0]\n" + "str q14, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q16, [x20, #0x0]\n" - "str q10, [x20, #0x10]\n" + "str q24, [x20, #0x0]\n" + "str q1, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q30, [x20, #0x0]\n" - "str q28, [x20, #0x10]\n" + "str q19, [x20, #0x0]\n" + "str q30, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q14, [x20, #0x0]\n" + "str q31, [x20, #0x0]\n" "str q26, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q0, [x20, #0x0]\n" - "str q11, [x20, #0x10]\n" - "9:" // Output stage exit + "str q5, [x20, #0x0]\n" + "str q13, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q7, [x20, #0x0]\n" + "str q15, [x20, #0x10]\n" + "b 11f\n" + "6:" // Partial output + "mov x27, %x[dst]\n" + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x27, %x[dst_stride_row], LSL #1\n" + "add x21, x27, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "tbz x9, #2, 8f\n" + "st1 { v7.4s }, [x23], #0x10\n" + "st1 { v5.4s }, [x25], #0x10\n" + "st1 { v31.4s }, [x24], #0x10\n" + "st1 { v19.4s }, [x26], #0x10\n" + "st1 { v24.4s }, [x20], #0x10\n" + "st1 { v11.4s }, [x22], #0x10\n" + "st1 { v18.4s }, [x21], #0x10\n" + "st1 { v29.4s }, [x27], #0x10\n" + "tbz x9, #1, 7f\n" + "st1 { v15.d }[0], [x23], #0x8\n" + "st1 { v13.d }[0], [x25], #0x8\n" + "st1 { v26.d }[0], [x24], #0x8\n" + "st1 { v30.d }[0], [x26], #0x8\n" + "st1 { v1.d }[0], [x20], #0x8\n" + "st1 { v14.d }[0], [x22], #0x8\n" + "st1 { v4.d }[0], [x21], #0x8\n" + "st1 { v12.d }[0], [x27], #0x8\n" + "tbz x9, #0, 10f\n" + "st1 { v15.s }[2], [x23]\n" + "st1 { v13.s }[2], [x25]\n" + "st1 { v26.s }[2], [x24]\n" + "st1 { v30.s }[2], [x26]\n" + "st1 { v1.s }[2], [x20]\n" + "st1 { v14.s }[2], [x22]\n" + "st1 { v4.s }[2], [x21]\n" + "st1 { v12.s }[2], [x27]\n" + "b 10f\n" + "7:" // Output block 0: partial_1_4 + "tbz x9, #0, 10f\n" + "st1 { v15.s }[0], [x23]\n" + "st1 { v13.s }[0], [x25]\n" + "st1 { v26.s }[0], [x24]\n" + "st1 { v30.s }[0], [x26]\n" + "st1 { v1.s }[0], [x20]\n" + "st1 { v14.s }[0], [x22]\n" + "st1 { v4.s }[0], [x21]\n" + "st1 { v12.s }[0], [x27]\n" + "b 10f\n" + "8:" // Output block 0: partial_2_0 + "tbz x9, #1, 9f\n" + "st1 { v7.d }[0], [x23], #0x8\n" + "st1 { v5.d }[0], [x25], #0x8\n" + "st1 { v31.d }[0], [x24], #0x8\n" + "st1 { v19.d }[0], [x26], #0x8\n" + "st1 { v24.d }[0], [x20], #0x8\n" + "st1 { v11.d }[0], [x22], #0x8\n" + "st1 { v18.d }[0], [x21], #0x8\n" + "st1 { v29.d }[0], [x27], #0x8\n" + "tbz x9, #0, 10f\n" + "st1 { v7.s }[2], [x23]\n" + "st1 { v5.s }[2], [x25]\n" + "st1 { v31.s }[2], [x24]\n" + "st1 { v19.s }[2], [x26]\n" + "st1 { v24.s }[2], [x20]\n" + "st1 { v11.s }[2], [x22]\n" + "st1 { v18.s }[2], [x21]\n" + "st1 { v29.s }[2], [x27]\n" + "b 10f\n" + "9:" // Output block 0: partial_1_0 + "st1 { v7.s }[0], [x23]\n" + "st1 { v5.s }[0], [x25]\n" + "st1 { v31.s }[0], [x24]\n" + "st1 { v19.s }[0], [x26]\n" + "st1 { v24.s }[0], [x20]\n" + "st1 { v11.s }[0], [x22]\n" + "st1 { v18.s }[0], [x21]\n" + "st1 { v29.s }[0], [x27]\n" + "10:" // Output block 0: Done + "11:" // Output stage exit "subs x9, x9, #0x8\n" "add %x[dst], %x[dst], #0x20\n" "bgt 2b\n" @@ -491,160 +495,182 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( "mov %x[dst], x28\n" "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" "bge 1b\n" - "10:" // Row loop skip - "cbz x12, 20f\n" - "11:" // Row tail: Row loop + "12:" // Row loop skip + "cbz x12, 23f\n" + "13:" // Row tail: Row loop "mov x26, %x[rhs_packed]\n" "mov x25, %x[n]\n" "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" - "12:" // Row tail: Column loop - "movi v13.4s, #0x0\n" - "movi v14.4s, #0x0\n" + "14:" // Row tail: Column loop "mov x22, %x[lhs_packed]\n" + "movi v29.4s, #0x0\n" + "movi v12.4s, #0x0\n" "mov x20, %x[num_blocks]\n" - "movi v25.4s, #0x0\n" - "movi v16.4s, #0x0\n" - "movi v26.4s, #0x0\n" - "movi v30.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "movi v19.4s, #0x0\n" - "13:" // Row tail: Block loop - "ldr q4, [x26, #0x0]\n" - "ldr q8, [x26, #0x10]\n" + "movi v18.4s, #0x0\n" + "movi v4.4s, #0x0\n" + "movi v11.4s, #0x0\n" + "movi v14.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v1.4s, #0x0\n" + "15:" // Row tail: Sub block loop + "ldr q16, [x26, #0x0]\n" + "ldr q7, [x26, #0x10]\n" "subs x20, x20, #0x1\n" - "ldr q2, [x26, #0x20]\n" - "ldr q11, [x26, #0x30]\n" - "ldr q18, [x22, #0x0]\n" - "ldr q15, [x22, #0x10]\n" - "ldr q12, [x26, #0x40]\n" - "ldr q6, [x26, #0x50]\n" - "shl v9.16b, v4.16b, #0x4\n" - "shl v22.16b, v8.16b, #0x4\n" - "ldr q28, [x26, #0x60]\n" - "ldr q27, [x26, #0x70]\n" - "shl v17.16b, v2.16b, #0x4\n" - "shl v23.16b, v11.16b, #0x4\n" - "ldr q31, [x22, #0x20]\n" - "ldr q7, [x22, #0x30]\n" - "and v4.16b, v4.16b, v3.16b\n" - "and v8.16b, v8.16b, v3.16b\n" - "ldr q24, [x22, #0x40]\n" - "ldr q1, [x22, #0x50]\n" - ".inst 0x4e89a64d // smmla v13.4s, v18.16b, v9.16b\n" - ".inst 0x4e96a659 // smmla v25.4s, v18.16b, v22.16b\n" - "ldr q21, [x22, #0x60]\n" - "ldr q20, [x22, #0x70]\n" - ".inst 0x4e91a64e // smmla v14.4s, v18.16b, v17.16b\n" - ".inst 0x4e97a650 // smmla v16.4s, v18.16b, v23.16b\n" - ".inst 0x4e89a5fa // smmla v26.4s, v15.16b, v9.16b\n" - ".inst 0x4e96a5ea // smmla v10.4s, v15.16b, v22.16b\n" - "shl v22.16b, v12.16b, #0x4\n" - "add x22, x22, #0x80\n" - ".inst 0x4e91a5fe // smmla v30.4s, v15.16b, v17.16b\n" - ".inst 0x4e97a5f3 // smmla v19.4s, v15.16b, v23.16b\n" - "shl v17.16b, v6.16b, #0x4\n" + "ldr q6, [x26, #0x20]\n" + "ldr q17, [x26, #0x30]\n" + "ldr q20, [x22, #0x0]\n" + "ldr q3, [x22, #0x10]\n" + "ldr q26, [x26, #0x40]\n" + "ldr q21, [x26, #0x50]\n" + "shl v13.16b, v16.16b, #0x4\n" + "shl v28.16b, v7.16b, #0x4\n" + "ldr q31, [x26, #0x60]\n" + "ldr q8, [x26, #0x70]\n" + "shl v2.16b, v6.16b, #0x4\n" + "shl v0.16b, v17.16b, #0x4\n" + "ldr q10, [x22, #0x20]\n" + "ldr q9, [x22, #0x30]\n" + "and v16.16b, v16.16b, v25.16b\n" + "and v7.16b, v7.16b, v25.16b\n" + "ldr q27, [x22, #0x40]\n" + "ldr q23, [x22, #0x50]\n" + ".inst 0x4e8da69d // smmla v29.4s, v20.16b, v13.16b\n" + ".inst 0x4e9ca692 // smmla v18.4s, v20.16b, v28.16b\n" + "ldr q22, [x22, #0x60]\n" + "ldr q15, [x22, #0x70]\n" + ".inst 0x4e82a68c // smmla v12.4s, v20.16b, v2.16b\n" + ".inst 0x4e80a684 // smmla v4.4s, v20.16b, v0.16b\n" + ".inst 0x4e8da46b // smmla v11.4s, v3.16b, v13.16b\n" + ".inst 0x4e9ca478 // smmla v24.4s, v3.16b, v28.16b\n" + "shl v20.16b, v26.16b, #0x4\n" "add x26, x26, #0x80\n" - "shl v23.16b, v28.16b, #0x4\n" - "shl v5.16b, v27.16b, #0x4\n" - ".inst 0x4e96a7ed // smmla v13.4s, v31.16b, v22.16b\n" - "and v2.16b, v2.16b, v3.16b\n" - "and v11.16b, v11.16b, v3.16b\n" - ".inst 0x4e91a7f9 // smmla v25.4s, v31.16b, v17.16b\n" - ".inst 0x4e96a4fa // smmla v26.4s, v7.16b, v22.16b\n" - ".inst 0x4e91a4ea // smmla v10.4s, v7.16b, v17.16b\n" - "and v12.16b, v12.16b, v3.16b\n" - ".inst 0x4e97a7ee // smmla v14.4s, v31.16b, v23.16b\n" - ".inst 0x4e85a7f0 // smmla v16.4s, v31.16b, v5.16b\n" - "and v6.16b, v6.16b, v3.16b\n" - ".inst 0x4e97a4fe // smmla v30.4s, v7.16b, v23.16b\n" - ".inst 0x4e85a4f3 // smmla v19.4s, v7.16b, v5.16b\n" - "and v28.16b, v28.16b, v3.16b\n" - ".inst 0x4e84a70d // smmla v13.4s, v24.16b, v4.16b\n" - ".inst 0x4e88a719 // smmla v25.4s, v24.16b, v8.16b\n" - "and v27.16b, v27.16b, v3.16b\n" - ".inst 0x4e84a43a // smmla v26.4s, v1.16b, v4.16b\n" - ".inst 0x4e88a42a // smmla v10.4s, v1.16b, v8.16b\n" - ".inst 0x4e82a70e // smmla v14.4s, v24.16b, v2.16b\n" - ".inst 0x4e8ba710 // smmla v16.4s, v24.16b, v11.16b\n" - ".inst 0x4e82a43e // smmla v30.4s, v1.16b, v2.16b\n" - ".inst 0x4e8ba433 // smmla v19.4s, v1.16b, v11.16b\n" - ".inst 0x4e8ca6ad // smmla v13.4s, v21.16b, v12.16b\n" - ".inst 0x4e86a6b9 // smmla v25.4s, v21.16b, v6.16b\n" - ".inst 0x4e8ca69a // smmla v26.4s, v20.16b, v12.16b\n" - ".inst 0x4e86a68a // smmla v10.4s, v20.16b, v6.16b\n" - ".inst 0x4e9ca6ae // smmla v14.4s, v21.16b, v28.16b\n" - ".inst 0x4e9ba6b0 // smmla v16.4s, v21.16b, v27.16b\n" - ".inst 0x4e9ca69e // smmla v30.4s, v20.16b, v28.16b\n" - ".inst 0x4e9ba693 // smmla v19.4s, v20.16b, v27.16b\n" - "bgt 13b\n" - "ldr q5, [x26, #0x0]\n" - "ldr q20, [x26, #0x10]\n" - "uzp1 v2.2d, v13.2d, v25.2d\n" - "uzp1 v21.2d, v14.2d, v16.2d\n" - "ldr q6, [x22, #0x0]\n" - "ldr q1, [x26, #0x20]\n" - "uzp2 v4.2d, v13.2d, v25.2d\n" - "uzp2 v28.2d, v14.2d, v16.2d\n" - "ldr q7, [x26, #0x30]\n" - "ldr q17, [x22, #0x10]\n" - "uzp1 v29.2d, v26.2d, v10.2d\n" - "uzp1 v15.2d, v30.2d, v19.2d\n" - "ld1r { v27.4s }, [%x[clamp_vals]]\n" - "uzp2 v26.2d, v26.2d, v10.2d\n" - "uzp2 v25.2d, v30.2d, v19.2d\n" - "add x20, %x[clamp_vals], #0x4\n" - "ld1r { v19.4s }, [x20]\n" - "mla v2.4s, v5.4s, v6.s[0]\n" - "mla v21.4s, v20.4s, v6.s[0]\n" - "cmp x25, #0x8\n" - "mla v4.4s, v5.4s, v6.s[1]\n" - "mla v28.4s, v20.4s, v6.s[1]\n" - "fmul v23.4s, v1.4s, v17.s[0]\n" + ".inst 0x4e82a46e // smmla v14.4s, v3.16b, v2.16b\n" + ".inst 0x4e80a461 // smmla v1.4s, v3.16b, v0.16b\n" + "shl v0.16b, v21.16b, #0x4\n" + "add x22, x22, #0x80\n" + "shl v5.16b, v31.16b, #0x4\n" + "shl v13.16b, v8.16b, #0x4\n" + ".inst 0x4e94a55d // smmla v29.4s, v10.16b, v20.16b\n" + "and v6.16b, v6.16b, v25.16b\n" + "and v17.16b, v17.16b, v25.16b\n" + ".inst 0x4e80a552 // smmla v18.4s, v10.16b, v0.16b\n" + ".inst 0x4e94a52b // smmla v11.4s, v9.16b, v20.16b\n" + ".inst 0x4e80a538 // smmla v24.4s, v9.16b, v0.16b\n" + "and v26.16b, v26.16b, v25.16b\n" + ".inst 0x4e85a54c // smmla v12.4s, v10.16b, v5.16b\n" + ".inst 0x4e8da544 // smmla v4.4s, v10.16b, v13.16b\n" + "and v21.16b, v21.16b, v25.16b\n" + ".inst 0x4e85a52e // smmla v14.4s, v9.16b, v5.16b\n" + ".inst 0x4e8da521 // smmla v1.4s, v9.16b, v13.16b\n" + "and v31.16b, v31.16b, v25.16b\n" + ".inst 0x4e90a77d // smmla v29.4s, v27.16b, v16.16b\n" + ".inst 0x4e87a772 // smmla v18.4s, v27.16b, v7.16b\n" + "and v8.16b, v8.16b, v25.16b\n" + ".inst 0x4e90a6eb // smmla v11.4s, v23.16b, v16.16b\n" + ".inst 0x4e87a6f8 // smmla v24.4s, v23.16b, v7.16b\n" + ".inst 0x4e86a76c // smmla v12.4s, v27.16b, v6.16b\n" + ".inst 0x4e91a764 // smmla v4.4s, v27.16b, v17.16b\n" + ".inst 0x4e86a6ee // smmla v14.4s, v23.16b, v6.16b\n" + ".inst 0x4e91a6e1 // smmla v1.4s, v23.16b, v17.16b\n" + ".inst 0x4e9aa6dd // smmla v29.4s, v22.16b, v26.16b\n" + ".inst 0x4e95a6d2 // smmla v18.4s, v22.16b, v21.16b\n" + ".inst 0x4e9aa5eb // smmla v11.4s, v15.16b, v26.16b\n" + ".inst 0x4e95a5f8 // smmla v24.4s, v15.16b, v21.16b\n" + ".inst 0x4e9fa6cc // smmla v12.4s, v22.16b, v31.16b\n" + ".inst 0x4e88a6c4 // smmla v4.4s, v22.16b, v8.16b\n" + ".inst 0x4e9fa5ee // smmla v14.4s, v15.16b, v31.16b\n" + ".inst 0x4e88a5e1 // smmla v1.4s, v15.16b, v8.16b\n" + "bgt 15b\n" + "ldr q22, [x26, #0x0]\n" + "ldr q21, [x26, #0x10]\n" + "uzp1 v3.2d, v29.2d, v18.2d\n" + "uzp2 v31.2d, v29.2d, v18.2d\n" + "ld1 { v7.4s }, [x22]\n" + "ldr q27, [x26, #0x20]\n" + "uzp1 v15.2d, v12.2d, v4.2d\n" + "uzp2 v6.2d, v12.2d, v4.2d\n" + "ldr q28, [x26, #0x30]\n" + "uzp1 v8.2d, v11.2d, v24.2d\n" + "uzp2 v30.2d, v11.2d, v24.2d\n" + "add x22, x22, #0x10\n" + "ldr q16, [x22, #0x0]\n" + "uzp1 v26.2d, v14.2d, v1.2d\n" + "uzp2 v20.2d, v14.2d, v1.2d\n" "add x26, x26, #0x40\n" - "mla v29.4s, v5.4s, v6.s[2]\n" - "mla v15.4s, v20.4s, v6.s[2]\n" - "fmul v31.4s, v7.4s, v17.s[0]\n" - "mla v26.4s, v5.4s, v6.s[3]\n" - "mla v25.4s, v20.4s, v6.s[3]\n" - "fmul v22.4s, v1.4s, v17.s[1]\n" - "scvtf v2.4s, v2.4s\n" - "scvtf v21.4s, v21.4s\n" - "scvtf v4.4s, v4.4s\n" - "scvtf v28.4s, v28.4s\n" - "fmul v20.4s, v7.4s, v17.s[1]\n" - "scvtf v29.4s, v29.4s\n" - "fmul v24.4s, v1.4s, v17.s[2]\n" + "mla v3.4s, v22.4s, v7.s[0]\n" + "mla v15.4s, v21.4s, v7.s[0]\n" + "mla v31.4s, v22.4s, v7.s[1]\n" + "mla v6.4s, v21.4s, v7.s[1]\n" + "mla v8.4s, v22.4s, v7.s[2]\n" + "mla v26.4s, v21.4s, v7.s[2]\n" + "fmul v23.4s, v27.4s, v16.s[0]\n" + "mla v30.4s, v22.4s, v7.s[3]\n" + "mla v20.4s, v21.4s, v7.s[3]\n" + "fmul v22.4s, v28.4s, v16.s[0]\n" + "scvtf v3.4s, v3.4s\n" "scvtf v15.4s, v15.4s\n" - "fmul v10.4s, v7.4s, v17.s[2]\n" + "fmul v21.4s, v27.4s, v16.s[1]\n" + "scvtf v31.4s, v31.4s\n" + "fmul v10.4s, v28.4s, v16.s[1]\n" + "scvtf v6.4s, v6.4s\n" + "fmul v5.4s, v27.4s, v16.s[2]\n" + "scvtf v8.4s, v8.4s\n" + "fmul v7.4s, v28.4s, v16.s[2]\n" "scvtf v26.4s, v26.4s\n" - "fmul v0.4s, v1.4s, v17.s[3]\n" - "scvtf v25.4s, v25.4s\n" - "fmul v8.4s, v7.4s, v17.s[3]\n" - "fmul v2.4s, v2.4s, v23.4s\n" - "fmul v21.4s, v21.4s, v31.4s\n" - "fmul v4.4s, v4.4s, v22.4s\n" - "fmul v28.4s, v28.4s, v20.4s\n" - "fmul v29.4s, v29.4s, v24.4s\n" - "fmul v15.4s, v15.4s, v10.4s\n" - "fmul v26.4s, v26.4s, v0.4s\n" - "fmul v25.4s, v25.4s, v8.4s\n" - "fmax v2.4s, v2.4s, v27.4s\n" - "fmax v21.4s, v21.4s, v27.4s\n" - "fmax v4.4s, v4.4s, v27.4s\n" - "fmax v28.4s, v28.4s, v27.4s\n" - "fmax v29.4s, v29.4s, v27.4s\n" - "fmax v15.4s, v15.4s, v27.4s\n" - "fmax v26.4s, v26.4s, v27.4s\n" - "fmax v25.4s, v25.4s, v27.4s\n" - "fmin v2.4s, v2.4s, v19.4s\n" - "fmin v21.4s, v21.4s, v19.4s\n" - "fmin v4.4s, v4.4s, v19.4s\n" - "fmin v28.4s, v28.4s, v19.4s\n" - "fmin v29.4s, v29.4s, v19.4s\n" - "fmin v15.4s, v15.4s, v19.4s\n" - "fmin v26.4s, v26.4s, v19.4s\n" - "fmin v25.4s, v25.4s, v19.4s\n" - "bge 18f\n" + "fmul v0.4s, v27.4s, v16.s[3]\n" + "scvtf v30.4s, v30.4s\n" + "fmul v16.4s, v28.4s, v16.s[3]\n" + "scvtf v20.4s, v20.4s\n" + "fmul v29.4s, v3.4s, v23.4s\n" + "fmul v12.4s, v15.4s, v22.4s\n" + "fmul v18.4s, v31.4s, v21.4s\n" + "fmul v4.4s, v6.4s, v10.4s\n" + "fmul v11.4s, v8.4s, v5.4s\n" + "fmul v14.4s, v26.4s, v7.4s\n" + "fmul v24.4s, v30.4s, v0.4s\n" + "fmul v1.4s, v20.4s, v16.4s\n" + "ld1r { v23.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x8\n" + "ld1r { v16.4s }, [x20]\n" + "add x26, x26, #0x20\n" + "fmax v29.4s, v29.4s, v23.4s\n" + "fmax v12.4s, v12.4s, v23.4s\n" + "fmax v18.4s, v18.4s, v23.4s\n" + "fmax v4.4s, v4.4s, v23.4s\n" + "fmax v11.4s, v11.4s, v23.4s\n" + "fmax v14.4s, v14.4s, v23.4s\n" + "fmax v24.4s, v24.4s, v23.4s\n" + "fmax v1.4s, v1.4s, v23.4s\n" + "fmin v29.4s, v29.4s, v16.4s\n" + "fmin v12.4s, v12.4s, v16.4s\n" + "fmin v18.4s, v18.4s, v16.4s\n" + "fmin v4.4s, v4.4s, v16.4s\n" + "fmin v11.4s, v11.4s, v16.4s\n" + "fmin v14.4s, v14.4s, v16.4s\n" + "fmin v24.4s, v24.4s, v16.4s\n" + "fmin v1.4s, v1.4s, v16.4s\n" + "blt 17f\n" + "mov x20, %x[dst]\n" + "cmp x12, #0x1\n" + "str q29, [x20, #0x0]\n" + "str q12, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "cmp x12, #0x2\n" + "str q18, [x20, #0x0]\n" + "str q4, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "cmp x12, #0x3\n" + "str q11, [x20, #0x0]\n" + "str q14, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "str q24, [x20, #0x0]\n" + "str q1, [x20, #0x10]\n" + "b 22f\n" + "17:" // Row tail: Partial output "mov x23, %x[dst]\n" "cmp x12, #0x1\n" "add x22, x23, %x[dst_stride_row]\n" @@ -655,79 +681,59 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( "cmp x12, #0x3\n" "add x20, x21, %x[dst_stride_row]\n" "csel x20, x20, x21, GE\n" - "tbz x25, #2, 15f\n" - "st1 { v26.4s }, [x20], #0x10\n" - "st1 { v29.4s }, [x21], #0x10\n" - "st1 { v4.4s }, [x22], #0x10\n" - "st1 { v2.4s }, [x23], #0x10\n" - "tbz x25, #1, 14f\n" - "str d25, [x20], #0x8\n" - "str d15, [x21], #0x8\n" - "str d28, [x22], #0x8\n" - "str d21, [x23], #0x8\n" - "tbz x25, #0, 17f\n" - "st1 { v25.s }[2], [x20]\n" - "st1 { v15.s }[2], [x21]\n" - "st1 { v28.s }[2], [x22]\n" - "st1 { v21.s }[2], [x23]\n" - "b 17f\n" - "14:" // Row tail: Output block 0: partial_1_4 - "tbz x25, #0, 17f\n" - "str s25, [x20, #0x0]\n" - "str s15, [x21, #0x0]\n" - "str s28, [x22, #0x0]\n" - "str s21, [x23, #0x0]\n" - "b 17f\n" - "15:" // Row tail: Output block 0: partial_2_0 - "tbz x25, #1, 16f\n" - "str d26, [x20], #0x8\n" - "str d29, [x21], #0x8\n" - "str d4, [x22], #0x8\n" - "str d2, [x23], #0x8\n" - "tbz x25, #0, 17f\n" - "st1 { v26.s }[2], [x20]\n" - "st1 { v29.s }[2], [x21]\n" + "tbz x25, #2, 19f\n" + "st1 { v24.4s }, [x20], #0x10\n" + "st1 { v11.4s }, [x21], #0x10\n" + "st1 { v18.4s }, [x22], #0x10\n" + "st1 { v29.4s }, [x23], #0x10\n" + "tbz x25, #1, 18f\n" + "st1 { v1.d }[0], [x20], #0x8\n" + "st1 { v14.d }[0], [x21], #0x8\n" + "st1 { v4.d }[0], [x22], #0x8\n" + "st1 { v12.d }[0], [x23], #0x8\n" + "tbz x25, #0, 21f\n" + "st1 { v1.s }[2], [x20]\n" + "st1 { v14.s }[2], [x21]\n" "st1 { v4.s }[2], [x22]\n" - "st1 { v2.s }[2], [x23]\n" - "b 17f\n" - "16:" // Row tail: Output block 0: partial_1_0 - "str s26, [x20, #0x0]\n" - "str s29, [x21, #0x0]\n" - "str s4, [x22, #0x0]\n" - "str s2, [x23, #0x0]\n" - "17:" // Row tail: Output block 0: Done - "b 19f\n" - "18:" // Row tail: Full output - "mov x20, %x[dst]\n" - "cmp x12, #0x1\n" - "str q2, [x20, #0x0]\n" - "str q21, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 19f\n" - "cmp x12, #0x2\n" - "str q4, [x20, #0x0]\n" - "str q28, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 19f\n" - "cmp x12, #0x3\n" - "str q29, [x20, #0x0]\n" - "str q15, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 19f\n" - "str q26, [x20, #0x0]\n" - "str q25, [x20, #0x10]\n" - "19:" // Row tail: Output stage exit + "st1 { v12.s }[2], [x23]\n" + "b 21f\n" + "18:" // Row tail: Output block 0: partial_1_4 + "tbz x25, #0, 21f\n" + "st1 { v1.s }[0], [x20]\n" + "st1 { v14.s }[0], [x21]\n" + "st1 { v4.s }[0], [x22]\n" + "st1 { v12.s }[0], [x23]\n" + "b 21f\n" + "19:" // Row tail: Output block 0: partial_2_0 + "tbz x25, #1, 20f\n" + "st1 { v24.d }[0], [x20], #0x8\n" + "st1 { v11.d }[0], [x21], #0x8\n" + "st1 { v18.d }[0], [x22], #0x8\n" + "st1 { v29.d }[0], [x23], #0x8\n" + "tbz x25, #0, 21f\n" + "st1 { v24.s }[2], [x20]\n" + "st1 { v11.s }[2], [x21]\n" + "st1 { v18.s }[2], [x22]\n" + "st1 { v29.s }[2], [x23]\n" + "b 21f\n" + "20:" // Row tail: Output block 0: partial_1_0 + "st1 { v24.s }[0], [x20]\n" + "st1 { v11.s }[0], [x21]\n" + "st1 { v18.s }[0], [x22]\n" + "st1 { v29.s }[0], [x23]\n" + "21:" // Row tail: Output block 0: Done + "22:" // Row tail: Output stage exit "subs x25, x25, #0x8\n" "add %x[dst], %x[dst], #0x20\n" - "bgt 12b\n" + "bgt 14b\n" "subs x12, x12, #0x4\n" "add %x[lhs_packed], %x[lhs_packed], x11\n" "mov %x[dst], x24\n" - "bgt 11b\n" - "20:" // Row tail: Row loop skip - : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) - : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), - [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + "bgt 13b\n" + "23:" // Row tail: Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c index 89027ed9..275ccb55 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c @@ -13,6 +13,7 @@ static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { // Since we pack a float and int32 value at the end of the row, @@ -26,7 +27,7 @@ inline static size_t kai_rhs_packed_stride(size_t k, size_t kr, size_t nr, size_ KAI_ASSERT((k_internal % 2) == 0); - return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr) { @@ -56,10 +57,8 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params) { KAI_ASSERT((k % 2) == 0); KAI_ASSERT(num_groups == 1); - KAI_ASSERT(bias == NULL); KAI_ASSERT(extra_bytes == 0); KAI_ASSERT((kr % sr) == 0); - KAI_ASSERT(rhs != NULL); KAI_ASSERT(scale != NULL); KAI_ASSERT(rhs_packed != NULL); @@ -130,7 +129,7 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( // Adjust the reduction sums for (size_t i = 0; i < nr; ++i) { - *((int32_t*)(dst_row)) = sums[i] * 16; + sums[i] = sums[i] * 16; dst_row += sizeof(int32_t); } @@ -141,5 +140,18 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; dst_row += sizeof(float); } + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(y + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + + dst_row += (kai_num_bytes_bias * nr); } } -- GitLab From af7f71952b71005eb1d7f617e3ade59b1b056d18 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 20 Jun 2024 10:06:36 +0100 Subject: [PATCH 2/4] Add support for the non-transposed RHS matrix in the packing function - Remove nxk reference into the original packing function - Add input argument in the packing function to tell whether the RHS matrix is N x K or K x N - Extend the example to work with N x K or K x N RHS matrices - Expose in the header file of the packing function the RHS packed stride Signed-off-by: Gian Marco Iodice --- CMakeLists.txt | 2 +- docs/matmul_qsi4cx/README.md | 5 +- .../CMakeLists.txt | 2 +- .../matmul_clamp_f32_qai8dxp_qsi4cxp.cpp | 287 +++--- ...ai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c | 10 +- ...ai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h | 2 +- ...ai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c | 13 +- ...ai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h | 2 +- ...2_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c | 5 + ...2_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h | 2 +- ...2_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c | 14 + ...2_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h | 2 +- ...2_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c | 12 +- ...2_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h | 2 +- ...2_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c | 960 +++++++++--------- ...2_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h | 2 +- .../kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h | 91 -- ...s0.c => kai_rhs_pack_qsi4cxp_qsu4cxs1s0.c} | 57 +- .../pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.h | 124 +++ 19 files changed, 862 insertions(+), 732 deletions(-) delete mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h rename kai/ukernels/matmul/pack/{kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c => kai_rhs_pack_qsi4cxp_qsu4cxs1s0.c} (74%) create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.h diff --git a/CMakeLists.txt b/CMakeLists.txt index be90b2e9..e3493a50 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,7 +76,7 @@ endif() set(KLEIDIAI_FILES_SCALAR kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c - kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c + kai/ukernels/matmul/pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.c ) set(KLEIDIAI_FILES_NEON_FP16 diff --git a/docs/matmul_qsi4cx/README.md b/docs/matmul_qsi4cx/README.md index 6ad01f68..6eeed395 100644 --- a/docs/matmul_qsi4cx/README.md +++ b/docs/matmul_qsi4cx/README.md @@ -183,13 +183,14 @@ Once we know the size of the packed matrices, we allocate the memory for the pac Assuming you have filled the native LHS and RHS matrices with some random values, perform the RHS packing: ```c - struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params params; + struct kai_rhs_pack_qsi4cxp_qsu4cxs1s0_params params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; // RHS packing kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( 1, n, k, nr, kr, sr, // Packing arguments + true, // The RHS matrix is transposed (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS NULL, // Bias (const float*)(rhs_scales_f32), // Scale @@ -261,7 +262,7 @@ include_directories( # Files requires to build the executable add_executable(matmul_clamp_f32_qai8dxp_qsi4cxp matmul_clamp_f32_qai8dxp_qsi4cxp.cpp - ${MATMUL_PACK_PATH}/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c + ${MATMUL_PACK_PATH}/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.c ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_qai8dxp_f32.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c) ``` diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt index 50e74161..f3afb6e4 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt @@ -20,7 +20,7 @@ include_directories( # Files requires to build the executable add_executable(matmul_clamp_f32_qai8dxp_qsi4cxp matmul_clamp_f32_qai8dxp_qsi4cxp.cpp - ${MATMUL_PACK_PATH}/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c + ${MATMUL_PACK_PATH}/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.c ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_qai8dxp_f32.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp index 1d1ae147..a74b2d64 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_MATMUL_INT8) +#if 0 //! defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_MATMUL_INT8) #error "Dotprod and I8mm extensions required to compile this example" #else #include @@ -22,11 +22,16 @@ #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" -#include "kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" +#include "kai_rhs_pack_qsi4cxp_qsu4cxs1s0.h" #define INT4_MIN (-8) #define INT4_MAX (7) +enum class rhs_format { + nxk, + kxn, +}; + // Micro-kernel interface struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; @@ -120,8 +125,11 @@ static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, si } } -static void quant_qs4cx_f32(size_t n, size_t k, const float* rhs_f32, uint8_t* rhs_qs4cx, float* rhs_scales_f32) { - const size_t dst_stride = (k / 2) * sizeof(int8_t); +static void quant_qs4cx_f32( + size_t n, size_t k, rhs_format format, const float* rhs_f32, uint8_t* rhs_qs4cx, float* rhs_scales_f32) { + const size_t dst_k_step = format == rhs_format::nxk ? sizeof(int8_t) : n * sizeof(int8_t); + + const size_t dst_n_step = format == rhs_format::nxk ? (k / 2) * sizeof(int8_t) : sizeof(int8_t); for (size_t row_idx = 0; row_idx < n; ++row_idx) { const float* src_ptr = rhs_f32 + row_idx * k; @@ -149,7 +157,7 @@ static void quant_qs4cx_f32(size_t n, size_t k, const float* rhs_f32, uint8_t* r // Reciprocal to quantize const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; - uint8_t* dst_ptr = (uint8_t*)rhs_qs4cx + row_idx * dst_stride; + uint8_t* dst_ptr = (uint8_t*)rhs_qs4cx + row_idx * dst_n_step; // Quantize the channels for (size_t k_idx = 0; k_idx < k; k_idx += 2) { @@ -172,7 +180,7 @@ static void quant_qs4cx_f32(size_t n, size_t k, const float* rhs_f32, uint8_t* r const uint8_t rhs_v0 = (v1_u8 << 4) | v0_u8; dst_ptr[0] = rhs_v0; - dst_ptr += sizeof(uint8_t); + dst_ptr += dst_k_step; } rhs_scales_f32[row_idx] = recip_scale0; @@ -248,10 +256,15 @@ static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t }; static void ref_matmul_f32_qa8dx_qs4cx( - size_t m, size_t n, size_t k, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4cx, const float* rhs_scales_f32, - float* dst_f32, float scalar_min, float scalar_max) { + size_t m, size_t n, size_t k, rhs_format format, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4cx, + const float* rhs_scales_f32, float* dst_f32, float scalar_min, float scalar_max) { const size_t lhs_stride = k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); - const size_t rhs_stride = (k / 2) * sizeof(uint8_t); + + const size_t lhs_k_step = sizeof(int8_t) * 2; + + const size_t rhs_k_step = format == rhs_format::nxk ? sizeof(int8_t) : n * sizeof(int8_t); + + const size_t rhs_n_step = format == rhs_format::nxk ? (k / 2) * sizeof(int8_t) : sizeof(int8_t); for (size_t row_idx = 0; row_idx < m; ++row_idx) { const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; @@ -260,7 +273,7 @@ static void ref_matmul_f32_qa8dx_qs4cx( int32_t iacc = 0; const int8_t* lhs_ptr = lhs_ptr_start; - const uint8_t* rhs_ptr = rhs_qs4cx + col_idx * rhs_stride; + const uint8_t* rhs_ptr = rhs_qs4cx + col_idx * rhs_n_step; // Get the LHS quantization parameters stored at the // beginning of each row @@ -287,8 +300,8 @@ static void ref_matmul_f32_qa8dx_qs4cx( iacc += lhs_offset * rhs_v0; iacc += lhs_offset * rhs_v1; - lhs_ptr += 2; - rhs_ptr += 1; + lhs_ptr += lhs_k_step; + rhs_ptr += rhs_k_step; } // Get the RHS scale @@ -329,131 +342,139 @@ int main(int argc, char** argv) { const size_t seed_lhs = 4568; const size_t seed_rhs = seed_lhs + 4; - const size_t lhs_native_size_f32 = m * k * sizeof(float); - const size_t rhs_native_size_f32 = n * k * sizeof(float); - const size_t rhs_native_size_qs4cx = n * (k / 2) * sizeof(uint8_t); - const size_t rhs_scales_size_f32 = n * sizeof(float); - - // Allocate the memory - uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; - uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; - uint8_t* rhs_native_mtx_qs4cx = new uint8_t[rhs_native_size_qs4cx]; - uint8_t* rhs_scales_f32 = new uint8_t[rhs_scales_size_f32]; - - fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); - fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); - - quant_qs4cx_f32(n, k, (const float*)rhs_native_mtx_f32, (uint8_t*)rhs_native_mtx_qs4cx, (float*)rhs_scales_f32); - - delete[] rhs_native_mtx_f32; - - //----------- REFERENCE IMPLEMENTATION - //------------------------------------ - //------------------------------------ - // Memory sizes for the reference implementation - // After dynamically quantized the LHS matrix, we have the scale and offset for each - // row. The scale (f32) and offset (int32) are stored at the beginning of each row - const size_t lhs_ref_size_qa8dx = m * (k + sizeof(int32_t) + sizeof(float)); - const size_t dst_ref_size_f32 = m * n * sizeof(float); - - uint8_t* lhs_ref_mtx_qa8dx = new uint8_t[lhs_ref_size_qa8dx]; - uint8_t* dst_ref_mtx_f32 = new uint8_t[dst_ref_size_f32]; - - ref_quant_qa8dx_f32(m, k, (const float*)lhs_native_mtx_f32, (int8_t*)lhs_ref_mtx_qa8dx); - - ref_matmul_f32_qa8dx_qs4cx( - m, n, k, (const int8_t*)lhs_ref_mtx_qa8dx, (const uint8_t*)rhs_native_mtx_qs4cx, (const float*)rhs_scales_f32, - (float*)dst_ref_mtx_f32, -FLT_MAX, FLT_MAX); - - // Remove the unnecessary buffer - delete[] lhs_ref_mtx_qa8dx; - - //----------- END REFERENCE IMPLEMENTATION - //------------------------------------ - //------------------------------------ - - //----------- MICRO-KERNELS TESTS - //------------------------------------ - //------------------------------------ - for (size_t idx_variant = 0; idx_variant < num_ukernel_variants; ++idx_variant) { - std::cout << "Testing " << ukernel_variants[idx_variant].name << std::endl; - - // Get the packing parameters - const size_t mr = ukernel_variants[idx_variant].ukernel.get_mr(); - const size_t nr = ukernel_variants[idx_variant].ukernel.get_nr(); - const size_t kr = ukernel_variants[idx_variant].ukernel.get_kr(); - const size_t sr = ukernel_variants[idx_variant].ukernel.get_sr(); - - // Get the size in bytes for the packed matrices - const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); - const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(n, k, nr, kr, sr); - const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); - - // Allocate the matrices - uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; - uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; - uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; - - // If the RHS matrix contains constant values, the packing can be performed - // only once - struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - - // RHS packing - kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( - 1, n, k, nr, kr, sr, // Packing arguments - (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS - NULL, // Bias - (const float*)(rhs_scales_f32), // Scale - rhs_packed_mtx_qs4cx, // RHS packed - 0, ¶ms); - - // LHS packing - kai_run_lhs_quant_pack_qai8dxp_f32( - m, k, mr, kr, sr, 0, // Packing arguments - (const float*)lhs_native_mtx_f32, // LHS - k * sizeof(float), // LHS stride - lhs_packed_mtx_qa8dx); // LHS packed - - // Matmul - { - const size_t dst_stride = n * sizeof(float); - const size_t lhs_offset = ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k); - const size_t rhs_offset = ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(0, k); - const size_t dst_offset = ukernel_variants[idx_variant].ukernel.get_dst_offset(0, 0, dst_stride); - - const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); - const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + rhs_offset); - float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); - - ukernel_variants[idx_variant].ukernel.run_matmul( - m, n, k, // Dimensions - lhs_ptr, // LHS packed - rhs_ptr, // RHS packed - dst_ptr, // DST - dst_stride, // DST stride (row) - sizeof(float), // DST stride (col) - -FLT_MAX, FLT_MAX // Min and max for the clamp operation - ); - } + // Iterate over the RHS format (NxK or KxN) + for (const rhs_format& format : {rhs_format::nxk, rhs_format::kxn}) { + std::cout << "Testing RHS format = " << (format == rhs_format::nxk ? "N x K" : "K x N") << std::endl; + + const size_t lhs_native_size_f32 = m * k * sizeof(float); + const size_t rhs_native_size_f32 = n * k * sizeof(float); + const size_t rhs_native_size_qs4cx = n * (k / 2) * sizeof(uint8_t); + const size_t rhs_scales_size_f32 = n * sizeof(float); + + // Allocate the memory + uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; + uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; + uint8_t* rhs_native_mtx_qs4cx = new uint8_t[rhs_native_size_qs4cx]; + uint8_t* rhs_scales_f32 = new uint8_t[rhs_scales_size_f32]; + + fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); + fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); + + quant_qs4cx_f32( + n, k, format, (const float*)rhs_native_mtx_f32, (uint8_t*)rhs_native_mtx_qs4cx, (float*)rhs_scales_f32); + + delete[] rhs_native_mtx_f32; + + //----------- REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + // Memory sizes for the reference implementation + // After dynamically quantized the LHS matrix, we have the scale and offset for each + // row. The scale (f32) and offset (int32) are stored at the beginning of each row + const size_t lhs_ref_size_qa8dx = m * (k + sizeof(int32_t) + sizeof(float)); + const size_t dst_ref_size_f32 = m * n * sizeof(float); + + uint8_t* lhs_ref_mtx_qa8dx = new uint8_t[lhs_ref_size_qa8dx]; + uint8_t* dst_ref_mtx_f32 = new uint8_t[dst_ref_size_f32]; + + ref_quant_qa8dx_f32(m, k, (const float*)lhs_native_mtx_f32, (int8_t*)lhs_ref_mtx_qa8dx); + + ref_matmul_f32_qa8dx_qs4cx( + m, n, k, format, (const int8_t*)lhs_ref_mtx_qa8dx, (const uint8_t*)rhs_native_mtx_qs4cx, + (const float*)rhs_scales_f32, (float*)dst_ref_mtx_f32, -FLT_MAX, FLT_MAX); + + // Remove the unnecessary buffer + delete[] lhs_ref_mtx_qa8dx; + + //----------- END REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + + //----------- MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + for (size_t idx_variant = 0; idx_variant < num_ukernel_variants; ++idx_variant) { + std::cout << "Testing " << ukernel_variants[idx_variant].name << std::endl; + ; + + // Get the packing parameters + const size_t mr = ukernel_variants[idx_variant].ukernel.get_mr(); + const size_t nr = ukernel_variants[idx_variant].ukernel.get_nr(); + const size_t kr = ukernel_variants[idx_variant].ukernel.get_kr(); + const size_t sr = ukernel_variants[idx_variant].ukernel.get_sr(); + + // Get the size in bytes for the packed matrices + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_qsi4cxp_qsu4cxs1s0(n, k, nr, kr, sr); + const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); + + // Allocate the matrices + uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; + uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; + uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; + + // If the RHS matrix contains constant values, the packing can be performed + // only once + struct kai_rhs_pack_qsi4cxp_qsu4cxs1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + + // RHS packing + kai_run_rhs_pack_qsi4cxp_qsu4cxs1s0( + 1, n, k, nr, kr, sr, // Packing arguments + format == rhs_format::nxk, // True, if the RHS matrix is N x K (transposed) + (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS + NULL, // Bias + (const float*)(rhs_scales_f32), // Scale + rhs_packed_mtx_qs4cx, // RHS packed + 0, ¶ms); + + // LHS packing + kai_run_lhs_quant_pack_qai8dxp_f32( + m, k, mr, kr, sr, 0, // Packing arguments + (const float*)lhs_native_mtx_f32, // LHS + k * sizeof(float), // LHS stride + lhs_packed_mtx_qa8dx); // LHS packed + + // Matmul + { + const size_t dst_stride = n * sizeof(float); + const size_t lhs_offset = ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k); + const size_t rhs_offset = ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(0, k); + const size_t dst_offset = ukernel_variants[idx_variant].ukernel.get_dst_offset(0, 0, dst_stride); + + const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); + const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + rhs_offset); + float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); + + ukernel_variants[idx_variant].ukernel.run_matmul( + m, n, k, // Dimensions + lhs_ptr, // LHS packed + rhs_ptr, // RHS packed + dst_ptr, // DST + dst_stride, // DST stride (row) + sizeof(float), // DST stride (col) + -FLT_MAX, FLT_MAX // Min and max for the clamp operation + ); + } - const bool is_valid = - is_output_correct(m, n, 0.0001f, (const float*)dst_ref_mtx_f32, (const float*)dst_act_mtx_f32); + const bool is_valid = + is_output_correct(m, n, 0.0001f, (const float*)dst_ref_mtx_f32, (const float*)dst_act_mtx_f32); - if (is_valid) { - printf("TEST[%ld] = PASSED\n", idx_variant); - } else { - printf("TEST[%ld] = FAILED\n", idx_variant); + if (is_valid) { + printf("TEST[%ld] = PASSED\n", idx_variant); + } else { + printf("TEST[%ld] = FAILED\n", idx_variant); + } + delete[] lhs_packed_mtx_qa8dx; + delete[] rhs_packed_mtx_qs4cx; + delete[] dst_act_mtx_f32; } - delete[] lhs_packed_mtx_qa8dx; - delete[] rhs_packed_mtx_qs4cx; - delete[] dst_act_mtx_f32; + delete[] lhs_native_mtx_f32; + delete[] rhs_native_mtx_qs4cx; + delete[] rhs_scales_f32; + delete[] dst_ref_mtx_f32; } - delete[] lhs_native_mtx_f32; - delete[] rhs_native_mtx_qs4cx; - delete[] rhs_scales_f32; - delete[] dst_ref_mtx_f32; } //----------- END MICRO-KERNELS TESTS diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c index d062c5be..8f1479dc 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c @@ -45,7 +45,7 @@ inline static size_t kai_rhs_packed_stride(size_t k) { KAI_ASSERT((k_internal % 2) == 0); - return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { @@ -185,8 +185,9 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( const float32x4_t rhs_scale = vld1q_f32((const float*)rhs_ptr); rhs_ptr += sizeof(float32x4_t); - // Skip the bias - rhs_ptr += kai_nr * kai_num_bytes_bias; + // Load the bias + const float32x4_t bias0 = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); // Add the reduction sum iacc = vmlaq_s32(iacc, sum_n_s32, lhs_offset); @@ -195,6 +196,9 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( main_acc = vmulq_f32(main_acc, lhs_scale); + // Add the bias + main_acc = vaddq_f32(main_acc, bias0); + // clamp (min-max) operation const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h index 98dc5880..b79e8a57 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h @@ -17,7 +17,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix +/// -# kai_rhs_pack_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix /// -------------------------------------------------- diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c index 707cb0bd..2b9a1234 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c @@ -45,7 +45,7 @@ inline static size_t kai_rhs_packed_stride(size_t k) { KAI_ASSERT((k_internal % 2) == 0); - return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) { @@ -212,8 +212,11 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod( const float32x4_t rhs_scale1 = vld1q_f32((const float*)rhs_ptr); rhs_ptr += sizeof(float32x4_t); - // Skip the bias - rhs_ptr += kai_nr * kai_num_bytes_bias; + // Load the bias + const float32x4_t bias0 = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + const float32x4_t bias1 = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); // Add the reduction sum iacc0 = vmlaq_s32(iacc0, sum_n_s32_0, lhs_offset); @@ -225,6 +228,10 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod( main_acc0 = vmulq_f32(main_acc0, lhs_scale); main_acc1 = vmulq_f32(main_acc1, lhs_scale); + // Add the bias + main_acc0 = vaddq_f32(main_acc0, bias0); + main_acc1 = vaddq_f32(main_acc1, bias1); + // clamp (min-max) operation const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h index b7810898..bb1de658 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h @@ -17,7 +17,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix +/// -# kai_rhs_pack_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix /// -------------------------------------------------- diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c index 97e3e308..d987b2f4 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c @@ -197,11 +197,16 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( "fmul v2.4s, v23.4s, v18.4s\n" "fmul v1.4s, v21.4s, v17.4s\n" "fmul v0.4s, v20.4s, v16.4s\n" + "ldr q18, [x26, #0x0]\n" "ld1r { v17.4s }, [%x[clamp_vals]]\n" "add x20, %x[clamp_vals], #0x4\n" "cmp x25, #0x4\n" "ld1r { v16.4s }, [x20]\n" "add x26, x26, #0x10\n" + "fadd v3.4s, v3.4s, v18.4s\n" + "fadd v2.4s, v2.4s, v18.4s\n" + "fadd v1.4s, v1.4s, v18.4s\n" + "fadd v0.4s, v0.4s, v18.4s\n" "fmax v3.4s, v3.4s, v17.4s\n" "fmax v2.4s, v2.4s, v17.4s\n" "fmax v1.4s, v1.4s, v17.4s\n" diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h index aec1ca0e..e5d82e7c 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h @@ -17,7 +17,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix +/// -# kai_rhs_pack_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix /// -------------------------------------------------- diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c index 71f7070d..b93f8147 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c @@ -250,11 +250,20 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( "fmul v5.4s, v22.4s, v18.4s\n" "fmul v4.4s, v21.4s, v17.4s\n" "fmul v3.4s, v20.4s, v16.4s\n" + "ldr q18, [x10, #0x0]\n" "ld1r { v17.4s }, [%x[clamp_vals]]\n" "add x20, %x[clamp_vals], #0x4\n" "cmp x9, #0x4\n" "ld1r { v16.4s }, [x20]\n" "add x10, x10, #0x10\n" + "fadd v10.4s, v10.4s, v18.4s\n" + "fadd v9.4s, v9.4s, v18.4s\n" + "fadd v8.4s, v8.4s, v18.4s\n" + "fadd v7.4s, v7.4s, v18.4s\n" + "fadd v6.4s, v6.4s, v18.4s\n" + "fadd v5.4s, v5.4s, v18.4s\n" + "fadd v4.4s, v4.4s, v18.4s\n" + "fadd v3.4s, v3.4s, v18.4s\n" "fmax v10.4s, v10.4s, v17.4s\n" "fmax v9.4s, v9.4s, v17.4s\n" "fmax v8.4s, v8.4s, v17.4s\n" @@ -417,11 +426,16 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( "fmul v9.4s, v23.4s, v18.4s\n" "fmul v8.4s, v21.4s, v17.4s\n" "fmul v7.4s, v20.4s, v16.4s\n" + "ldr q18, [x26, #0x0]\n" "ld1r { v17.4s }, [%x[clamp_vals]]\n" "add x20, %x[clamp_vals], #0x4\n" "cmp x25, #0x4\n" "ld1r { v16.4s }, [x20]\n" "add x26, x26, #0x10\n" + "fadd v10.4s, v10.4s, v18.4s\n" + "fadd v9.4s, v9.4s, v18.4s\n" + "fadd v8.4s, v8.4s, v18.4s\n" + "fadd v7.4s, v7.4s, v18.4s\n" "fmax v10.4s, v10.4s, v17.4s\n" "fmax v9.4s, v9.4s, v17.4s\n" "fmax v8.4s, v8.4s, v17.4s\n" diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h index bc277b12..861500e3 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h @@ -17,7 +17,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix +/// -# kai_rhs_pack_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix /// -------------------------------------------------- diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c index 136248e0..1f7d0b8f 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c @@ -250,11 +250,21 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( "fmul v6.4s, v25.4s, v18.4s\n" "fmul v5.4s, v26.4s, v17.4s\n" "fmul v4.4s, v24.4s, v16.4s\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "ldr q19, [x26, #0x0]\n" + "ldr q18, [x26, #0x10]\n" "add x20, %x[clamp_vals], #0x4\n" "cmp x25, #0x8\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" "ld1r { v16.4s }, [x20]\n" "add x26, x26, #0x20\n" + "fadd v11.4s, v11.4s, v19.4s\n" + "fadd v10.4s, v10.4s, v18.4s\n" + "fadd v9.4s, v9.4s, v19.4s\n" + "fadd v8.4s, v8.4s, v18.4s\n" + "fadd v7.4s, v7.4s, v19.4s\n" + "fadd v6.4s, v6.4s, v18.4s\n" + "fadd v5.4s, v5.4s, v19.4s\n" + "fadd v4.4s, v4.4s, v18.4s\n" "fmax v11.4s, v11.4s, v17.4s\n" "fmax v10.4s, v10.4s, v17.4s\n" "fmax v9.4s, v9.4s, v17.4s\n" diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h index b9f06ea2..6ccbec69 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h @@ -17,7 +17,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix +/// -# kai_rhs_pack_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix /// -------------------------------------------------- diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c index 277d6e5f..8bcce742 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c @@ -114,7 +114,7 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( __asm__ __volatile__( "mov x12, %x[m]\n" "mov x11, #0x80\n" - "movi v25.16b, #0xf0\n" + "movi v24.16b, #0xf0\n" "mov x20, #0x20\n" "cmp x12, #0x8\n" "madd x11, %x[num_blocks], x11, x20\n" @@ -125,287 +125,305 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" "2:" // Column loop "mov x22, %x[lhs_packed]\n" - "movi v29.4s, #0x0\n" - "movi v12.4s, #0x0\n" + "movi v13.4s, #0x0\n" + "movi v22.4s, #0x0\n" "mov x21, %x[num_blocks]\n" - "movi v18.4s, #0x0\n" - "movi v4.4s, #0x0\n" "movi v11.4s, #0x0\n" + "movi v15.4s, #0x0\n" "movi v14.4s, #0x0\n" + "movi v17.4s, #0x0\n" "add x20, x22, x11\n" - "movi v24.4s, #0x0\n" "movi v1.4s, #0x0\n" - "movi v19.4s, #0x0\n" + "movi v8.4s, #0x0\n" "movi v30.4s, #0x0\n" - "movi v31.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v12.4s, #0x0\n" + "movi v4.4s, #0x0\n" "movi v26.4s, #0x0\n" - "movi v5.4s, #0x0\n" - "movi v13.4s, #0x0\n" - "movi v7.4s, #0x0\n" - "movi v15.4s, #0x0\n" "3:" // Sub block loop - "ldr q0, [x10, #0x0]\n" - "ldr q17, [x10, #0x10]\n" + "ldr q31, [x10, #0x0]\n" + "ldr q3, [x10, #0x10]\n" "subs x21, x21, #0x1\n" - "ldr q10, [x10, #0x20]\n" - "ldr q8, [x10, #0x30]\n" - "ldr q9, [x22, #0x0]\n" - "ldr q20, [x22, #0x10]\n" - "ldr q2, [x20, #0x0]\n" - "ldr q3, [x20, #0x10]\n" - "shl v23.16b, v0.16b, #0x4\n" - "shl v21.16b, v17.16b, #0x4\n" - "ldr q27, [x10, #0x40]\n" - "ldr q6, [x10, #0x50]\n" - "shl v16.16b, v10.16b, #0x4\n" - "shl v22.16b, v8.16b, #0x4\n" - "ldr q28, [x10, #0x60]\n" - "and v0.16b, v0.16b, v25.16b\n" - "and v17.16b, v17.16b, v25.16b\n" - ".inst 0x4e97a53d // smmla v29.4s, v9.16b, v23.16b\n" - ".inst 0x4e95a532 // smmla v18.4s, v9.16b, v21.16b\n" - ".inst 0x4e97a68b // smmla v11.4s, v20.16b, v23.16b\n" - "and v10.16b, v10.16b, v25.16b\n" - ".inst 0x4e90a52c // smmla v12.4s, v9.16b, v16.16b\n" - ".inst 0x4e96a524 // smmla v4.4s, v9.16b, v22.16b\n" - "ldr q9, [x10, #0x70]\n" - "and v8.16b, v8.16b, v25.16b\n" - ".inst 0x4e95a698 // smmla v24.4s, v20.16b, v21.16b\n" - ".inst 0x4e90a68e // smmla v14.4s, v20.16b, v16.16b\n" + "ldr q6, [x10, #0x20]\n" + "ldr q25, [x10, #0x30]\n" + "ldr q19, [x22, #0x0]\n" + "ldr q7, [x22, #0x10]\n" + "ldr q20, [x20, #0x0]\n" + "ldr q5, [x20, #0x10]\n" + "shl v21.16b, v31.16b, #0x4\n" + "shl v2.16b, v3.16b, #0x4\n" + "ldr q16, [x10, #0x40]\n" + "ldr q23, [x10, #0x50]\n" + "shl v9.16b, v6.16b, #0x4\n" + "shl v29.16b, v25.16b, #0x4\n" + "ldr q0, [x10, #0x60]\n" + "and v31.16b, v31.16b, v24.16b\n" + "and v3.16b, v3.16b, v24.16b\n" + ".inst 0x4e95a66d // smmla v13.4s, v19.16b, v21.16b\n" + ".inst 0x4e82a66b // smmla v11.4s, v19.16b, v2.16b\n" + ".inst 0x4e95a4ee // smmla v14.4s, v7.16b, v21.16b\n" + "and v6.16b, v6.16b, v24.16b\n" + ".inst 0x4e89a676 // smmla v22.4s, v19.16b, v9.16b\n" + ".inst 0x4e9da66f // smmla v15.4s, v19.16b, v29.16b\n" + "ldr q19, [x10, #0x70]\n" + "and v25.16b, v25.16b, v24.16b\n" + ".inst 0x4e82a4e1 // smmla v1.4s, v7.16b, v2.16b\n" + ".inst 0x4e89a4f1 // smmla v17.4s, v7.16b, v9.16b\n" "add x10, x10, #0x80\n" - ".inst 0x4e96a681 // smmla v1.4s, v20.16b, v22.16b\n" - "ldr q20, [x22, #0x20]\n" - ".inst 0x4e97a453 // smmla v19.4s, v2.16b, v23.16b\n" - ".inst 0x4e95a45f // smmla v31.4s, v2.16b, v21.16b\n" - ".inst 0x4e90a45e // smmla v30.4s, v2.16b, v16.16b\n" - ".inst 0x4e96a45a // smmla v26.4s, v2.16b, v22.16b\n" - "ldr q2, [x22, #0x30]\n" - ".inst 0x4e97a465 // smmla v5.4s, v3.16b, v23.16b\n" - "ldr q23, [x20, #0x20]\n" - ".inst 0x4e95a467 // smmla v7.4s, v3.16b, v21.16b\n" - "ldr q21, [x20, #0x30]\n" - ".inst 0x4e90a46d // smmla v13.4s, v3.16b, v16.16b\n" - "ldr q16, [x22, #0x40]\n" - ".inst 0x4e96a46f // smmla v15.4s, v3.16b, v22.16b\n" - "ldr q3, [x22, #0x50]\n" - "shl v22.16b, v27.16b, #0x4\n" - "and v27.16b, v27.16b, v25.16b\n" - ".inst 0x4e96a69d // smmla v29.4s, v20.16b, v22.16b\n" - ".inst 0x4e96a44b // smmla v11.4s, v2.16b, v22.16b\n" - ".inst 0x4e96a6f3 // smmla v19.4s, v23.16b, v22.16b\n" - ".inst 0x4e96a6a5 // smmla v5.4s, v21.16b, v22.16b\n" - "shl v22.16b, v6.16b, #0x4\n" - "and v6.16b, v6.16b, v25.16b\n" - ".inst 0x4e96a692 // smmla v18.4s, v20.16b, v22.16b\n" - ".inst 0x4e96a458 // smmla v24.4s, v2.16b, v22.16b\n" - ".inst 0x4e96a6ff // smmla v31.4s, v23.16b, v22.16b\n" - ".inst 0x4e96a6a7 // smmla v7.4s, v21.16b, v22.16b\n" - "shl v22.16b, v28.16b, #0x4\n" - ".inst 0x4e80a61d // smmla v29.4s, v16.16b, v0.16b\n" - ".inst 0x4e80a46b // smmla v11.4s, v3.16b, v0.16b\n" - "and v28.16b, v28.16b, v25.16b\n" - ".inst 0x4e96a68c // smmla v12.4s, v20.16b, v22.16b\n" - ".inst 0x4e96a44e // smmla v14.4s, v2.16b, v22.16b\n" - ".inst 0x4e96a6fe // smmla v30.4s, v23.16b, v22.16b\n" - ".inst 0x4e96a6ad // smmla v13.4s, v21.16b, v22.16b\n" - "shl v22.16b, v9.16b, #0x4\n" - ".inst 0x4e91a612 // smmla v18.4s, v16.16b, v17.16b\n" - ".inst 0x4e91a478 // smmla v24.4s, v3.16b, v17.16b\n" - "and v9.16b, v9.16b, v25.16b\n" - ".inst 0x4e96a684 // smmla v4.4s, v20.16b, v22.16b\n" - "ldr q20, [x20, #0x40]\n" - ".inst 0x4e96a441 // smmla v1.4s, v2.16b, v22.16b\n" - "ldr q2, [x20, #0x50]\n" - ".inst 0x4e96a6fa // smmla v26.4s, v23.16b, v22.16b\n" - "ldr q23, [x22, #0x60]\n" - ".inst 0x4e96a6af // smmla v15.4s, v21.16b, v22.16b\n" - "ldr q22, [x22, #0x70]\n" - "ldr q21, [x20, #0x60]\n" - ".inst 0x4e8aa60c // smmla v12.4s, v16.16b, v10.16b\n" - ".inst 0x4e8aa46e // smmla v14.4s, v3.16b, v10.16b\n" + ".inst 0x4e9da4e8 // smmla v8.4s, v7.16b, v29.16b\n" + "ldr q7, [x22, #0x20]\n" + ".inst 0x4e95a69e // smmla v30.4s, v20.16b, v21.16b\n" + ".inst 0x4e82a68a // smmla v10.4s, v20.16b, v2.16b\n" + ".inst 0x4e89a69c // smmla v28.4s, v20.16b, v9.16b\n" + ".inst 0x4e9da69b // smmla v27.4s, v20.16b, v29.16b\n" + "ldr q20, [x22, #0x30]\n" + ".inst 0x4e95a4b2 // smmla v18.4s, v5.16b, v21.16b\n" + "ldr q21, [x20, #0x20]\n" + ".inst 0x4e82a4a4 // smmla v4.4s, v5.16b, v2.16b\n" + "ldr q2, [x20, #0x30]\n" + ".inst 0x4e89a4ac // smmla v12.4s, v5.16b, v9.16b\n" + "ldr q9, [x22, #0x40]\n" + ".inst 0x4e9da4ba // smmla v26.4s, v5.16b, v29.16b\n" + "ldr q29, [x22, #0x50]\n" + "shl v5.16b, v16.16b, #0x4\n" + "and v16.16b, v16.16b, v24.16b\n" + ".inst 0x4e85a4ed // smmla v13.4s, v7.16b, v5.16b\n" + ".inst 0x4e85a68e // smmla v14.4s, v20.16b, v5.16b\n" + ".inst 0x4e85a6be // smmla v30.4s, v21.16b, v5.16b\n" + ".inst 0x4e85a452 // smmla v18.4s, v2.16b, v5.16b\n" + "shl v5.16b, v23.16b, #0x4\n" + "and v23.16b, v23.16b, v24.16b\n" + ".inst 0x4e85a4eb // smmla v11.4s, v7.16b, v5.16b\n" + ".inst 0x4e85a681 // smmla v1.4s, v20.16b, v5.16b\n" + ".inst 0x4e85a6aa // smmla v10.4s, v21.16b, v5.16b\n" + ".inst 0x4e85a444 // smmla v4.4s, v2.16b, v5.16b\n" + "shl v5.16b, v0.16b, #0x4\n" + ".inst 0x4e9fa52d // smmla v13.4s, v9.16b, v31.16b\n" + ".inst 0x4e9fa7ae // smmla v14.4s, v29.16b, v31.16b\n" + "and v0.16b, v0.16b, v24.16b\n" + ".inst 0x4e85a4f6 // smmla v22.4s, v7.16b, v5.16b\n" + ".inst 0x4e85a691 // smmla v17.4s, v20.16b, v5.16b\n" + ".inst 0x4e85a6bc // smmla v28.4s, v21.16b, v5.16b\n" + ".inst 0x4e85a44c // smmla v12.4s, v2.16b, v5.16b\n" + "shl v5.16b, v19.16b, #0x4\n" + ".inst 0x4e83a52b // smmla v11.4s, v9.16b, v3.16b\n" + ".inst 0x4e83a7a1 // smmla v1.4s, v29.16b, v3.16b\n" + "and v19.16b, v19.16b, v24.16b\n" + ".inst 0x4e85a4ef // smmla v15.4s, v7.16b, v5.16b\n" + "ldr q7, [x20, #0x40]\n" + ".inst 0x4e85a688 // smmla v8.4s, v20.16b, v5.16b\n" + "ldr q20, [x20, #0x50]\n" + ".inst 0x4e85a6bb // smmla v27.4s, v21.16b, v5.16b\n" + "ldr q21, [x22, #0x60]\n" + ".inst 0x4e85a45a // smmla v26.4s, v2.16b, v5.16b\n" + "ldr q5, [x22, #0x70]\n" + "ldr q2, [x20, #0x60]\n" + ".inst 0x4e86a536 // smmla v22.4s, v9.16b, v6.16b\n" + ".inst 0x4e86a7b1 // smmla v17.4s, v29.16b, v6.16b\n" "add x22, x22, #0x80\n" - ".inst 0x4e80a693 // smmla v19.4s, v20.16b, v0.16b\n" - ".inst 0x4e91a69f // smmla v31.4s, v20.16b, v17.16b\n" - ".inst 0x4e88a604 // smmla v4.4s, v16.16b, v8.16b\n" - "ldr q16, [x20, #0x70]\n" - ".inst 0x4e88a461 // smmla v1.4s, v3.16b, v8.16b\n" + ".inst 0x4e9fa4fe // smmla v30.4s, v7.16b, v31.16b\n" + ".inst 0x4e83a4ea // smmla v10.4s, v7.16b, v3.16b\n" + ".inst 0x4e99a52f // smmla v15.4s, v9.16b, v25.16b\n" + "ldr q9, [x20, #0x70]\n" + ".inst 0x4e99a7a8 // smmla v8.4s, v29.16b, v25.16b\n" "add x20, x20, #0x80\n" - ".inst 0x4e8aa69e // smmla v30.4s, v20.16b, v10.16b\n" - ".inst 0x4e88a69a // smmla v26.4s, v20.16b, v8.16b\n" - ".inst 0x4e80a445 // smmla v5.4s, v2.16b, v0.16b\n" - ".inst 0x4e91a447 // smmla v7.4s, v2.16b, v17.16b\n" - ".inst 0x4e8aa44d // smmla v13.4s, v2.16b, v10.16b\n" - ".inst 0x4e88a44f // smmla v15.4s, v2.16b, v8.16b\n" - ".inst 0x4e9ba6fd // smmla v29.4s, v23.16b, v27.16b\n" - ".inst 0x4e86a6f2 // smmla v18.4s, v23.16b, v6.16b\n" - ".inst 0x4e9ca6ec // smmla v12.4s, v23.16b, v28.16b\n" - ".inst 0x4e89a6e4 // smmla v4.4s, v23.16b, v9.16b\n" - ".inst 0x4e9ba6cb // smmla v11.4s, v22.16b, v27.16b\n" - ".inst 0x4e86a6d8 // smmla v24.4s, v22.16b, v6.16b\n" - ".inst 0x4e9ca6ce // smmla v14.4s, v22.16b, v28.16b\n" - ".inst 0x4e89a6c1 // smmla v1.4s, v22.16b, v9.16b\n" - ".inst 0x4e9ba6b3 // smmla v19.4s, v21.16b, v27.16b\n" - ".inst 0x4e86a6bf // smmla v31.4s, v21.16b, v6.16b\n" - ".inst 0x4e9ca6be // smmla v30.4s, v21.16b, v28.16b\n" - ".inst 0x4e89a6ba // smmla v26.4s, v21.16b, v9.16b\n" - ".inst 0x4e9ba605 // smmla v5.4s, v16.16b, v27.16b\n" - ".inst 0x4e86a607 // smmla v7.4s, v16.16b, v6.16b\n" - ".inst 0x4e9ca60d // smmla v13.4s, v16.16b, v28.16b\n" - ".inst 0x4e89a60f // smmla v15.4s, v16.16b, v9.16b\n" + ".inst 0x4e86a4fc // smmla v28.4s, v7.16b, v6.16b\n" + ".inst 0x4e99a4fb // smmla v27.4s, v7.16b, v25.16b\n" + ".inst 0x4e9fa692 // smmla v18.4s, v20.16b, v31.16b\n" + ".inst 0x4e83a684 // smmla v4.4s, v20.16b, v3.16b\n" + ".inst 0x4e86a68c // smmla v12.4s, v20.16b, v6.16b\n" + ".inst 0x4e99a69a // smmla v26.4s, v20.16b, v25.16b\n" + ".inst 0x4e90a6ad // smmla v13.4s, v21.16b, v16.16b\n" + ".inst 0x4e97a6ab // smmla v11.4s, v21.16b, v23.16b\n" + ".inst 0x4e80a6b6 // smmla v22.4s, v21.16b, v0.16b\n" + ".inst 0x4e93a6af // smmla v15.4s, v21.16b, v19.16b\n" + ".inst 0x4e90a4ae // smmla v14.4s, v5.16b, v16.16b\n" + ".inst 0x4e97a4a1 // smmla v1.4s, v5.16b, v23.16b\n" + ".inst 0x4e80a4b1 // smmla v17.4s, v5.16b, v0.16b\n" + ".inst 0x4e93a4a8 // smmla v8.4s, v5.16b, v19.16b\n" + ".inst 0x4e90a45e // smmla v30.4s, v2.16b, v16.16b\n" + ".inst 0x4e97a44a // smmla v10.4s, v2.16b, v23.16b\n" + ".inst 0x4e80a45c // smmla v28.4s, v2.16b, v0.16b\n" + ".inst 0x4e93a45b // smmla v27.4s, v2.16b, v19.16b\n" + ".inst 0x4e90a532 // smmla v18.4s, v9.16b, v16.16b\n" + ".inst 0x4e97a524 // smmla v4.4s, v9.16b, v23.16b\n" + ".inst 0x4e80a52c // smmla v12.4s, v9.16b, v0.16b\n" + ".inst 0x4e93a53a // smmla v26.4s, v9.16b, v19.16b\n" "bgt 3b\n" - "ldr q6, [x10, #0x0]\n" - "ldr q22, [x10, #0x10]\n" - "uzp1 v9.2d, v29.2d, v18.2d\n" - "uzp2 v2.2d, v29.2d, v18.2d\n" - "ld1 { v21.4s }, [x22]\n" - "ldr q20, [x10, #0x20]\n" - "uzp1 v16.2d, v12.2d, v4.2d\n" - "uzp2 v23.2d, v12.2d, v4.2d\n" - "ldr q10, [x10, #0x30]\n" - "uzp1 v17.2d, v11.2d, v24.2d\n" - "uzp2 v24.2d, v11.2d, v24.2d\n" + "ldr q5, [x10, #0x0]\n" + "ldr q19, [x10, #0x10]\n" + "uzp1 v2.2d, v13.2d, v11.2d\n" + "uzp2 v20.2d, v13.2d, v11.2d\n" + "ld1 { v11.4s }, [x22]\n" + "ldr q23, [x10, #0x20]\n" + "uzp1 v9.2d, v22.2d, v15.2d\n" + "uzp2 v29.2d, v22.2d, v15.2d\n" + "ldr q6, [x10, #0x30]\n" + "uzp1 v31.2d, v14.2d, v1.2d\n" + "uzp2 v7.2d, v14.2d, v1.2d\n" "add x22, x22, #0x10\n" - "ldr q28, [x22, #0x0]\n" - "uzp1 v0.2d, v14.2d, v1.2d\n" - "uzp2 v27.2d, v14.2d, v1.2d\n" + "ldr q22, [x22, #0x0]\n" + "uzp1 v0.2d, v17.2d, v8.2d\n" + "uzp2 v16.2d, v17.2d, v8.2d\n" "add x10, x10, #0x40\n" - "mla v9.4s, v6.4s, v21.s[0]\n" - "mla v16.4s, v22.4s, v21.s[0]\n" - "mla v2.4s, v6.4s, v21.s[1]\n" - "mla v23.4s, v22.4s, v21.s[1]\n" - "mla v17.4s, v6.4s, v21.s[2]\n" - "mla v0.4s, v22.4s, v21.s[2]\n" - "fmul v12.4s, v20.4s, v28.s[0]\n" - "mla v24.4s, v6.4s, v21.s[3]\n" - "mla v27.4s, v22.4s, v21.s[3]\n" - "fmul v11.4s, v10.4s, v28.s[0]\n" - "scvtf v9.4s, v9.4s\n" - "scvtf v16.4s, v16.4s\n" - "fmul v18.4s, v20.4s, v28.s[1]\n" + "mla v2.4s, v5.4s, v11.s[0]\n" + "mla v9.4s, v19.4s, v11.s[0]\n" + "mla v20.4s, v5.4s, v11.s[1]\n" + "mla v29.4s, v19.4s, v11.s[1]\n" + "mla v31.4s, v5.4s, v11.s[2]\n" + "mla v0.4s, v19.4s, v11.s[2]\n" + "fmul v15.4s, v23.4s, v22.s[0]\n" + "mla v7.4s, v5.4s, v11.s[3]\n" + "mla v16.4s, v19.4s, v11.s[3]\n" + "fmul v11.4s, v6.4s, v22.s[0]\n" "scvtf v2.4s, v2.4s\n" - "fmul v1.4s, v10.4s, v28.s[1]\n" - "scvtf v23.4s, v23.4s\n" - "fmul v14.4s, v20.4s, v28.s[2]\n" - "scvtf v17.4s, v17.4s\n" - "fmul v3.4s, v10.4s, v28.s[2]\n" + "scvtf v9.4s, v9.4s\n" + "fmul v25.4s, v23.4s, v22.s[1]\n" + "scvtf v20.4s, v20.4s\n" + "fmul v14.4s, v6.4s, v22.s[1]\n" + "scvtf v29.4s, v29.4s\n" + "fmul v1.4s, v23.4s, v22.s[2]\n" + "scvtf v31.4s, v31.4s\n" + "fmul v17.4s, v6.4s, v22.s[2]\n" "scvtf v0.4s, v0.4s\n" - "fmul v8.4s, v20.4s, v28.s[3]\n" - "scvtf v24.4s, v24.4s\n" - "fmul v28.4s, v10.4s, v28.s[3]\n" - "scvtf v27.4s, v27.4s\n" - "fmul v29.4s, v9.4s, v12.4s\n" - "fmul v12.4s, v16.4s, v11.4s\n" - "fmul v18.4s, v2.4s, v18.4s\n" - "fmul v4.4s, v23.4s, v1.4s\n" - "fmul v11.4s, v17.4s, v14.4s\n" - "fmul v14.4s, v0.4s, v3.4s\n" - "fmul v24.4s, v24.4s, v8.4s\n" - "fmul v1.4s, v27.4s, v28.4s\n" - "ld1 { v0.4s }, [x20]\n" - "uzp1 v23.2d, v19.2d, v31.2d\n" - "uzp2 v2.2d, v19.2d, v31.2d\n" + "fmul v21.4s, v23.4s, v22.s[3]\n" + "scvtf v7.4s, v7.4s\n" + "fmul v3.4s, v6.4s, v22.s[3]\n" + "scvtf v16.4s, v16.4s\n" + "fmul v13.4s, v2.4s, v15.4s\n" + "fmul v22.4s, v9.4s, v11.4s\n" + "fmul v11.4s, v20.4s, v25.4s\n" + "fmul v15.4s, v29.4s, v14.4s\n" + "fmul v14.4s, v31.4s, v1.4s\n" + "fmul v17.4s, v0.4s, v17.4s\n" + "fmul v1.4s, v7.4s, v21.4s\n" + "fmul v8.4s, v16.4s, v3.4s\n" + "ld1 { v20.4s }, [x20]\n" + "uzp1 v2.2d, v30.2d, v10.2d\n" + "uzp2 v10.2d, v30.2d, v10.2d\n" "add x20, x20, #0x10\n" - "ldr q16, [x20, #0x0]\n" - "uzp1 v3.2d, v30.2d, v26.2d\n" - "uzp2 v21.2d, v30.2d, v26.2d\n" - "uzp1 v27.2d, v5.2d, v7.2d\n" - "uzp2 v9.2d, v5.2d, v7.2d\n" - "uzp1 v7.2d, v13.2d, v15.2d\n" - "uzp2 v28.2d, v13.2d, v15.2d\n" - "mla v23.4s, v6.4s, v0.s[0]\n" - "mla v3.4s, v22.4s, v0.s[0]\n" - "mla v2.4s, v6.4s, v0.s[1]\n" - "fmul v30.4s, v20.4s, v16.s[0]\n" - "mla v21.4s, v22.4s, v0.s[1]\n" - "mla v27.4s, v6.4s, v0.s[2]\n" - "fmul v5.4s, v10.4s, v16.s[0]\n" - "mla v7.4s, v22.4s, v0.s[2]\n" - "mla v9.4s, v6.4s, v0.s[3]\n" - "fmul v15.4s, v20.4s, v16.s[1]\n" - "mla v28.4s, v22.4s, v0.s[3]\n" - "scvtf v23.4s, v23.4s\n" - "scvtf v3.4s, v3.4s\n" + "ldr q3, [x20, #0x0]\n" + "uzp1 v0.2d, v28.2d, v27.2d\n" + "uzp2 v31.2d, v28.2d, v27.2d\n" + "uzp1 v29.2d, v18.2d, v4.2d\n" + "uzp2 v9.2d, v18.2d, v4.2d\n" + "uzp1 v4.2d, v12.2d, v26.2d\n" + "uzp2 v21.2d, v12.2d, v26.2d\n" + "mla v2.4s, v5.4s, v20.s[0]\n" + "mla v0.4s, v19.4s, v20.s[0]\n" + "mla v10.4s, v5.4s, v20.s[1]\n" + "fmul v30.4s, v23.4s, v3.s[0]\n" + "mla v31.4s, v19.4s, v20.s[1]\n" + "mla v29.4s, v5.4s, v20.s[2]\n" + "fmul v7.4s, v6.4s, v3.s[0]\n" + "mla v4.4s, v19.4s, v20.s[2]\n" + "mla v9.4s, v5.4s, v20.s[3]\n" + "fmul v18.4s, v23.4s, v3.s[1]\n" + "mla v21.4s, v19.4s, v20.s[3]\n" "scvtf v2.4s, v2.4s\n" - "fmul v6.4s, v10.4s, v16.s[1]\n" - "scvtf v21.4s, v21.4s\n" - "fmul v13.4s, v20.4s, v16.s[2]\n" - "scvtf v27.4s, v27.4s\n" - "fmul v8.4s, v10.4s, v16.s[2]\n" - "scvtf v7.4s, v7.4s\n" - "fmul v0.4s, v20.4s, v16.s[3]\n" + "scvtf v0.4s, v0.4s\n" + "scvtf v10.4s, v10.4s\n" + "fmul v27.4s, v6.4s, v3.s[1]\n" + "scvtf v31.4s, v31.4s\n" + "fmul v20.4s, v23.4s, v3.s[2]\n" + "scvtf v29.4s, v29.4s\n" + "fmul v19.4s, v6.4s, v3.s[2]\n" + "scvtf v4.4s, v4.4s\n" + "fmul v23.4s, v23.4s, v3.s[3]\n" "scvtf v9.4s, v9.4s\n" - "fmul v16.4s, v10.4s, v16.s[3]\n" - "scvtf v28.4s, v28.4s\n" - "fmul v19.4s, v23.4s, v30.4s\n" - "fmul v30.4s, v3.4s, v5.4s\n" - "fmul v31.4s, v2.4s, v15.4s\n" + "fmul v6.4s, v6.4s, v3.s[3]\n" + "scvtf v21.4s, v21.4s\n" + "fmul v30.4s, v2.4s, v30.4s\n" + "fmul v28.4s, v0.4s, v7.4s\n" + "fmul v10.4s, v10.4s, v18.4s\n" + "fmul v27.4s, v31.4s, v27.4s\n" + "fmul v18.4s, v29.4s, v20.4s\n" + "fmul v12.4s, v4.4s, v19.4s\n" + "fmul v4.4s, v9.4s, v23.4s\n" "fmul v26.4s, v21.4s, v6.4s\n" - "fmul v5.4s, v27.4s, v13.4s\n" - "fmul v13.4s, v7.4s, v8.4s\n" - "fmul v7.4s, v9.4s, v0.4s\n" - "fmul v15.4s, v28.4s, v16.4s\n" - "ld1r { v3.4s }, [%x[clamp_vals]]\n" + "ldr q20, [x10, #0x0]\n" + "ldr q19, [x10, #0x10]\n" "add x20, %x[clamp_vals], #0x4\n" "cmp x9, #0x8\n" - "ld1r { v16.4s }, [x20]\n" + "ld1r { v9.4s }, [%x[clamp_vals]]\n" + "ld1r { v6.4s }, [x20]\n" "add x10, x10, #0x20\n" - "fmax v29.4s, v29.4s, v3.4s\n" - "fmax v12.4s, v12.4s, v3.4s\n" - "fmax v18.4s, v18.4s, v3.4s\n" - "fmax v4.4s, v4.4s, v3.4s\n" - "fmax v11.4s, v11.4s, v3.4s\n" - "fmax v14.4s, v14.4s, v3.4s\n" - "fmax v24.4s, v24.4s, v3.4s\n" - "fmax v1.4s, v1.4s, v3.4s\n" - "fmax v19.4s, v19.4s, v3.4s\n" - "fmax v30.4s, v30.4s, v3.4s\n" - "fmax v31.4s, v31.4s, v3.4s\n" - "fmax v26.4s, v26.4s, v3.4s\n" - "fmax v5.4s, v5.4s, v3.4s\n" - "fmax v13.4s, v13.4s, v3.4s\n" - "fmax v7.4s, v7.4s, v3.4s\n" - "fmax v15.4s, v15.4s, v3.4s\n" - "fmin v29.4s, v29.4s, v16.4s\n" - "fmin v12.4s, v12.4s, v16.4s\n" - "fmin v18.4s, v18.4s, v16.4s\n" - "fmin v4.4s, v4.4s, v16.4s\n" - "fmin v11.4s, v11.4s, v16.4s\n" - "fmin v14.4s, v14.4s, v16.4s\n" - "fmin v24.4s, v24.4s, v16.4s\n" - "fmin v1.4s, v1.4s, v16.4s\n" - "fmin v19.4s, v19.4s, v16.4s\n" - "fmin v30.4s, v30.4s, v16.4s\n" - "fmin v31.4s, v31.4s, v16.4s\n" - "fmin v26.4s, v26.4s, v16.4s\n" - "fmin v5.4s, v5.4s, v16.4s\n" - "fmin v13.4s, v13.4s, v16.4s\n" - "fmin v7.4s, v7.4s, v16.4s\n" - "fmin v15.4s, v15.4s, v16.4s\n" + "fadd v13.4s, v13.4s, v20.4s\n" + "fadd v22.4s, v22.4s, v19.4s\n" + "fadd v11.4s, v11.4s, v20.4s\n" + "fadd v15.4s, v15.4s, v19.4s\n" + "fadd v14.4s, v14.4s, v20.4s\n" + "fadd v17.4s, v17.4s, v19.4s\n" + "fadd v1.4s, v1.4s, v20.4s\n" + "fadd v8.4s, v8.4s, v19.4s\n" + "fadd v30.4s, v30.4s, v20.4s\n" + "fadd v28.4s, v28.4s, v19.4s\n" + "fadd v10.4s, v10.4s, v20.4s\n" + "fadd v27.4s, v27.4s, v19.4s\n" + "fadd v18.4s, v18.4s, v20.4s\n" + "fadd v12.4s, v12.4s, v19.4s\n" + "fadd v4.4s, v4.4s, v20.4s\n" + "fadd v26.4s, v26.4s, v19.4s\n" + "fmax v13.4s, v13.4s, v9.4s\n" + "fmax v22.4s, v22.4s, v9.4s\n" + "fmax v11.4s, v11.4s, v9.4s\n" + "fmax v15.4s, v15.4s, v9.4s\n" + "fmax v14.4s, v14.4s, v9.4s\n" + "fmax v17.4s, v17.4s, v9.4s\n" + "fmax v1.4s, v1.4s, v9.4s\n" + "fmax v8.4s, v8.4s, v9.4s\n" + "fmax v30.4s, v30.4s, v9.4s\n" + "fmax v28.4s, v28.4s, v9.4s\n" + "fmax v10.4s, v10.4s, v9.4s\n" + "fmax v27.4s, v27.4s, v9.4s\n" + "fmax v18.4s, v18.4s, v9.4s\n" + "fmax v12.4s, v12.4s, v9.4s\n" + "fmax v4.4s, v4.4s, v9.4s\n" + "fmax v26.4s, v26.4s, v9.4s\n" + "fmin v13.4s, v13.4s, v6.4s\n" + "fmin v22.4s, v22.4s, v6.4s\n" + "fmin v11.4s, v11.4s, v6.4s\n" + "fmin v15.4s, v15.4s, v6.4s\n" + "fmin v14.4s, v14.4s, v6.4s\n" + "fmin v17.4s, v17.4s, v6.4s\n" + "fmin v1.4s, v1.4s, v6.4s\n" + "fmin v8.4s, v8.4s, v6.4s\n" + "fmin v30.4s, v30.4s, v6.4s\n" + "fmin v28.4s, v28.4s, v6.4s\n" + "fmin v10.4s, v10.4s, v6.4s\n" + "fmin v27.4s, v27.4s, v6.4s\n" + "fmin v18.4s, v18.4s, v6.4s\n" + "fmin v12.4s, v12.4s, v6.4s\n" + "fmin v4.4s, v4.4s, v6.4s\n" + "fmin v26.4s, v26.4s, v6.4s\n" "blt 6f\n" "mov x20, %x[dst]\n" - "str q29, [x20, #0x0]\n" - "str q12, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q18, [x20, #0x0]\n" - "str q4, [x20, #0x10]\n" + "str q13, [x20, #0x0]\n" + "str q22, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" "str q11, [x20, #0x0]\n" - "str q14, [x20, #0x10]\n" + "str q15, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q24, [x20, #0x0]\n" - "str q1, [x20, #0x10]\n" + "str q14, [x20, #0x0]\n" + "str q17, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q19, [x20, #0x0]\n" - "str q30, [x20, #0x10]\n" + "str q1, [x20, #0x0]\n" + "str q8, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q31, [x20, #0x0]\n" - "str q26, [x20, #0x10]\n" + "str q30, [x20, #0x0]\n" + "str q28, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q5, [x20, #0x0]\n" - "str q13, [x20, #0x10]\n" + "str q10, [x20, #0x0]\n" + "str q27, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q7, [x20, #0x0]\n" - "str q15, [x20, #0x10]\n" + "str q18, [x20, #0x0]\n" + "str q12, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q4, [x20, #0x0]\n" + "str q26, [x20, #0x10]\n" "b 11f\n" "6:" // Partial output "mov x27, %x[dst]\n" @@ -417,73 +435,73 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( "add x21, x27, %x[dst_stride_row]\n" "add x20, x22, %x[dst_stride_row]\n" "tbz x9, #2, 8f\n" - "st1 { v7.4s }, [x23], #0x10\n" - "st1 { v5.4s }, [x25], #0x10\n" - "st1 { v31.4s }, [x24], #0x10\n" - "st1 { v19.4s }, [x26], #0x10\n" - "st1 { v24.4s }, [x20], #0x10\n" - "st1 { v11.4s }, [x22], #0x10\n" - "st1 { v18.4s }, [x21], #0x10\n" - "st1 { v29.4s }, [x27], #0x10\n" + "st1 { v4.4s }, [x23], #0x10\n" + "st1 { v18.4s }, [x25], #0x10\n" + "st1 { v10.4s }, [x24], #0x10\n" + "st1 { v30.4s }, [x26], #0x10\n" + "st1 { v1.4s }, [x20], #0x10\n" + "st1 { v14.4s }, [x22], #0x10\n" + "st1 { v11.4s }, [x21], #0x10\n" + "st1 { v13.4s }, [x27], #0x10\n" "tbz x9, #1, 7f\n" - "st1 { v15.d }[0], [x23], #0x8\n" - "st1 { v13.d }[0], [x25], #0x8\n" - "st1 { v26.d }[0], [x24], #0x8\n" + "st1 { v26.d }[0], [x23], #0x8\n" + "st1 { v12.d }[0], [x25], #0x8\n" + "st1 { v27.d }[0], [x24], #0x8\n" + "st1 { v28.d }[0], [x26], #0x8\n" + "st1 { v8.d }[0], [x20], #0x8\n" + "st1 { v17.d }[0], [x22], #0x8\n" + "st1 { v15.d }[0], [x21], #0x8\n" + "st1 { v22.d }[0], [x27], #0x8\n" + "tbz x9, #0, 10f\n" + "st1 { v26.s }[2], [x23]\n" + "st1 { v12.s }[2], [x25]\n" + "st1 { v27.s }[2], [x24]\n" + "st1 { v28.s }[2], [x26]\n" + "st1 { v8.s }[2], [x20]\n" + "st1 { v17.s }[2], [x22]\n" + "st1 { v15.s }[2], [x21]\n" + "st1 { v22.s }[2], [x27]\n" + "b 10f\n" + "7:" // Output block 0: partial_1_4 + "tbz x9, #0, 10f\n" + "st1 { v26.s }[0], [x23]\n" + "st1 { v12.s }[0], [x25]\n" + "st1 { v27.s }[0], [x24]\n" + "st1 { v28.s }[0], [x26]\n" + "st1 { v8.s }[0], [x20]\n" + "st1 { v17.s }[0], [x22]\n" + "st1 { v15.s }[0], [x21]\n" + "st1 { v22.s }[0], [x27]\n" + "b 10f\n" + "8:" // Output block 0: partial_2_0 + "tbz x9, #1, 9f\n" + "st1 { v4.d }[0], [x23], #0x8\n" + "st1 { v18.d }[0], [x25], #0x8\n" + "st1 { v10.d }[0], [x24], #0x8\n" "st1 { v30.d }[0], [x26], #0x8\n" "st1 { v1.d }[0], [x20], #0x8\n" "st1 { v14.d }[0], [x22], #0x8\n" - "st1 { v4.d }[0], [x21], #0x8\n" - "st1 { v12.d }[0], [x27], #0x8\n" + "st1 { v11.d }[0], [x21], #0x8\n" + "st1 { v13.d }[0], [x27], #0x8\n" "tbz x9, #0, 10f\n" - "st1 { v15.s }[2], [x23]\n" - "st1 { v13.s }[2], [x25]\n" - "st1 { v26.s }[2], [x24]\n" + "st1 { v4.s }[2], [x23]\n" + "st1 { v18.s }[2], [x25]\n" + "st1 { v10.s }[2], [x24]\n" "st1 { v30.s }[2], [x26]\n" "st1 { v1.s }[2], [x20]\n" "st1 { v14.s }[2], [x22]\n" - "st1 { v4.s }[2], [x21]\n" - "st1 { v12.s }[2], [x27]\n" + "st1 { v11.s }[2], [x21]\n" + "st1 { v13.s }[2], [x27]\n" "b 10f\n" - "7:" // Output block 0: partial_1_4 - "tbz x9, #0, 10f\n" - "st1 { v15.s }[0], [x23]\n" - "st1 { v13.s }[0], [x25]\n" - "st1 { v26.s }[0], [x24]\n" + "9:" // Output block 0: partial_1_0 + "st1 { v4.s }[0], [x23]\n" + "st1 { v18.s }[0], [x25]\n" + "st1 { v10.s }[0], [x24]\n" "st1 { v30.s }[0], [x26]\n" "st1 { v1.s }[0], [x20]\n" "st1 { v14.s }[0], [x22]\n" - "st1 { v4.s }[0], [x21]\n" - "st1 { v12.s }[0], [x27]\n" - "b 10f\n" - "8:" // Output block 0: partial_2_0 - "tbz x9, #1, 9f\n" - "st1 { v7.d }[0], [x23], #0x8\n" - "st1 { v5.d }[0], [x25], #0x8\n" - "st1 { v31.d }[0], [x24], #0x8\n" - "st1 { v19.d }[0], [x26], #0x8\n" - "st1 { v24.d }[0], [x20], #0x8\n" - "st1 { v11.d }[0], [x22], #0x8\n" - "st1 { v18.d }[0], [x21], #0x8\n" - "st1 { v29.d }[0], [x27], #0x8\n" - "tbz x9, #0, 10f\n" - "st1 { v7.s }[2], [x23]\n" - "st1 { v5.s }[2], [x25]\n" - "st1 { v31.s }[2], [x24]\n" - "st1 { v19.s }[2], [x26]\n" - "st1 { v24.s }[2], [x20]\n" - "st1 { v11.s }[2], [x22]\n" - "st1 { v18.s }[2], [x21]\n" - "st1 { v29.s }[2], [x27]\n" - "b 10f\n" - "9:" // Output block 0: partial_1_0 - "st1 { v7.s }[0], [x23]\n" - "st1 { v5.s }[0], [x25]\n" - "st1 { v31.s }[0], [x24]\n" - "st1 { v19.s }[0], [x26]\n" - "st1 { v24.s }[0], [x20]\n" - "st1 { v11.s }[0], [x22]\n" - "st1 { v18.s }[0], [x21]\n" - "st1 { v29.s }[0], [x27]\n" + "st1 { v11.s }[0], [x21]\n" + "st1 { v13.s }[0], [x27]\n" "10:" // Output block 0: Done "11:" // Output stage exit "subs x9, x9, #0x8\n" @@ -503,172 +521,182 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" "14:" // Row tail: Column loop "mov x22, %x[lhs_packed]\n" - "movi v29.4s, #0x0\n" - "movi v12.4s, #0x0\n" + "movi v13.4s, #0x0\n" + "movi v22.4s, #0x0\n" "mov x20, %x[num_blocks]\n" - "movi v18.4s, #0x0\n" - "movi v4.4s, #0x0\n" "movi v11.4s, #0x0\n" + "movi v15.4s, #0x0\n" "movi v14.4s, #0x0\n" - "movi v24.4s, #0x0\n" + "movi v17.4s, #0x0\n" "movi v1.4s, #0x0\n" + "movi v8.4s, #0x0\n" "15:" // Row tail: Sub block loop "ldr q16, [x26, #0x0]\n" "ldr q7, [x26, #0x10]\n" "subs x20, x20, #0x1\n" "ldr q6, [x26, #0x20]\n" - "ldr q17, [x26, #0x30]\n" - "ldr q20, [x22, #0x0]\n" - "ldr q3, [x22, #0x10]\n" - "ldr q26, [x26, #0x40]\n" - "ldr q21, [x26, #0x50]\n" - "shl v13.16b, v16.16b, #0x4\n" - "shl v28.16b, v7.16b, #0x4\n" + "ldr q5, [x26, #0x30]\n" + "ldr q4, [x22, #0x0]\n" + "ldr q9, [x22, #0x10]\n" + "ldr q10, [x26, #0x40]\n" + "ldr q3, [x26, #0x50]\n" + "shl v0.16b, v16.16b, #0x4\n" + "shl v19.16b, v7.16b, #0x4\n" "ldr q31, [x26, #0x60]\n" - "ldr q8, [x26, #0x70]\n" - "shl v2.16b, v6.16b, #0x4\n" - "shl v0.16b, v17.16b, #0x4\n" - "ldr q10, [x22, #0x20]\n" - "ldr q9, [x22, #0x30]\n" - "and v16.16b, v16.16b, v25.16b\n" - "and v7.16b, v7.16b, v25.16b\n" - "ldr q27, [x22, #0x40]\n" + "ldr q27, [x26, #0x70]\n" + "shl v18.16b, v6.16b, #0x4\n" + "shl v12.16b, v5.16b, #0x4\n" + "ldr q29, [x22, #0x20]\n" + "ldr q28, [x22, #0x30]\n" + "and v16.16b, v16.16b, v24.16b\n" + "and v7.16b, v7.16b, v24.16b\n" + "ldr q2, [x22, #0x40]\n" "ldr q23, [x22, #0x50]\n" - ".inst 0x4e8da69d // smmla v29.4s, v20.16b, v13.16b\n" - ".inst 0x4e9ca692 // smmla v18.4s, v20.16b, v28.16b\n" - "ldr q22, [x22, #0x60]\n" - "ldr q15, [x22, #0x70]\n" - ".inst 0x4e82a68c // smmla v12.4s, v20.16b, v2.16b\n" - ".inst 0x4e80a684 // smmla v4.4s, v20.16b, v0.16b\n" - ".inst 0x4e8da46b // smmla v11.4s, v3.16b, v13.16b\n" - ".inst 0x4e9ca478 // smmla v24.4s, v3.16b, v28.16b\n" - "shl v20.16b, v26.16b, #0x4\n" + ".inst 0x4e80a48d // smmla v13.4s, v4.16b, v0.16b\n" + ".inst 0x4e93a48b // smmla v11.4s, v4.16b, v19.16b\n" + "ldr q30, [x22, #0x60]\n" + "ldr q21, [x22, #0x70]\n" + ".inst 0x4e92a496 // smmla v22.4s, v4.16b, v18.16b\n" + ".inst 0x4e8ca48f // smmla v15.4s, v4.16b, v12.16b\n" + ".inst 0x4e80a52e // smmla v14.4s, v9.16b, v0.16b\n" + ".inst 0x4e93a521 // smmla v1.4s, v9.16b, v19.16b\n" + "shl v20.16b, v10.16b, #0x4\n" "add x26, x26, #0x80\n" - ".inst 0x4e82a46e // smmla v14.4s, v3.16b, v2.16b\n" - ".inst 0x4e80a461 // smmla v1.4s, v3.16b, v0.16b\n" - "shl v0.16b, v21.16b, #0x4\n" + ".inst 0x4e92a531 // smmla v17.4s, v9.16b, v18.16b\n" + ".inst 0x4e8ca528 // smmla v8.4s, v9.16b, v12.16b\n" + "shl v19.16b, v3.16b, #0x4\n" "add x22, x22, #0x80\n" - "shl v5.16b, v31.16b, #0x4\n" - "shl v13.16b, v8.16b, #0x4\n" - ".inst 0x4e94a55d // smmla v29.4s, v10.16b, v20.16b\n" - "and v6.16b, v6.16b, v25.16b\n" - "and v17.16b, v17.16b, v25.16b\n" - ".inst 0x4e80a552 // smmla v18.4s, v10.16b, v0.16b\n" - ".inst 0x4e94a52b // smmla v11.4s, v9.16b, v20.16b\n" - ".inst 0x4e80a538 // smmla v24.4s, v9.16b, v0.16b\n" - "and v26.16b, v26.16b, v25.16b\n" - ".inst 0x4e85a54c // smmla v12.4s, v10.16b, v5.16b\n" - ".inst 0x4e8da544 // smmla v4.4s, v10.16b, v13.16b\n" - "and v21.16b, v21.16b, v25.16b\n" - ".inst 0x4e85a52e // smmla v14.4s, v9.16b, v5.16b\n" - ".inst 0x4e8da521 // smmla v1.4s, v9.16b, v13.16b\n" - "and v31.16b, v31.16b, v25.16b\n" - ".inst 0x4e90a77d // smmla v29.4s, v27.16b, v16.16b\n" - ".inst 0x4e87a772 // smmla v18.4s, v27.16b, v7.16b\n" - "and v8.16b, v8.16b, v25.16b\n" - ".inst 0x4e90a6eb // smmla v11.4s, v23.16b, v16.16b\n" - ".inst 0x4e87a6f8 // smmla v24.4s, v23.16b, v7.16b\n" - ".inst 0x4e86a76c // smmla v12.4s, v27.16b, v6.16b\n" - ".inst 0x4e91a764 // smmla v4.4s, v27.16b, v17.16b\n" - ".inst 0x4e86a6ee // smmla v14.4s, v23.16b, v6.16b\n" - ".inst 0x4e91a6e1 // smmla v1.4s, v23.16b, v17.16b\n" - ".inst 0x4e9aa6dd // smmla v29.4s, v22.16b, v26.16b\n" - ".inst 0x4e95a6d2 // smmla v18.4s, v22.16b, v21.16b\n" - ".inst 0x4e9aa5eb // smmla v11.4s, v15.16b, v26.16b\n" - ".inst 0x4e95a5f8 // smmla v24.4s, v15.16b, v21.16b\n" - ".inst 0x4e9fa6cc // smmla v12.4s, v22.16b, v31.16b\n" - ".inst 0x4e88a6c4 // smmla v4.4s, v22.16b, v8.16b\n" - ".inst 0x4e9fa5ee // smmla v14.4s, v15.16b, v31.16b\n" - ".inst 0x4e88a5e1 // smmla v1.4s, v15.16b, v8.16b\n" + "shl v18.16b, v31.16b, #0x4\n" + "shl v12.16b, v27.16b, #0x4\n" + ".inst 0x4e94a7ad // smmla v13.4s, v29.16b, v20.16b\n" + "and v6.16b, v6.16b, v24.16b\n" + "and v5.16b, v5.16b, v24.16b\n" + ".inst 0x4e93a7ab // smmla v11.4s, v29.16b, v19.16b\n" + ".inst 0x4e94a78e // smmla v14.4s, v28.16b, v20.16b\n" + ".inst 0x4e93a781 // smmla v1.4s, v28.16b, v19.16b\n" + "and v10.16b, v10.16b, v24.16b\n" + ".inst 0x4e92a7b6 // smmla v22.4s, v29.16b, v18.16b\n" + ".inst 0x4e8ca7af // smmla v15.4s, v29.16b, v12.16b\n" + "and v3.16b, v3.16b, v24.16b\n" + ".inst 0x4e92a791 // smmla v17.4s, v28.16b, v18.16b\n" + ".inst 0x4e8ca788 // smmla v8.4s, v28.16b, v12.16b\n" + "and v31.16b, v31.16b, v24.16b\n" + ".inst 0x4e90a44d // smmla v13.4s, v2.16b, v16.16b\n" + ".inst 0x4e87a44b // smmla v11.4s, v2.16b, v7.16b\n" + "and v27.16b, v27.16b, v24.16b\n" + ".inst 0x4e90a6ee // smmla v14.4s, v23.16b, v16.16b\n" + ".inst 0x4e87a6e1 // smmla v1.4s, v23.16b, v7.16b\n" + ".inst 0x4e86a456 // smmla v22.4s, v2.16b, v6.16b\n" + ".inst 0x4e85a44f // smmla v15.4s, v2.16b, v5.16b\n" + ".inst 0x4e86a6f1 // smmla v17.4s, v23.16b, v6.16b\n" + ".inst 0x4e85a6e8 // smmla v8.4s, v23.16b, v5.16b\n" + ".inst 0x4e8aa7cd // smmla v13.4s, v30.16b, v10.16b\n" + ".inst 0x4e83a7cb // smmla v11.4s, v30.16b, v3.16b\n" + ".inst 0x4e8aa6ae // smmla v14.4s, v21.16b, v10.16b\n" + ".inst 0x4e83a6a1 // smmla v1.4s, v21.16b, v3.16b\n" + ".inst 0x4e9fa7d6 // smmla v22.4s, v30.16b, v31.16b\n" + ".inst 0x4e9ba7cf // smmla v15.4s, v30.16b, v27.16b\n" + ".inst 0x4e9fa6b1 // smmla v17.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ba6a8 // smmla v8.4s, v21.16b, v27.16b\n" "bgt 15b\n" - "ldr q22, [x26, #0x0]\n" - "ldr q21, [x26, #0x10]\n" - "uzp1 v3.2d, v29.2d, v18.2d\n" - "uzp2 v31.2d, v29.2d, v18.2d\n" - "ld1 { v7.4s }, [x22]\n" + "ldr q21, [x26, #0x0]\n" + "ldr q20, [x26, #0x10]\n" + "uzp1 v9.2d, v13.2d, v11.2d\n" + "uzp2 v2.2d, v13.2d, v11.2d\n" + "ld1 { v19.4s }, [x22]\n" "ldr q27, [x26, #0x20]\n" - "uzp1 v15.2d, v12.2d, v4.2d\n" - "uzp2 v6.2d, v12.2d, v4.2d\n" - "ldr q28, [x26, #0x30]\n" - "uzp1 v8.2d, v11.2d, v24.2d\n" - "uzp2 v30.2d, v11.2d, v24.2d\n" + "uzp1 v0.2d, v22.2d, v15.2d\n" + "uzp2 v31.2d, v22.2d, v15.2d\n" + "ldr q13, [x26, #0x30]\n" + "uzp1 v29.2d, v14.2d, v1.2d\n" + "uzp2 v10.2d, v14.2d, v1.2d\n" "add x22, x22, #0x10\n" - "ldr q16, [x22, #0x0]\n" - "uzp1 v26.2d, v14.2d, v1.2d\n" - "uzp2 v20.2d, v14.2d, v1.2d\n" + "ldr q23, [x22, #0x0]\n" + "uzp1 v5.2d, v17.2d, v8.2d\n" + "uzp2 v18.2d, v17.2d, v8.2d\n" "add x26, x26, #0x40\n" - "mla v3.4s, v22.4s, v7.s[0]\n" - "mla v15.4s, v21.4s, v7.s[0]\n" - "mla v31.4s, v22.4s, v7.s[1]\n" - "mla v6.4s, v21.4s, v7.s[1]\n" - "mla v8.4s, v22.4s, v7.s[2]\n" - "mla v26.4s, v21.4s, v7.s[2]\n" - "fmul v23.4s, v27.4s, v16.s[0]\n" - "mla v30.4s, v22.4s, v7.s[3]\n" - "mla v20.4s, v21.4s, v7.s[3]\n" - "fmul v22.4s, v28.4s, v16.s[0]\n" - "scvtf v3.4s, v3.4s\n" - "scvtf v15.4s, v15.4s\n" - "fmul v21.4s, v27.4s, v16.s[1]\n" + "mla v9.4s, v21.4s, v19.s[0]\n" + "mla v0.4s, v20.4s, v19.s[0]\n" + "mla v2.4s, v21.4s, v19.s[1]\n" + "mla v31.4s, v20.4s, v19.s[1]\n" + "mla v29.4s, v21.4s, v19.s[2]\n" + "mla v5.4s, v20.4s, v19.s[2]\n" + "fmul v30.4s, v27.4s, v23.s[0]\n" + "mla v10.4s, v21.4s, v19.s[3]\n" + "mla v18.4s, v20.4s, v19.s[3]\n" + "fmul v17.4s, v13.4s, v23.s[0]\n" + "scvtf v9.4s, v9.4s\n" + "scvtf v0.4s, v0.4s\n" + "fmul v21.4s, v27.4s, v23.s[1]\n" + "scvtf v2.4s, v2.4s\n" + "fmul v20.4s, v13.4s, v23.s[1]\n" "scvtf v31.4s, v31.4s\n" - "fmul v10.4s, v28.4s, v16.s[1]\n" - "scvtf v6.4s, v6.4s\n" - "fmul v5.4s, v27.4s, v16.s[2]\n" - "scvtf v8.4s, v8.4s\n" - "fmul v7.4s, v28.4s, v16.s[2]\n" - "scvtf v26.4s, v26.4s\n" - "fmul v0.4s, v27.4s, v16.s[3]\n" - "scvtf v30.4s, v30.4s\n" - "fmul v16.4s, v28.4s, v16.s[3]\n" - "scvtf v20.4s, v20.4s\n" - "fmul v29.4s, v3.4s, v23.4s\n" - "fmul v12.4s, v15.4s, v22.4s\n" - "fmul v18.4s, v31.4s, v21.4s\n" - "fmul v4.4s, v6.4s, v10.4s\n" - "fmul v11.4s, v8.4s, v5.4s\n" - "fmul v14.4s, v26.4s, v7.4s\n" - "fmul v24.4s, v30.4s, v0.4s\n" - "fmul v1.4s, v20.4s, v16.4s\n" - "ld1r { v23.4s }, [%x[clamp_vals]]\n" + "fmul v19.4s, v27.4s, v23.s[2]\n" + "scvtf v29.4s, v29.4s\n" + "fmul v28.4s, v13.4s, v23.s[2]\n" + "scvtf v5.4s, v5.4s\n" + "fmul v26.4s, v27.4s, v23.s[3]\n" + "scvtf v10.4s, v10.4s\n" + "fmul v16.4s, v13.4s, v23.s[3]\n" + "scvtf v18.4s, v18.4s\n" + "fmul v13.4s, v9.4s, v30.4s\n" + "fmul v22.4s, v0.4s, v17.4s\n" + "fmul v11.4s, v2.4s, v21.4s\n" + "fmul v15.4s, v31.4s, v20.4s\n" + "fmul v14.4s, v29.4s, v19.4s\n" + "fmul v17.4s, v5.4s, v28.4s\n" + "fmul v1.4s, v10.4s, v26.4s\n" + "fmul v8.4s, v18.4s, v16.4s\n" + "ldr q19, [x26, #0x0]\n" + "ldr q18, [x26, #0x10]\n" "add x20, %x[clamp_vals], #0x4\n" "cmp x25, #0x8\n" - "ld1r { v16.4s }, [x20]\n" + "ld1r { v20.4s }, [%x[clamp_vals]]\n" + "ld1r { v27.4s }, [x20]\n" "add x26, x26, #0x20\n" - "fmax v29.4s, v29.4s, v23.4s\n" - "fmax v12.4s, v12.4s, v23.4s\n" - "fmax v18.4s, v18.4s, v23.4s\n" - "fmax v4.4s, v4.4s, v23.4s\n" - "fmax v11.4s, v11.4s, v23.4s\n" - "fmax v14.4s, v14.4s, v23.4s\n" - "fmax v24.4s, v24.4s, v23.4s\n" - "fmax v1.4s, v1.4s, v23.4s\n" - "fmin v29.4s, v29.4s, v16.4s\n" - "fmin v12.4s, v12.4s, v16.4s\n" - "fmin v18.4s, v18.4s, v16.4s\n" - "fmin v4.4s, v4.4s, v16.4s\n" - "fmin v11.4s, v11.4s, v16.4s\n" - "fmin v14.4s, v14.4s, v16.4s\n" - "fmin v24.4s, v24.4s, v16.4s\n" - "fmin v1.4s, v1.4s, v16.4s\n" + "fadd v13.4s, v13.4s, v19.4s\n" + "fadd v22.4s, v22.4s, v18.4s\n" + "fadd v11.4s, v11.4s, v19.4s\n" + "fadd v15.4s, v15.4s, v18.4s\n" + "fadd v14.4s, v14.4s, v19.4s\n" + "fadd v17.4s, v17.4s, v18.4s\n" + "fadd v1.4s, v1.4s, v19.4s\n" + "fadd v8.4s, v8.4s, v18.4s\n" + "fmax v13.4s, v13.4s, v20.4s\n" + "fmax v22.4s, v22.4s, v20.4s\n" + "fmax v11.4s, v11.4s, v20.4s\n" + "fmax v15.4s, v15.4s, v20.4s\n" + "fmax v14.4s, v14.4s, v20.4s\n" + "fmax v17.4s, v17.4s, v20.4s\n" + "fmax v1.4s, v1.4s, v20.4s\n" + "fmax v8.4s, v8.4s, v20.4s\n" + "fmin v13.4s, v13.4s, v27.4s\n" + "fmin v22.4s, v22.4s, v27.4s\n" + "fmin v11.4s, v11.4s, v27.4s\n" + "fmin v15.4s, v15.4s, v27.4s\n" + "fmin v14.4s, v14.4s, v27.4s\n" + "fmin v17.4s, v17.4s, v27.4s\n" + "fmin v1.4s, v1.4s, v27.4s\n" + "fmin v8.4s, v8.4s, v27.4s\n" "blt 17f\n" "mov x20, %x[dst]\n" "cmp x12, #0x1\n" - "str q29, [x20, #0x0]\n" - "str q12, [x20, #0x10]\n" + "str q13, [x20, #0x0]\n" + "str q22, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 22f\n" "cmp x12, #0x2\n" - "str q18, [x20, #0x0]\n" - "str q4, [x20, #0x10]\n" + "str q11, [x20, #0x0]\n" + "str q15, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 22f\n" "cmp x12, #0x3\n" - "str q11, [x20, #0x0]\n" - "str q14, [x20, #0x10]\n" + "str q14, [x20, #0x0]\n" + "str q17, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 22f\n" - "str q24, [x20, #0x0]\n" - "str q1, [x20, #0x10]\n" + "str q1, [x20, #0x0]\n" + "str q8, [x20, #0x10]\n" "b 22f\n" "17:" // Row tail: Partial output "mov x23, %x[dst]\n" @@ -682,45 +710,45 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( "add x20, x21, %x[dst_stride_row]\n" "csel x20, x20, x21, GE\n" "tbz x25, #2, 19f\n" - "st1 { v24.4s }, [x20], #0x10\n" - "st1 { v11.4s }, [x21], #0x10\n" - "st1 { v18.4s }, [x22], #0x10\n" - "st1 { v29.4s }, [x23], #0x10\n" + "st1 { v1.4s }, [x20], #0x10\n" + "st1 { v14.4s }, [x21], #0x10\n" + "st1 { v11.4s }, [x22], #0x10\n" + "st1 { v13.4s }, [x23], #0x10\n" "tbz x25, #1, 18f\n" - "st1 { v1.d }[0], [x20], #0x8\n" - "st1 { v14.d }[0], [x21], #0x8\n" - "st1 { v4.d }[0], [x22], #0x8\n" - "st1 { v12.d }[0], [x23], #0x8\n" + "st1 { v8.d }[0], [x20], #0x8\n" + "st1 { v17.d }[0], [x21], #0x8\n" + "st1 { v15.d }[0], [x22], #0x8\n" + "st1 { v22.d }[0], [x23], #0x8\n" "tbz x25, #0, 21f\n" - "st1 { v1.s }[2], [x20]\n" - "st1 { v14.s }[2], [x21]\n" - "st1 { v4.s }[2], [x22]\n" - "st1 { v12.s }[2], [x23]\n" + "st1 { v8.s }[2], [x20]\n" + "st1 { v17.s }[2], [x21]\n" + "st1 { v15.s }[2], [x22]\n" + "st1 { v22.s }[2], [x23]\n" "b 21f\n" "18:" // Row tail: Output block 0: partial_1_4 "tbz x25, #0, 21f\n" - "st1 { v1.s }[0], [x20]\n" - "st1 { v14.s }[0], [x21]\n" - "st1 { v4.s }[0], [x22]\n" - "st1 { v12.s }[0], [x23]\n" + "st1 { v8.s }[0], [x20]\n" + "st1 { v17.s }[0], [x21]\n" + "st1 { v15.s }[0], [x22]\n" + "st1 { v22.s }[0], [x23]\n" "b 21f\n" "19:" // Row tail: Output block 0: partial_2_0 "tbz x25, #1, 20f\n" - "st1 { v24.d }[0], [x20], #0x8\n" - "st1 { v11.d }[0], [x21], #0x8\n" - "st1 { v18.d }[0], [x22], #0x8\n" - "st1 { v29.d }[0], [x23], #0x8\n" + "st1 { v1.d }[0], [x20], #0x8\n" + "st1 { v14.d }[0], [x21], #0x8\n" + "st1 { v11.d }[0], [x22], #0x8\n" + "st1 { v13.d }[0], [x23], #0x8\n" "tbz x25, #0, 21f\n" - "st1 { v24.s }[2], [x20]\n" - "st1 { v11.s }[2], [x21]\n" - "st1 { v18.s }[2], [x22]\n" - "st1 { v29.s }[2], [x23]\n" + "st1 { v1.s }[2], [x20]\n" + "st1 { v14.s }[2], [x21]\n" + "st1 { v11.s }[2], [x22]\n" + "st1 { v13.s }[2], [x23]\n" "b 21f\n" "20:" // Row tail: Output block 0: partial_1_0 - "st1 { v24.s }[0], [x20]\n" - "st1 { v11.s }[0], [x21]\n" - "st1 { v18.s }[0], [x22]\n" - "st1 { v29.s }[0], [x23]\n" + "st1 { v1.s }[0], [x20]\n" + "st1 { v14.s }[0], [x21]\n" + "st1 { v11.s }[0], [x22]\n" + "st1 { v13.s }[0], [x23]\n" "21:" // Row tail: Output block 0: Done "22:" // Row tail: Output stage exit "subs x25, x25, #0x8\n" diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h index 7c289441..57bd914c 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h @@ -17,7 +17,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix +/// -# kai_rhs_pack_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix /// -------------------------------------------------- diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h deleted file mode 100644 index bf947c1c..00000000 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h +++ /dev/null @@ -1,91 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#pragma once - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params { - int8_t lhs_zero_point; - uint8_t rhs_zero_point; -}; - -/// Get the n step value. -/// The micro-kernel can process any N values. However, the starting N index to -/// be processed must be a multiple of n step. -/// -/// @param[in] nr The number of columns written by the matmul micro-kernel -/// -/// @return the n step value -size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr); - -/// Gets the offset in bytes for the RHS matrix (not packed), which holds -/// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. -/// Two int4 values are stored in one byte. The lower order part of the byte (low) holds -/// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). -/// -/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. -/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) -/// -/// @return the offset in bytes to the RHS matrix (not packed) -size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride); - -/// Gets the offset in bytes for the packed RHS matrix, -/// which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. -/// -/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. -/// @param[in] k The common dimension between the LHS and RHS matrix (K) -/// @param[in] nr The number of columns written by the matmul micro-kernel -/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. -/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. -/// -/// @return the offset in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( - size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr); - -/// @brief Gets the size in bytes for the packed RHS matrix -/// -/// @param[in] n The number of rows in the RHS matrix (not packed) -/// @param[in] k The number of columns in the RHS matrix (not packed). -/// @param[in] nr The number of columns written by the matmul micro-kernel -/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. -/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. -/// -/// @return the packed RHS matrix size in bytes -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr); - -/// Run the micro-kernel to pack the RHS matrix. -/// -/// @note The int4 values are stored in a N x K matrix, where N is number of rows and K is the number of columns. -/// Two int4 values are stored in one byte. The lower order part of the byte (low) holds -/// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). -/// -/// @param[in] num_groups The number of groups. It must be 1. -/// @param[in] n The number of columns of the output matrix (N). -/// @param[in] k The common dimension between the LHS and RHS matrix (K). It must be an even value. -/// @param[in] nr The number of N columns to interleave on the same output output row. -/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. -/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. -/// However, kr must be multiple of sr. -/// @param[in] rhs The RHS matrix containing the 4-bit values. -/// Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). -/// @param[in] bias The biases. -/// @param[in] scale The scale for each output channel. -/// @param[out] rhs_packed The packed RHS matrix. -/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. -/// @param[in] params Parameters for the micro-kernel. -void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const int32_t* bias, - const float* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params); - -#ifdef __cplusplus -} -#endif diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.c similarity index 74% rename from kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c rename to kai/ukernels/matmul/pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.c index 275ccb55..b58937c3 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" +#include "kai_rhs_pack_qsi4cxp_qsu4cxs1s0.h" #include #include @@ -15,6 +15,8 @@ static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); +#define kai_get_rhs_packed_stride kai_get_rhs_packed_stride_rhs_pack_qsi4cxp_qsu4cxs1s0 + inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { // Since we pack a float and int32 value at the end of the row, // we must make sure that k is a multiple of 4 for alignment @@ -22,7 +24,19 @@ inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { return kai_roundup(k, kr_sr_roundedup4); } -inline static size_t kai_rhs_packed_stride(size_t k, size_t kr, size_t nr, size_t sr) { +size_t kai_get_n_step_rhs_pack_qsi4cxp_qsu4cxs1s0(size_t nr) { + return nr; +} + +size_t kai_get_rhs_offset_rhs_pack_qsi4cxp_qsu4cxs1s0(size_t n_idx, bool is_rhs_nxk, size_t rhs_stride) { + if (is_rhs_nxk) { + return n_idx * rhs_stride; + } else { + return n_idx * sizeof(int8_t); + } +} + +size_t kai_get_rhs_packed_stride_rhs_pack_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) { const size_t k_internal = kai_k_roundedup(k, kr, sr); KAI_ASSERT((k_internal % 2) == 0); @@ -30,31 +44,22 @@ inline static size_t kai_rhs_packed_stride(size_t k, size_t kr, size_t nr, size_ return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } -size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr) { - return nr; -} - -size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride) { - return n_idx * rhs_stride; -} - -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( - size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { +size_t kai_get_rhs_packed_offset_rhs_pack_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { KAI_ASSERT((n_idx % nr) == 0); - return (n_idx / nr) * kai_rhs_packed_stride(k, kr, nr, sr); + return (n_idx / nr) * kai_get_rhs_packed_stride(k, nr, kr, sr); } -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { +size_t kai_get_rhs_packed_size_rhs_pack_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * kai_rhs_packed_stride(k, kr, nr, sr); + return num_rows * kai_get_rhs_packed_stride(k, nr, kr, sr); } -void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const int32_t* bias, - const float* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params) { +void kai_run_rhs_pack_qsi4cxp_qsu4cxs1s0( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, bool is_rhs_nxk, const uint8_t* rhs, + const float* bias, const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_qsi4cxp_qsu4cxs1s0_params* params) { KAI_ASSERT((k % 2) == 0); KAI_ASSERT(num_groups == 1); KAI_ASSERT(extra_bytes == 0); @@ -67,15 +72,17 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( KAI_ASSERT(params->lhs_zero_point == 1); // Note: The input matrix (rhs) is expected with: - // "k" columns and "n" rows (NxK) + // "k" columns and "n" rows (NxK) if is_rhs_nxk = true + // "n" columns and "k" rows (KxN) if is_rhs_nxk = false const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_stride = k / 2; - const size_t rhs_packed_stride = kai_rhs_packed_stride(k, kr, nr, sr); + const size_t rhs_n_step = is_rhs_nxk == true ? k / 2 : sizeof(int8_t); + const size_t rhs_k_step = is_rhs_nxk == true ? sizeof(int8_t) : n * sizeof(int8_t); + const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, nr, kr, sr); const size_t k_internal = kai_k_roundedup(k, kr, sr); for (size_t y = 0; y < n; y += nr) { - const uint8_t* src_row = rhs + y * rhs_stride; + const uint8_t* src_row = rhs + y * rhs_n_step; uint8_t* dst_row = (uint8_t*)rhs_packed + (y / nr) * rhs_packed_stride; int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); @@ -92,8 +99,8 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( // Clamp the row index to avoid out-of-bound reads const size_t src_row_idx = y + i >= n ? 0 : i; - const size_t src_addr_byte0 = src_row_idx * rhs_stride + k_idx_start0; - const size_t src_addr_byte1 = src_row_idx * rhs_stride + k_idx_start1; + const size_t src_addr_byte0 = src_row_idx * rhs_n_step + k_idx_start0 * rhs_k_step; + const size_t src_addr_byte1 = src_row_idx * rhs_n_step + k_idx_start1 * rhs_k_step; uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.h new file mode 100644 index 00000000..584cd619 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.h @@ -0,0 +1,124 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifndef __cplusplus +#include +#endif +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct kai_rhs_pack_qsi4cxp_qsu4cxs1s0_params { + int8_t lhs_zero_point; + uint8_t rhs_zero_point; +}; + +/// Get the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// +/// @return the n step value +size_t kai_get_n_step_rhs_pack_qsi4cxp_qsu4cxs1s0(size_t nr); + +/// Gets the offset in bytes for the RHS matrix (not packed), which holds +/// the int4 values in a N x K matrix (if is_rhs_nxk = true), or in a K x N matrix (if is_rhs_nxk = false). +/// In the RHS matrix (not packed), N is number of rows if is_rhs_nxk = true or the number of columns if is_rhs_nxk = +/// false. In the RHS matrix (not packed), K is the number of columns if is_rhs_nxk = true or the number of rows if +/// is_rhs_nxk = false. Two int4 values are stored in one byte. The lower order part of the byte (low) holds the first +/// nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] is_rhs_nxk True if the RHS matrix not packed is N x K. It should be false otherwise. +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_qsi4cxp_qsu4cxs1s0(size_t n_idx, bool is_rhs_nxk, size_t rhs_stride); + +/// Get the row stride in bytes to the packed RHS matrix +/// +/// @param[in] k In the RHS matrix (not packed), K is the number of columns if is_rhs_nxk = true or the number of +/// rows if is_rhs_nxk = false. +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the stride in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_stride_rhs_pack_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr); + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k In the RHS matrix (not packed), K is the number of columns if is_rhs_nxk = true or the number of +/// rows if is_rhs_nxk = false. +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr); + +/// @brief Gets the size in bytes for the packed RHS matrix +/// +/// @param[in] n The number of rows in the RHS matrix (not packed) +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr); + +/// Run the micro-kernel to pack the RHS matrix. +/// +/// @note The int4 values are stored in a N x K matrix (if is_rhs_nxk = true) or in a K x N matrix (if is_rhs_nxk = +/// false). +/// In the RHS matrix (not packed), N is number of rows if is_rhs_nxk = true or the number of columns if +/// is_rhs_nxk = false. In the RHS matrix (not packed), K is the number of columns if is_rhs_nxk = true or the +/// number of rows if is_rhs_nxk = false. Two int4 values are stored in one byte. The lower order part of the +/// byte (low) holds the first nibble (K-index + 0). The higher order of the byte holds the second nibble +/// (K-index + 1). +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of rows if is_rhs_nxk = true or the number of columns if is_rhs_nxk = false. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). It must be an even value. +/// @param[in] nr The number of N rows (if is_rhs_nxk = true) or colums (if is_rhs_nxk = false) to interleave +/// on the same output output row. +/// @param[in] kr The number of K values loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] is_rhs_nxk True if the RHS matrix (not packed) is N x K. It should be false otherwise. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). +/// @param[in] bias The biases. +/// @param[in] scale The scale for each output channel. +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_qsi4cxp_qsu4cxs1s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + bool is_rhs_nxk, // + const uint8_t* rhs, // + const float* bias, // + const float* scale, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_qsi4cxp_qsu4cxs1s0_params* params); + +#ifdef __cplusplus +} +#endif -- GitLab From fef34618ef7e8398fa896dd6eef818f78415f21f Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 28 Jun 2024 15:07:00 +0100 Subject: [PATCH 3/4] Refactor the packing function for non-transposed RHS - Refactor the non-transposed case. Now, if the RHS matrix is not transposed, each byte holds two N values - Extend the example to support the NxK and KxN RHS matrices - Add support for non-even K values Signed-off-by: Gian Marco Iodice --- .../matmul_clamp_f32_qai8dxp_qsi4cxp.cpp | 311 +++++++++++++----- .../pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.c | 240 ++++++++++---- 2 files changed, 411 insertions(+), 140 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp index a74b2d64..12597671 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp @@ -3,13 +3,14 @@ // // SPDX-License-Identifier: Apache-2.0 // -#if 0 //! defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_MATMUL_INT8) +#if !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_MATMUL_INT8) #error "Dotprod and I8mm extensions required to compile this example" #else #include #include #include #include +#include #include #include @@ -111,11 +112,16 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm"}, + }; // Number of micro-kernel variants stored in the array const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); +static size_t roundup(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, size_t seed) { std::srand(seed); @@ -125,14 +131,84 @@ static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, si } } -static void quant_qs4cx_f32( - size_t n, size_t k, rhs_format format, const float* rhs_f32, uint8_t* rhs_qs4cx, float* rhs_scales_f32) { - const size_t dst_k_step = format == rhs_format::nxk ? sizeof(int8_t) : n * sizeof(int8_t); +static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { + const size_t dst_stride = (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); - const size_t dst_n_step = format == rhs_format::nxk ? (k / 2) * sizeof(int8_t) : sizeof(int8_t); + const size_t lhs_qa8dx_stride = k; - for (size_t row_idx = 0; row_idx < n; ++row_idx) { - const float* src_ptr = rhs_f32 + row_idx * k; + for (size_t n_idx = 0; n_idx < m; ++n_idx) { + const float* src_ptr = lhs_f32 + n_idx * lhs_qa8dx_stride; + + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + max0 = std::max(src0_0, max0); + min0 = std::min(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = std::min(0.0f, min0); + const float rmax0 = std::max(0.0f, max0); + + const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; + + zero_point0 = std::max(zero_point0, qmin); + zero_point0 = std::min(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = lrintf(zero_point0); + + int8_t* dst_ptr = (int8_t*)lhs_qa8dx + n_idx * dst_stride; + + // LHS offset at the beginning of the row + *((float*)(dst_ptr)) = recip_scale0; + dst_ptr += sizeof(float); + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + dst_ptr += sizeof(int32_t); + + // Quantize the channels + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = std::max(v0_s32, INT8_MIN); + v0_s32 = std::min(v0_s32, INT8_MAX); + dst_ptr[0] = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + } +}; + +static void quant_nxk_qs4cx_f32(size_t n, size_t k, const float* rhs_f32, uint8_t* rhs_qs4cx, float* rhs_scales_f32) { + const size_t rhs_qs4cx_stride = (roundup(k, 2) / 2); + + // Make sure the output is filled with zeros + std::memset(rhs_qs4cx, 0, n * rhs_qs4cx_stride); + + for (size_t n_idx = 0; n_idx < n; ++n_idx) { + const float* src_ptr = rhs_f32 + n_idx * k; float max0 = -FLT_MAX; float min0 = FLT_MAX; @@ -157,41 +233,42 @@ static void quant_qs4cx_f32( // Reciprocal to quantize const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; - uint8_t* dst_ptr = (uint8_t*)rhs_qs4cx + row_idx * dst_n_step; - // Quantize the channels - for (size_t k_idx = 0; k_idx < k; k_idx += 2) { - const float src0_0 = src_ptr[k_idx + 0]; - const float src0_1 = src_ptr[k_idx + 1]; + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; // Scale the values int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); - int32_t v1_s32 = (int32_t)(round(src0_1 * scale0)); // Maximum/minimum int4 values v0_s32 = std::max(v0_s32, INT4_MIN); v0_s32 = std::min(v0_s32, INT4_MAX); - v1_s32 = std::max(v1_s32, INT4_MIN); - v1_s32 = std::min(v1_s32, INT4_MAX); - int32_t v0_u8 = (uint8_t)(v0_s32 + 8); - int32_t v1_u8 = (uint8_t)(v1_s32 + 8); + const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8); - const uint8_t rhs_v0 = (v1_u8 << 4) | v0_u8; + const size_t dst_addr = (k_idx / 2) + n_idx * rhs_qs4cx_stride; + uint8_t rhs_v0 = rhs_qs4cx[dst_addr]; - dst_ptr[0] = rhs_v0; - dst_ptr += dst_k_step; + if ((k_idx % 2) == 0) { + rhs_v0 |= v0_u8; + } else { + rhs_v0 |= (v0_u8 << 4); + } + rhs_qs4cx[dst_addr] = rhs_v0; } - rhs_scales_f32[row_idx] = recip_scale0; + rhs_scales_f32[n_idx] = recip_scale0; } }; -static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { - const size_t dst_stride = (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); +static void quant_kxn_qs4cx_f32(size_t n, size_t k, const float* rhs_f32, uint8_t* rhs_qs4cx, float* rhs_scales_f32) { + const size_t rhs_qs4cx_stride = (roundup(n, 2) / 2); - for (size_t row_idx = 0; row_idx < m; ++row_idx) { - const float* src_ptr = lhs_f32 + row_idx * k; + // Make sure the output is filled with zeros + std::memset(rhs_qs4cx, 0, k * rhs_qs4cx_stride); + + for (size_t n_idx = 0; n_idx < n; ++n_idx) { + const float* src_ptr = rhs_f32 + n_idx * k; float max0 = -FLT_MAX; float min0 = FLT_MAX; @@ -205,8 +282,8 @@ static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t } // Maximum/minimum int8 values - const float qmin = (float)INT8_MIN; - const float qmax = (float)INT8_MAX; + const float qmin = (float)INT4_MIN; + const float qmax = (float)INT4_MAX; const float rmin0 = std::min(0.0f, min0); const float rmax0 = std::max(0.0f, max0); @@ -216,64 +293,125 @@ static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t // Reciprocal to quantize const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; - const float descaled_min0 = rmin0 * scale0; - const float descaled_max0 = rmax0 * scale0; + // Quantize the channels + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; - const float zero_point_from_min_error0 = qmin + descaled_min0; - const float zero_point_from_max_error0 = qmax + descaled_max0; + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); - float zero_point0 = - zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; + // Maximum/minimum int4 values + v0_s32 = std::max(v0_s32, INT4_MIN); + v0_s32 = std::min(v0_s32, INT4_MAX); - zero_point0 = std::max(zero_point0, qmin); - zero_point0 = std::min(zero_point0, qmax); + const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8); - // Round to nearest integer - const int32_t nudged_zero_point0 = lrintf(zero_point0); + const size_t dst_addr = (n_idx / 2) + k_idx * rhs_qs4cx_stride; + uint8_t rhs_v0 = rhs_qs4cx[dst_addr]; - int8_t* dst_ptr = (int8_t*)lhs_qa8dx + row_idx * dst_stride; + if ((n_idx % 2) == 0) { + rhs_v0 |= v0_u8; + } else { + rhs_v0 |= (v0_u8 << 4); + } + rhs_qs4cx[dst_addr] = rhs_v0; + } - // LHS offset at the beginning of the row - *((float*)(dst_ptr)) = recip_scale0; - dst_ptr += sizeof(float); - *((int32_t*)(dst_ptr)) = -nudged_zero_point0; - dst_ptr += sizeof(int32_t); + rhs_scales_f32[n_idx] = recip_scale0; + } +}; - // Quantize the channels - for (size_t k_idx = 0; k_idx < k; ++k_idx) { - const float src0_0 = src_ptr[k_idx]; +static void quant_qs4cx_f32( + size_t n, size_t k, rhs_format format, const float* rhs_f32, uint8_t* rhs_qs4cx, float* rhs_scales_f32) { + if (rhs_format::nxk == format) { + quant_nxk_qs4cx_f32(n, k, rhs_f32, rhs_qs4cx, rhs_scales_f32); + } else { + quant_kxn_qs4cx_f32(n, k, rhs_f32, rhs_qs4cx, rhs_scales_f32); + } +}; - // Scale the values - int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); +static void ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4cx( + size_t m, size_t n, size_t k, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4cx, const float* rhs_scales_f32, + float* dst_f32, float scalar_min, float scalar_max) { + const size_t lhs_stride = k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); - v0_s32 = v0_s32 + nudged_zero_point0; - v0_s32 = std::max(v0_s32, INT8_MIN); - v0_s32 = std::min(v0_s32, INT8_MAX); - dst_ptr[0] = (int8_t)v0_s32; - dst_ptr += sizeof(int8_t); + const size_t rhs_qs4cx_stride = (roundup(k, 2) / 2); + + for (size_t m_idx = 0; m_idx < m; ++m_idx) { + const int8_t* lhs_ptr_start = lhs_qa8dx + m_idx * lhs_stride; + + for (size_t n_idx = 0; n_idx < n; ++n_idx) { + // Main f32 accumulator + int32_t iacc = 0; + + const int8_t* lhs_ptr = lhs_ptr_start; + const uint8_t* rhs_ptr = rhs_qs4cx + n_idx * rhs_qs4cx_stride; + + // Get the LHS quantization parameters stored at the + // beginning of each row + const float lhs_scale = *(const float*)lhs_ptr; + lhs_ptr += sizeof(float); + + const int32_t lhs_offset = *(const int32_t*)lhs_ptr; + lhs_ptr += sizeof(int32_t); + + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + // Get the LHS values + const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; + + // Get the RHS values + const uint8_t rhs_byte = rhs_ptr[0]; + + // Unpack the RHS values + int32_t rhs_v0 = 0; + if ((k_idx % 2) == 0) { + rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8); + } else { + rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8); + } + + iacc += lhs_v0 * rhs_v0; + iacc += lhs_offset * rhs_v0; + + lhs_ptr += 1; + + // Increment only when k_idx is not a multiple of 2 + rhs_ptr += k_idx % 2; + } + + // Get the RHS scale + const float rhs_scale = rhs_scales_f32[n_idx]; + + float main_acc = iacc * rhs_scale; + + main_acc = main_acc * lhs_scale; + + // Clamp (min-max) operation + main_acc = std::max(main_acc, scalar_min); + main_acc = std::min(main_acc, scalar_max); + + dst_f32[0] = main_acc; + dst_f32 += 1; } } }; -static void ref_matmul_f32_qa8dx_qs4cx( - size_t m, size_t n, size_t k, rhs_format format, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4cx, - const float* rhs_scales_f32, float* dst_f32, float scalar_min, float scalar_max) { +static void ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4cx( + size_t m, size_t n, size_t k, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4cx, const float* rhs_scales_f32, + float* dst_f32, float scalar_min, float scalar_max) { const size_t lhs_stride = k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); - const size_t lhs_k_step = sizeof(int8_t) * 2; + const size_t rhs_qs4cx_stride = (roundup(n, 2) / 2); - const size_t rhs_k_step = format == rhs_format::nxk ? sizeof(int8_t) : n * sizeof(int8_t); + for (size_t m_idx = 0; m_idx < m; ++m_idx) { + const int8_t* lhs_ptr_start = lhs_qa8dx + m_idx * lhs_stride; - const size_t rhs_n_step = format == rhs_format::nxk ? (k / 2) * sizeof(int8_t) : sizeof(int8_t); - - for (size_t row_idx = 0; row_idx < m; ++row_idx) { - const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; - for (size_t col_idx = 0; col_idx < n; ++col_idx) { + for (size_t n_idx = 0; n_idx < n; ++n_idx) { // Main f32 accumulator int32_t iacc = 0; const int8_t* lhs_ptr = lhs_ptr_start; - const uint8_t* rhs_ptr = rhs_qs4cx + col_idx * rhs_n_step; + const uint8_t* rhs_ptr = rhs_qs4cx + (n_idx / 2); // Get the LHS quantization parameters stored at the // beginning of each row @@ -283,29 +421,32 @@ static void ref_matmul_f32_qa8dx_qs4cx( const int32_t lhs_offset = *(const int32_t*)lhs_ptr; lhs_ptr += sizeof(int32_t); - for (size_t b = 0; b < k; b += 2) { + for (size_t k_idx = 0; k_idx < k; ++k_idx) { // Get the LHS values const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; - const int32_t lhs_v1 = (int32_t)lhs_ptr[1]; // Get the RHS values const uint8_t rhs_byte = rhs_ptr[0]; // Unpack the RHS values - const int32_t rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8); - const int32_t rhs_v1 = (((int32_t)(rhs_byte >> 4)) - 8); + int32_t rhs_v0 = 0; + if ((n_idx % 2) == 0) { + rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8); + } else { + rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8); + } iacc += lhs_v0 * rhs_v0; - iacc += lhs_v1 * rhs_v1; iacc += lhs_offset * rhs_v0; - iacc += lhs_offset * rhs_v1; - lhs_ptr += lhs_k_step; - rhs_ptr += rhs_k_step; + lhs_ptr += 1; + + // Increment only when k_idx is not a multiple of 2 + rhs_ptr += rhs_qs4cx_stride; } // Get the RHS scale - const float rhs_scale = rhs_scales_f32[col_idx]; + const float rhs_scale = rhs_scales_f32[n_idx]; float main_acc = iacc * rhs_scale; @@ -321,6 +462,20 @@ static void ref_matmul_f32_qa8dx_qs4cx( } }; +static void ref_matmul_f32_qa8dx_qs4cx( + size_t m, size_t n, size_t k, rhs_format format, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4cx, + const float* rhs_scales_f32, float* dst_f32, float scalar_min, float scalar_max) { + const size_t lhs_stride = k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); + + if (rhs_format::nxk == format) { + ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4cx( + m, n, k, lhs_qa8dx, rhs_qs4cx, rhs_scales_f32, dst_f32, scalar_min, scalar_max); + } else { + ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4cx( + m, n, k, lhs_qa8dx, rhs_qs4cx, rhs_scales_f32, dst_f32, scalar_min, scalar_max); + } +}; + static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, const float* ref, const float* act) { bool is_valid = true; @@ -336,9 +491,9 @@ static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, } int main(int argc, char** argv) { - const size_t m = 13; - const size_t n = 17; - const size_t k = 18; + const size_t m = 17; + const size_t n = 7; + const size_t k = 21; const size_t seed_lhs = 4568; const size_t seed_rhs = seed_lhs + 4; @@ -348,7 +503,8 @@ int main(int argc, char** argv) { const size_t lhs_native_size_f32 = m * k * sizeof(float); const size_t rhs_native_size_f32 = n * k * sizeof(float); - const size_t rhs_native_size_qs4cx = n * (k / 2) * sizeof(uint8_t); + const size_t rhs_native_size_qs4cx = format == rhs_format::nxk ? n * (roundup(k, 2) / 2) * sizeof(uint8_t) + : k * (roundup(n, 2) / 2) * sizeof(uint8_t); const size_t rhs_scales_size_f32 = n * sizeof(float); // Allocate the memory @@ -466,6 +622,7 @@ int main(int argc, char** argv) { } else { printf("TEST[%ld] = FAILED\n", idx_variant); } + delete[] lhs_packed_mtx_qa8dx; delete[] rhs_packed_mtx_qs4cx; delete[] dst_act_mtx_f32; diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.c index b58937c3..25ad65a2 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_qsi4cxp_qsu4cxs1s0.c @@ -32,7 +32,8 @@ size_t kai_get_rhs_offset_rhs_pack_qsi4cxp_qsu4cxs1s0(size_t n_idx, bool is_rhs_ if (is_rhs_nxk) { return n_idx * rhs_stride; } else { - return n_idx * sizeof(int8_t); + KAI_ASSERT((n_idx % 2) == 0); + return (n_idx / 2) * sizeof(int8_t); } } @@ -60,7 +61,6 @@ void kai_run_rhs_pack_qsi4cxp_qsu4cxs1s0( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, bool is_rhs_nxk, const uint8_t* rhs, const float* bias, const float* scale, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qsi4cxp_qsu4cxs1s0_params* params) { - KAI_ASSERT((k % 2) == 0); KAI_ASSERT(num_groups == 1); KAI_ASSERT(extra_bytes == 0); KAI_ASSERT((kr % sr) == 0); @@ -75,90 +75,204 @@ void kai_run_rhs_pack_qsi4cxp_qsu4cxs1s0( // "k" columns and "n" rows (NxK) if is_rhs_nxk = true // "n" columns and "k" rows (KxN) if is_rhs_nxk = false - const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_n_step = is_rhs_nxk == true ? k / 2 : sizeof(int8_t); - const size_t rhs_k_step = is_rhs_nxk == true ? sizeof(int8_t) : n * sizeof(int8_t); - const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, nr, kr, sr); - const size_t k_internal = kai_k_roundedup(k, kr, sr); + if (is_rhs_nxk == true) { + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, nr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); + const size_t block_length_in_bytes = kr / sr; + const size_t k_interleaved_v = 16U; + const size_t rhs_stride = kai_roundup(k, 2) / 2; - for (size_t y = 0; y < n; y += nr) { - const uint8_t* src_row = rhs + y * rhs_n_step; - uint8_t* dst_row = (uint8_t*)rhs_packed + (y / nr) * rhs_packed_stride; + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; - int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); - // Initialize to zero the RHS reduction sums - memset(sums, 0, nr * sizeof(int32_t)); + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); - for (size_t x = 0; x < k_internal; x += (kr * sr)) { - for (size_t s = 0; s < sr; ++s) { - for (size_t i = 0; i < nr; ++i) { - for (size_t kr_idx = 0; kr_idx < kr / sr; kr_idx += 2) { - const size_t k_idx_start0 = (x / 2) + kr_idx / 2 + s * (kr / sr) / 2; - const size_t k_idx_start1 = k_idx_start0 + (kr / 2); + for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; - // Clamp the row index to avoid out-of-bound reads - const size_t src_row_idx = y + i >= n ? 0 : i; - const size_t src_addr_byte0 = src_row_idx * rhs_n_step + k_idx_start0 * rhs_k_step; - const size_t src_addr_byte1 = src_row_idx * rhs_n_step + k_idx_start1 * rhs_k_step; + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; - uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; - uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); - if (k_idx_start0 < (k / 2)) { - byte0 = src_row[src_addr_byte0]; - } + const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; - if (k_idx_start1 < (k / 2)) { - byte1 = src_row[src_addr_byte1]; - } + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; - const uint8_t src_x0_lo = (byte0 & 0x0F); - const uint8_t src_x1_lo = (byte0 >> 4); + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } - const uint8_t src_x0_hi = (byte1 & 0x0F); - const uint8_t src_x1_hi = (byte1 >> 4); + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } - sums[i] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; - sums[i] += (int32_t)src_x1_lo + (int32_t)src_x1_hi - 2 * (int32_t)rhs_zero_point; + // The following operations where we extract the values from the bytes + // can be also written in the following and less efficient manner: + /* + uint8_t src_x0_lo = 0; + uint8_t src_x0_hi = 0; - const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); - const uint8_t dst_qs1 = src_x1_lo | (src_x1_hi << 4); + if ((k0_idx % 2) == 0) { + src_x0_lo = (byte0 & 0x0F); + } else { + src_x0_lo = (byte0 >> 4); + } - *dst_row = dst_qs0 ^ 0x88; - dst_row += sizeof(uint8_t); - *dst_row = dst_qs1 ^ 0x88; - dst_row += sizeof(uint8_t); + if ((k1_idx % 2) == 0) { + src_x0_hi = (byte1 & 0x0F); + } else { + src_x0_hi = (byte1 >> 4); } + */ + const size_t shift_right_x0 = (k0_idx % 2) * 4; + const size_t shift_right_x1 = (k1_idx % 2) * 4; + + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + + sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } + + // Adjust the reduction sums + for (size_t i = 0; i < nr; ++i) { + sums[i] = sums[i] * 16; + dst_row += sizeof(int32_t); + } + + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; + dst_row += sizeof(float); + } + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; } } } + } else { + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, nr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); + const size_t block_length_in_bytes = kr / sr; + const size_t k_interleaved_v = 16U; + const size_t rhs_stride = kai_roundup(n, 2) / 2; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); + + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + + for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = (n0_valid_idx / 2) + k0_idx * rhs_stride; + const size_t src_addr_byte1 = (n0_valid_idx / 2) + k1_idx * rhs_stride; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } - // Adjust the reduction sums - for (size_t i = 0; i < nr; ++i) { - sums[i] = sums[i] * 16; - dst_row += sizeof(int32_t); - } + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } - // Adjust the scales - for (size_t i = 0; i < nr; ++i) { - // Clamp the row index to avoid out-of-bound reads - const size_t src_row_idx = KAI_MIN(y + i, n - 1); - *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; - dst_row += sizeof(float); - } + if ((n0_idx % 2) == 0) { + const uint8_t src_x0_lo = (byte0 & 0x0F); + const uint8_t src_x0_hi = (byte1 & 0x0F); + + sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } else { + const uint8_t src_x1_lo = (byte0 >> 4); + const uint8_t src_x1_hi = (byte1 >> 4); + + sums[nr_idx] += (int32_t)src_x1_lo + (int32_t)src_x1_hi - 2 * (int32_t)rhs_zero_point; + + const uint8_t dst_qs1 = src_x1_lo | (src_x1_hi << 4); + + *dst_row = dst_qs1 ^ 0x88; + dst_row += sizeof(uint8_t); + } + } + + // Adjust the reduction sums + for (size_t i = 0; i < nr; ++i) { + sums[i] = sums[i] * 16; + dst_row += sizeof(int32_t); + } - // Set the bias - if (bias == NULL) { - memset(dst_row, 0, nr * kai_num_bytes_bias); - } else { + // Adjust the scales for (size_t i = 0; i < nr; ++i) { // Clamp the row index to avoid out-of-bound reads - const size_t src_row_idx = KAI_MIN(y + i, n - 1); - ((float*)dst_row)[i] = bias[src_row_idx]; + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; + dst_row += sizeof(float); } - } - dst_row += (kai_num_bytes_bias * nr); + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + } } } -- GitLab From 00cbe61900fe868af62a3461b95077f33b48f9ef Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 28 Jun 2024 15:17:03 +0100 Subject: [PATCH 4/4] Rename n_idx to m_idx Signed-off-by: Gian Marco Iodice --- .../matmul_clamp_f32_qai8dxp_qsi4cxp.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp index 12597671..f0c92826 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp @@ -136,8 +136,8 @@ static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t const size_t lhs_qa8dx_stride = k; - for (size_t n_idx = 0; n_idx < m; ++n_idx) { - const float* src_ptr = lhs_f32 + n_idx * lhs_qa8dx_stride; + for (size_t m_idx = 0; m_idx < m; ++m_idx) { + const float* src_ptr = lhs_f32 + m_idx * lhs_qa8dx_stride; float max0 = -FLT_MAX; float min0 = FLT_MAX; @@ -177,7 +177,7 @@ static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t // Round to nearest integer const int32_t nudged_zero_point0 = lrintf(zero_point0); - int8_t* dst_ptr = (int8_t*)lhs_qa8dx + n_idx * dst_stride; + int8_t* dst_ptr = (int8_t*)lhs_qa8dx + m_idx * dst_stride; // LHS offset at the beginning of the row *((float*)(dst_ptr)) = recip_scale0; -- GitLab