From cd326230da0f4110c911c2e2308fb33a681db8dc Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 20 Sep 2024 14:59:28 +0100 Subject: [PATCH 1/5] Optimize the qsi4c32p int4 matmul for m_step = 16 and bl = 32 Signed-off-by: Anitha Raj --- ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c | 644 +++++++++++++++++- 1 file changed, 641 insertions(+), 3 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c index 5e84e870..5af3e172 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -14,7 +14,7 @@ #include "kai/kai_common.h" -static const size_t kai_m_step = 8; +static const size_t kai_m_step = 16; static const size_t kai_n_step = 4; static const size_t kai_mr = 4; static const size_t kai_nr = 4; @@ -106,10 +106,10 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( +inline static void kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { - KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT(bl == 32); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); @@ -117,6 +117,625 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( return; } + size_t num_blocks = kai_num_blocks_per_row(k, bl); + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x13, %x[m]\n" + "mov x12, #0x80\n" + "mov x20, #0x20\n" + "cmp x13, #0x10\n" + "madd x12, %x[num_blocks], x12, x20\n" + "blt 14f\n" + "1:" // Row loop + "mov x11, %x[rhs_packed]\n" + "mov x10, %x[n]\n" + "add x9, %x[dst], %x[dst_stride_row], LSL #4\n" + "2:" // Column loop + "mov x27, %x[lhs_packed]\n" + "movi v31.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "mov x20, %x[num_blocks]\n" + "movi v29.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "add x23, x27, x12\n" + "add x22, x23, x12\n" + "movi v25.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "add x21, x22, x12\n" + "movi v23.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "3:" // Block loop + "ldr q11, [x11, #0x0]\n" + "ldr q4, [x11, #0x10]\n" + "movi v2.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "ldr q12, [x27, #0x0]\n" + "ldr q0, [x27, #0x10]\n" + "movi v7.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "ldr q15, [x11, #0x20]\n" + "ldr q13, [x11, #0x30]\n" + "movi v10.16b, #0xf0\n" + "add x11, x11, #0x40\n" + "ldr q8, [x27, #0x20]\n" + "ldr q6, [x27, #0x30]\n" + "shl v14.16b, v11.16b, #0x4\n" + "shl v3.16b, v4.16b, #0x4\n" + "ldr q1, [x27, #0x40]\n" + "and v11.16b, v11.16b, v10.16b\n" + "and v4.16b, v4.16b, v10.16b\n" + ".inst 0x4e8ea582 // smmla v2.4s, v12.16b, v14.16b\n" + ".inst 0x4e83a589 // smmla v9.4s, v12.16b, v3.16b\n" + "shl v12.16b, v15.16b, #0x4\n" + ".inst 0x4e8ea407 // smmla v7.4s, v0.16b, v14.16b\n" + ".inst 0x4e83a405 // smmla v5.4s, v0.16b, v3.16b\n" + "shl v0.16b, v13.16b, #0x4\n" + "and v15.16b, v15.16b, v10.16b\n" + "and v13.16b, v13.16b, v10.16b\n" + "ldr q10, [x27, #0x50]\n" + ".inst 0x4e8ca502 // smmla v2.4s, v8.16b, v12.16b\n" + ".inst 0x4e80a509 // smmla v9.4s, v8.16b, v0.16b\n" + "ldr q8, [x27, #0x60]\n" + ".inst 0x4e8ca4c7 // smmla v7.4s, v6.16b, v12.16b\n" + ".inst 0x4e80a4c5 // smmla v5.4s, v6.16b, v0.16b\n" + "ldr q6, [x27, #0x70]\n" + "add x27, x27, #0x80\n" + ".inst 0x4e8ba422 // smmla v2.4s, v1.16b, v11.16b\n" + ".inst 0x4e84a429 // smmla v9.4s, v1.16b, v4.16b\n" + "ldr d1, [x11, #0x0]\n" + "add x11, x11, #0x8\n" + ".inst 0x4e8ba547 // smmla v7.4s, v10.16b, v11.16b\n" + ".inst 0x4e84a545 // smmla v5.4s, v10.16b, v4.16b\n" + ".inst 0x4e8fa502 // smmla v2.4s, v8.16b, v15.16b\n" + "shll v1.4s, v1.4h, #0x10\n" + ".inst 0x4e8da509 // smmla v9.4s, v8.16b, v13.16b\n" + ".inst 0x4e8fa4c7 // smmla v7.4s, v6.16b, v15.16b\n" + ".inst 0x4e8da4c5 // smmla v5.4s, v6.16b, v13.16b\n" + "uzp1 v6.2d, v2.2d, v9.2d\n" + "uzp2 v8.2d, v2.2d, v9.2d\n" + "scvtf v6.4s, v6.4s, #0x4\n" + "uzp1 v9.2d, v7.2d, v5.2d\n" + "uzp2 v2.2d, v7.2d, v5.2d\n" + "scvtf v8.4s, v8.4s, #0x4\n" + "fmla v31.4s, v6.4s, v1.4s\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v30.4s, v8.4s, v1.4s\n" + "fmla v29.4s, v9.4s, v1.4s\n" + "fmla v28.4s, v2.4s, v1.4s\n" + "ldr q9, [x23, #0x0]\n" + "ldr q7, [x23, #0x10]\n" + "movi v8.4s, #0x0\n" + "movi v2.4s, #0x0\n" + "ldr q5, [x23, #0x20]\n" + "ldr q10, [x23, #0x30]\n" + "movi v6.4s, #0x0\n" + ".inst 0x4e8ea528 // smmla v8.4s, v9.16b, v14.16b\n" + ".inst 0x4e83a522 // smmla v2.4s, v9.16b, v3.16b\n" + "ldr q9, [x23, #0x40]\n" + ".inst 0x4e8ea4e6 // smmla v6.4s, v7.16b, v14.16b\n" + ".inst 0x4e8ca4a8 // smmla v8.4s, v5.16b, v12.16b\n" + ".inst 0x4e80a4a2 // smmla v2.4s, v5.16b, v0.16b\n" + "ldr q5, [x23, #0x50]\n" + ".inst 0x4e8ca546 // smmla v6.4s, v10.16b, v12.16b\n" + ".inst 0x4e8ba528 // smmla v8.4s, v9.16b, v11.16b\n" + ".inst 0x4e84a522 // smmla v2.4s, v9.16b, v4.16b\n" + "ldr q9, [x23, #0x60]\n" + ".inst 0x4e8ba4a6 // smmla v6.4s, v5.16b, v11.16b\n" + ".inst 0x4e8fa528 // smmla v8.4s, v9.16b, v15.16b\n" + ".inst 0x4e8da522 // smmla v2.4s, v9.16b, v13.16b\n" + "movi v9.4s, #0x0\n" + ".inst 0x4e83a4e9 // smmla v9.4s, v7.16b, v3.16b\n" + "ldr q7, [x23, #0x70]\n" + "add x23, x23, #0x80\n" + ".inst 0x4e8fa4e6 // smmla v6.4s, v7.16b, v15.16b\n" + ".inst 0x4e80a549 // smmla v9.4s, v10.16b, v0.16b\n" + "uzp1 v10.2d, v8.2d, v2.2d\n" + "uzp2 v2.2d, v8.2d, v2.2d\n" + "scvtf v10.4s, v10.4s, #0x4\n" + ".inst 0x4e84a4a9 // smmla v9.4s, v5.16b, v4.16b\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v27.4s, v10.4s, v1.4s\n" + ".inst 0x4e8da4e9 // smmla v9.4s, v7.16b, v13.16b\n" + "fmla v26.4s, v2.4s, v1.4s\n" + "uzp1 v2.2d, v6.2d, v9.2d\n" + "uzp2 v10.2d, v6.2d, v9.2d\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "fmla v25.4s, v2.4s, v1.4s\n" + "fmla v24.4s, v10.4s, v1.4s\n" + "ldr q8, [x22, #0x0]\n" + "ldr q7, [x22, #0x10]\n" + "movi v9.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "ldr q2, [x22, #0x20]\n" + "ldr q5, [x22, #0x30]\n" + "movi v10.4s, #0x0\n" + ".inst 0x4e8ea509 // smmla v9.4s, v8.16b, v14.16b\n" + ".inst 0x4e83a506 // smmla v6.4s, v8.16b, v3.16b\n" + "ldr q8, [x22, #0x40]\n" + ".inst 0x4e8ea4ea // smmla v10.4s, v7.16b, v14.16b\n" + ".inst 0x4e8ca449 // smmla v9.4s, v2.16b, v12.16b\n" + ".inst 0x4e80a446 // smmla v6.4s, v2.16b, v0.16b\n" + "ldr q2, [x22, #0x50]\n" + ".inst 0x4e8ca4aa // smmla v10.4s, v5.16b, v12.16b\n" + ".inst 0x4e8ba509 // smmla v9.4s, v8.16b, v11.16b\n" + ".inst 0x4e84a506 // smmla v6.4s, v8.16b, v4.16b\n" + "ldr q8, [x22, #0x60]\n" + ".inst 0x4e8ba44a // smmla v10.4s, v2.16b, v11.16b\n" + ".inst 0x4e8fa509 // smmla v9.4s, v8.16b, v15.16b\n" + ".inst 0x4e8da506 // smmla v6.4s, v8.16b, v13.16b\n" + "movi v8.4s, #0x0\n" + ".inst 0x4e83a4e8 // smmla v8.4s, v7.16b, v3.16b\n" + "ldr q7, [x22, #0x70]\n" + "add x22, x22, #0x80\n" + ".inst 0x4e8fa4ea // smmla v10.4s, v7.16b, v15.16b\n" + ".inst 0x4e80a4a8 // smmla v8.4s, v5.16b, v0.16b\n" + "uzp1 v5.2d, v9.2d, v6.2d\n" + "uzp2 v9.2d, v9.2d, v6.2d\n" + "scvtf v5.4s, v5.4s, #0x4\n" + ".inst 0x4e84a448 // smmla v8.4s, v2.16b, v4.16b\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "fmla v23.4s, v5.4s, v1.4s\n" + ".inst 0x4e8da4e8 // smmla v8.4s, v7.16b, v13.16b\n" + "fmla v22.4s, v9.4s, v1.4s\n" + "uzp1 v2.2d, v10.2d, v8.2d\n" + "uzp2 v10.2d, v10.2d, v8.2d\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "fmla v21.4s, v2.4s, v1.4s\n" + "fmla v20.4s, v10.4s, v1.4s\n" + "ldr q2, [x21, #0x0]\n" + "ldr q10, [x21, #0x10]\n" + "movi v6.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "ldr q5, [x21, #0x20]\n" + "ldr q8, [x21, #0x30]\n" + "movi v7.4s, #0x0\n" + ".inst 0x4e8ea446 // smmla v6.4s, v2.16b, v14.16b\n" + ".inst 0x4e83a449 // smmla v9.4s, v2.16b, v3.16b\n" + "ldr q2, [x21, #0x40]\n" + ".inst 0x4e8ea547 // smmla v7.4s, v10.16b, v14.16b\n" + "ldr q14, [x21, #0x50]\n" + ".inst 0x4e8ca4a6 // smmla v6.4s, v5.16b, v12.16b\n" + ".inst 0x4e80a4a9 // smmla v9.4s, v5.16b, v0.16b\n" + "ldr q5, [x21, #0x60]\n" + ".inst 0x4e8ca507 // smmla v7.4s, v8.16b, v12.16b\n" + "ldr q12, [x21, #0x70]\n" + "add x21, x21, #0x80\n" + ".inst 0x4e8ba446 // smmla v6.4s, v2.16b, v11.16b\n" + ".inst 0x4e84a449 // smmla v9.4s, v2.16b, v4.16b\n" + "movi v2.4s, #0x0\n" + ".inst 0x4e83a542 // smmla v2.4s, v10.16b, v3.16b\n" + ".inst 0x4e8ba5c7 // smmla v7.4s, v14.16b, v11.16b\n" + ".inst 0x4e8fa4a6 // smmla v6.4s, v5.16b, v15.16b\n" + ".inst 0x4e80a502 // smmla v2.4s, v8.16b, v0.16b\n" + ".inst 0x4e8da4a9 // smmla v9.4s, v5.16b, v13.16b\n" + ".inst 0x4e8fa587 // smmla v7.4s, v12.16b, v15.16b\n" + ".inst 0x4e84a5c2 // smmla v2.4s, v14.16b, v4.16b\n" + "uzp1 v11.2d, v6.2d, v9.2d\n" + "uzp2 v14.2d, v6.2d, v9.2d\n" + "scvtf v11.4s, v11.4s, #0x4\n" + ".inst 0x4e8da582 // smmla v2.4s, v12.16b, v13.16b\n" + "scvtf v14.4s, v14.4s, #0x4\n" + "fmla v19.4s, v11.4s, v1.4s\n" + "uzp1 v9.2d, v7.2d, v2.2d\n" + "uzp2 v0.2d, v7.2d, v2.2d\n" + "fmla v18.4s, v14.4s, v1.4s\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "fmla v17.4s, v9.4s, v1.4s\n" + "fmla v16.4s, v0.4s, v1.4s\n" + "subs x20, x20, #0x1\n" + "bgt 3b\n" + "ld1 { v11.4s }, [x27]\n" + "ld1 { v10.4s }, [x23]\n" + "add x27, x27, #0x10\n" + "add x23, x23, #0x10\n" + "ld1 { v9.4s }, [x22]\n" + "ld1 { v8.4s }, [x21]\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "ldr q7, [x11, #0x0]\n" + "ldr q6, [x27, #0x0]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x10, #0x4\n" + "ldr q5, [x23, #0x0]\n" + "ldr q4, [x22, #0x0]\n" + "scvtf v11.4s, v11.4s\n" + "scvtf v10.4s, v10.4s\n" + "ldr q3, [x21, #0x0]\n" + "ldr q2, [x11, #0x10]\n" + "scvtf v9.4s, v9.4s\n" + "scvtf v8.4s, v8.4s\n" + "ld1r { v1.4s }, [%x[clamp_vals]]\n" + "ld1r { v0.4s }, [x20]\n" + "add x11, x11, #0x20\n" + "fmla v31.4s, v7.4s, v11.s[0]\n" + "fmla v30.4s, v7.4s, v11.s[1]\n" + "fmla v29.4s, v7.4s, v11.s[2]\n" + "fmla v28.4s, v7.4s, v11.s[3]\n" + "fmla v27.4s, v7.4s, v10.s[0]\n" + "fmla v26.4s, v7.4s, v10.s[1]\n" + "fmla v25.4s, v7.4s, v10.s[2]\n" + "fmla v24.4s, v7.4s, v10.s[3]\n" + "fmla v23.4s, v7.4s, v9.s[0]\n" + "fmul v31.4s, v31.4s, v6.s[0]\n" + "fmla v22.4s, v7.4s, v9.s[1]\n" + "fmla v21.4s, v7.4s, v9.s[2]\n" + "fmul v30.4s, v30.4s, v6.s[1]\n" + "fmla v20.4s, v7.4s, v9.s[3]\n" + "fmla v19.4s, v7.4s, v8.s[0]\n" + "fmul v29.4s, v29.4s, v6.s[2]\n" + "fmla v18.4s, v7.4s, v8.s[1]\n" + "fmla v17.4s, v7.4s, v8.s[2]\n" + "fmul v28.4s, v28.4s, v6.s[3]\n" + "fmla v16.4s, v7.4s, v8.s[3]\n" + "fmul v27.4s, v27.4s, v5.s[0]\n" + "fmul v26.4s, v26.4s, v5.s[1]\n" + "fmul v25.4s, v25.4s, v5.s[2]\n" + "fmul v24.4s, v24.4s, v5.s[3]\n" + "fmul v23.4s, v23.4s, v4.s[0]\n" + "fmul v22.4s, v22.4s, v4.s[1]\n" + "fmul v21.4s, v21.4s, v4.s[2]\n" + "fmul v20.4s, v20.4s, v4.s[3]\n" + "fmul v19.4s, v19.4s, v3.s[0]\n" + "fmul v18.4s, v18.4s, v3.s[1]\n" + "fmul v17.4s, v17.4s, v3.s[2]\n" + "fmul v16.4s, v16.4s, v3.s[3]\n" + "fadd v31.4s, v31.4s, v2.4s\n" + "fadd v30.4s, v30.4s, v2.4s\n" + "fadd v29.4s, v29.4s, v2.4s\n" + "fadd v28.4s, v28.4s, v2.4s\n" + "fadd v27.4s, v27.4s, v2.4s\n" + "fadd v26.4s, v26.4s, v2.4s\n" + "fadd v25.4s, v25.4s, v2.4s\n" + "fadd v24.4s, v24.4s, v2.4s\n" + "fadd v23.4s, v23.4s, v2.4s\n" + "fadd v22.4s, v22.4s, v2.4s\n" + "fadd v21.4s, v21.4s, v2.4s\n" + "fadd v20.4s, v20.4s, v2.4s\n" + "fadd v19.4s, v19.4s, v2.4s\n" + "fadd v18.4s, v18.4s, v2.4s\n" + "fadd v17.4s, v17.4s, v2.4s\n" + "fadd v16.4s, v16.4s, v2.4s\n" + "fmax v31.4s, v31.4s, v1.4s\n" + "fmax v30.4s, v30.4s, v1.4s\n" + "fmax v29.4s, v29.4s, v1.4s\n" + "fmax v28.4s, v28.4s, v1.4s\n" + "fmax v27.4s, v27.4s, v1.4s\n" + "fmax v26.4s, v26.4s, v1.4s\n" + "fmax v25.4s, v25.4s, v1.4s\n" + "fmax v24.4s, v24.4s, v1.4s\n" + "fmax v23.4s, v23.4s, v1.4s\n" + "fmax v22.4s, v22.4s, v1.4s\n" + "fmax v21.4s, v21.4s, v1.4s\n" + "fmax v20.4s, v20.4s, v1.4s\n" + "fmax v19.4s, v19.4s, v1.4s\n" + "fmax v18.4s, v18.4s, v1.4s\n" + "fmax v17.4s, v17.4s, v1.4s\n" + "fmax v16.4s, v16.4s, v1.4s\n" + "fmin v31.4s, v31.4s, v0.4s\n" + "fmin v30.4s, v30.4s, v0.4s\n" + "fmin v29.4s, v29.4s, v0.4s\n" + "fmin v28.4s, v28.4s, v0.4s\n" + "fmin v27.4s, v27.4s, v0.4s\n" + "fmin v26.4s, v26.4s, v0.4s\n" + "fmin v25.4s, v25.4s, v0.4s\n" + "fmin v24.4s, v24.4s, v0.4s\n" + "fmin v23.4s, v23.4s, v0.4s\n" + "fmin v22.4s, v22.4s, v0.4s\n" + "fmin v21.4s, v21.4s, v0.4s\n" + "fmin v20.4s, v20.4s, v0.4s\n" + "fmin v19.4s, v19.4s, v0.4s\n" + "fmin v18.4s, v18.4s, v0.4s\n" + "fmin v17.4s, v17.4s, v0.4s\n" + "fmin v16.4s, v16.4s, v0.4s\n" + "blt 8f\n" + "mov x20, %x[dst]\n" + "str q31, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q29, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q27, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q26, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q24, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q21, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q20, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q17, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q16, [x20, #0x0]\n" + "b 13f\n" + "8:" // Partial output + "mov x28, %x[dst]\n" + "add x26, x28, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x28, %x[dst_stride_row], LSL #1\n" + "add x21, x28, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "add x27, x23, %x[dst_stride_row]\n" + "tbz x10, #1, 9f\n" + "st1 { v24.d }[0], [x23], #0x8\n" + "st1 { v25.d }[0], [x25], #0x8\n" + "st1 { v26.d }[0], [x24], #0x8\n" + "st1 { v27.d }[0], [x26], #0x8\n" + "st1 { v28.d }[0], [x20], #0x8\n" + "st1 { v29.d }[0], [x22], #0x8\n" + "st1 { v30.d }[0], [x21], #0x8\n" + "st1 { v31.d }[0], [x28], #0x8\n" + "tbz x10, #0, 10f\n" + "st1 { v24.s }[2], [x23]\n" + "st1 { v25.s }[2], [x25]\n" + "st1 { v26.s }[2], [x24]\n" + "st1 { v27.s }[2], [x26]\n" + "st1 { v28.s }[2], [x20]\n" + "st1 { v29.s }[2], [x22]\n" + "st1 { v30.s }[2], [x21]\n" + "st1 { v31.s }[2], [x28]\n" + "b 10f\n" + "9:" // Output block 0: partial_1_0 + "st1 { v24.s }[0], [x23]\n" + "st1 { v25.s }[0], [x25]\n" + "st1 { v26.s }[0], [x24]\n" + "st1 { v27.s }[0], [x26]\n" + "st1 { v28.s }[0], [x20]\n" + "st1 { v29.s }[0], [x22]\n" + "st1 { v30.s }[0], [x21]\n" + "st1 { v31.s }[0], [x28]\n" + "10:" // Output block 0: Done + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x27, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row], LSL #1\n" + "add x23, x27, %x[dst_stride_row]\n" + "add x22, x25, %x[dst_stride_row]\n" + "add x21, x26, %x[dst_stride_row]\n" + "add x20, x24, %x[dst_stride_row]\n" + "tbz x10, #1, 11f\n" + "st1 { v16.d }[0], [x20], #0x8\n" + "st1 { v17.d }[0], [x24], #0x8\n" + "st1 { v18.d }[0], [x21], #0x8\n" + "st1 { v19.d }[0], [x26], #0x8\n" + "st1 { v20.d }[0], [x22], #0x8\n" + "st1 { v21.d }[0], [x25], #0x8\n" + "st1 { v22.d }[0], [x23], #0x8\n" + "st1 { v23.d }[0], [x27], #0x8\n" + "tbz x10, #0, 12f\n" + "st1 { v16.s }[2], [x20]\n" + "st1 { v17.s }[2], [x24]\n" + "st1 { v18.s }[2], [x21]\n" + "st1 { v19.s }[2], [x26]\n" + "st1 { v20.s }[2], [x22]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v22.s }[2], [x23]\n" + "st1 { v23.s }[2], [x27]\n" + "b 12f\n" + "11:" // Output block 1: partial_1_0 + "st1 { v16.s }[0], [x20]\n" + "st1 { v17.s }[0], [x24]\n" + "st1 { v18.s }[0], [x21]\n" + "st1 { v19.s }[0], [x26]\n" + "st1 { v20.s }[0], [x22]\n" + "st1 { v21.s }[0], [x25]\n" + "st1 { v22.s }[0], [x23]\n" + "st1 { v23.s }[0], [x27]\n" + "12:" // Output block 1: Done + "13:" // Output stage exit + "subs x10, x10, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "mov x20, #0x4\n" + "sub x13, x13, #0x10\n" + "cmp x13, #0x10\n" + "mov %x[dst], x9\n" + "madd %x[lhs_packed], x20, x12, %x[lhs_packed]\n" + "bge 1b\n" + "14:" // Row loop skip + "cbz x13, 23f\n" + "15:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "16:" // Row tail: Column loop + "movi v31.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "mov x27, %x[lhs_packed]\n" + "mov x20, %x[num_blocks]\n" + "movi v29.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "17:" // Row tail: Block loop + "ldr q9, [x26, #0x0]\n" + "ldr q8, [x26, #0x10]\n" + "movi v7.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "ldr q5, [x27, #0x0]\n" + "ldr q4, [x27, #0x10]\n" + "movi v3.4s, #0x0\n" + "movi v2.4s, #0x0\n" + "ldr q1, [x26, #0x20]\n" + "ldr q0, [x26, #0x30]\n" + "movi v27.16b, #0xf0\n" + "add x26, x26, #0x40\n" + "ldr q26, [x27, #0x20]\n" + "ldr q25, [x27, #0x30]\n" + "shl v24.16b, v9.16b, #0x4\n" + "shl v20.16b, v8.16b, #0x4\n" + "ldr q23, [x27, #0x40]\n" + "ldr q22, [x27, #0x50]\n" + "and v9.16b, v9.16b, v27.16b\n" + "and v8.16b, v8.16b, v27.16b\n" + "ldr q21, [x27, #0x60]\n" + "ldr q19, [x27, #0x70]\n" + "shl v18.16b, v1.16b, #0x4\n" + "shl v17.16b, v0.16b, #0x4\n" + "ldr d16, [x26, #0x0]\n" + ".inst 0x4e98a4a7 // smmla v7.4s, v5.16b, v24.16b\n" + ".inst 0x4e94a4a6 // smmla v6.4s, v5.16b, v20.16b\n" + "and v1.16b, v1.16b, v27.16b\n" + ".inst 0x4e98a483 // smmla v3.4s, v4.16b, v24.16b\n" + ".inst 0x4e94a482 // smmla v2.4s, v4.16b, v20.16b\n" + "and v0.16b, v0.16b, v27.16b\n" + "add x26, x26, #0x8\n" + "add x27, x27, #0x80\n" + "shll v20.4s, v16.4h, #0x10\n" + ".inst 0x4e92a747 // smmla v7.4s, v26.16b, v18.16b\n" + ".inst 0x4e91a746 // smmla v6.4s, v26.16b, v17.16b\n" + ".inst 0x4e92a723 // smmla v3.4s, v25.16b, v18.16b\n" + ".inst 0x4e91a722 // smmla v2.4s, v25.16b, v17.16b\n" + ".inst 0x4e89a6e7 // smmla v7.4s, v23.16b, v9.16b\n" + ".inst 0x4e88a6e6 // smmla v6.4s, v23.16b, v8.16b\n" + ".inst 0x4e89a6c3 // smmla v3.4s, v22.16b, v9.16b\n" + ".inst 0x4e88a6c2 // smmla v2.4s, v22.16b, v8.16b\n" + ".inst 0x4e81a6a7 // smmla v7.4s, v21.16b, v1.16b\n" + ".inst 0x4e80a6a6 // smmla v6.4s, v21.16b, v0.16b\n" + ".inst 0x4e81a663 // smmla v3.4s, v19.16b, v1.16b\n" + ".inst 0x4e80a662 // smmla v2.4s, v19.16b, v0.16b\n" + "uzp1 v19.2d, v7.2d, v6.2d\n" + "uzp2 v18.2d, v7.2d, v6.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp1 v17.2d, v3.2d, v2.2d\n" + "uzp2 v16.2d, v3.2d, v2.2d\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v31.4s, v19.4s, v20.4s\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v30.4s, v18.4s, v20.4s\n" + "fmla v29.4s, v17.4s, v20.4s\n" + "fmla v28.4s, v16.4s, v20.4s\n" + "subs x20, x20, #0x1\n" + "bgt 17b\n" + "ld1 { v21.4s }, [x27]\n" + "ldr q20, [x26, #0x0]\n" + "add x27, x27, #0x10\n" + "add x20, %x[clamp_vals], #0x4\n" + "ldr q19, [x27, #0x0]\n" + "ldr q18, [x26, #0x10]\n" + "cmp x25, #0x4\n" + "add x26, x26, #0x20\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v21.4s, v21.4s\n" + "fmla v31.4s, v20.4s, v21.s[0]\n" + "fmla v30.4s, v20.4s, v21.s[1]\n" + "fmla v29.4s, v20.4s, v21.s[2]\n" + "fmla v28.4s, v20.4s, v21.s[3]\n" + "fmul v31.4s, v31.4s, v19.s[0]\n" + "fmul v30.4s, v30.4s, v19.s[1]\n" + "fmul v29.4s, v29.4s, v19.s[2]\n" + "fadd v31.4s, v31.4s, v18.4s\n" + "fmul v28.4s, v28.4s, v19.s[3]\n" + "fadd v30.4s, v30.4s, v18.4s\n" + "fadd v29.4s, v29.4s, v18.4s\n" + "fadd v28.4s, v28.4s, v18.4s\n" + "fmax v31.4s, v31.4s, v17.4s\n" + "fmax v30.4s, v30.4s, v17.4s\n" + "fmax v29.4s, v29.4s, v17.4s\n" + "fmax v28.4s, v28.4s, v17.4s\n" + "fmin v31.4s, v31.4s, v16.4s\n" + "fmin v30.4s, v30.4s, v16.4s\n" + "fmin v29.4s, v29.4s, v16.4s\n" + "fmin v28.4s, v28.4s, v16.4s\n" + "blt 19f\n" + "mov x20, %x[dst]\n" + "cmp x13, #0x1\n" + "str q31, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "cmp x13, #0x2\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "cmp x13, #0x3\n" + "str q29, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "str q28, [x20, #0x0]\n" + "b 22f\n" + "19:" // Row tail: Partial output + "mov x23, %x[dst]\n" + "cmp x13, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GT\n" + "cmp x13, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GT\n" + "cmp x13, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GT\n" + "tbz x25, #1, 20f\n" + "st1 { v28.d }[0], [x20], #0x8\n" + "st1 { v29.d }[0], [x21], #0x8\n" + "st1 { v30.d }[0], [x22], #0x8\n" + "st1 { v31.d }[0], [x23], #0x8\n" + "tbz x25, #0, 21f\n" + "st1 { v28.s }[2], [x20]\n" + "st1 { v29.s }[2], [x21]\n" + "st1 { v30.s }[2], [x22]\n" + "st1 { v31.s }[2], [x23]\n" + "b 21f\n" + "20:" // Row tail: Output block 0: partial_1_0 + "st1 { v28.s }[0], [x20]\n" + "st1 { v29.s }[0], [x21]\n" + "st1 { v30.s }[0], [x22]\n" + "st1 { v31.s }[0], [x23]\n" + "21:" // Row tail: Output block 0: Done + "22:" // Row tail: Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 16b\n" + "subs x13, x13, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x12\n" + "mov %x[dst], x24\n" + "bgt 15b\n" + "23:" // Row tail: Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v2", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v31", "v4", "v5", "v6", "v7", + "v8", "v9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9"); +} + +inline static void kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } size_t num_subblocks = bl / kai_bl_multiple_of; size_t num_blocks = kai_num_blocks_per_row(k, bl); @@ -545,4 +1164,23 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); } + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + if (m >= 16 && bl == 32) { + kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( + m, n, k, bl, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, scalar_min, scalar_max); + } else { + kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + m, n, k, bl, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, scalar_min, scalar_max); + } +} #endif // Architectural feature check -- GitLab From 239dd4ebd8ab18c6a452ba5a3cb9fd3455749a02 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Mon, 30 Sep 2024 14:41:50 +0100 Subject: [PATCH 2/5] Move 16x4 matmul kernel to new ukernel files Signed-off-by: Anitha Raj --- CMakeLists.txt | 1 + .../CMakeLists.txt | 4 +- kai/ukernels/matmul/BUILD.bazel | 11 + ...qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c | 1410 +++++++++++++++++ ...qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h | 142 ++ ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c | 644 +------- 6 files changed, 1570 insertions(+), 642 deletions(-) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h diff --git a/CMakeLists.txt b/CMakeLists.txt index ee4f3304..7fc0dc2c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -109,6 +109,7 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c ) set(KLEIDIAI_FILES_SME diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt index 47527352..cbc4e93e 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt @@ -27,7 +27,9 @@ add_executable(matmul_clamp_f32_qai8dxp_qsi4c32p ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c - ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c) + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c + ) target_compile_options(matmul_clamp_f32_qai8dxp_qsi4c32p PRIVATE -march=armv8.2-a+dotprod+i8mm) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 1a0fde61..96c26780 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -237,6 +237,16 @@ kai_c_library( ], ) +kai_c_library( + name = "clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h"], + cpu_uarch = kai_cpu_i8mm(), + deps = [ + ":clamp_f32_qai8dxp_qsi4c32p_interface", + ], +) + kai_c_library( name = "rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", srcs = ["pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c"], @@ -261,6 +271,7 @@ kai_c_library( ":clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod", ":clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod", ":clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod", + ":clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c new file mode 100644 index 00000000..92f8a56d --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c @@ -0,0 +1,1410 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_MATMUL_INT8) +#error "I8mm extension required to compile this micro-kernel" +#else +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 16; +static const size_t kai_n_step = 4; +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_bl_multiple_of = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + + return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( + size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + size_t num_subblocks = bl / kai_bl_multiple_of; + size_t num_blocks = kai_num_blocks_per_row(k, bl); + + float clamp_vals[2] = {scalar_min, scalar_max}; + + if (bl == 32) { + __asm__ __volatile__( + "mov x13, %x[m]\n" + "mov x12, #0x80\n" + "mov x20, #0x20\n" + "cmp x13, #0x10\n" + "madd x12, %x[num_blocks], x12, x20\n" + "blt 14f\n" + "1:" // Row loop + "mov x11, %x[rhs_packed]\n" + "mov x10, %x[n]\n" + "add x9, %x[dst], %x[dst_stride_row], LSL #4\n" + "2:" // Column loop + "mov x27, %x[lhs_packed]\n" + "movi v31.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "mov x20, %x[num_blocks]\n" + "movi v29.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "add x23, x27, x12\n" + "add x22, x23, x12\n" + "movi v25.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "add x21, x22, x12\n" + "movi v23.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "3:" // Block loop + "ldr q11, [x11, #0x0]\n" + "ldr q4, [x11, #0x10]\n" + "movi v2.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "ldr q12, [x27, #0x0]\n" + "ldr q0, [x27, #0x10]\n" + "movi v7.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "ldr q15, [x11, #0x20]\n" + "ldr q13, [x11, #0x30]\n" + "movi v10.16b, #0xf0\n" + "add x11, x11, #0x40\n" + "ldr q8, [x27, #0x20]\n" + "ldr q6, [x27, #0x30]\n" + "shl v14.16b, v11.16b, #0x4\n" + "shl v3.16b, v4.16b, #0x4\n" + "ldr q1, [x27, #0x40]\n" + "and v11.16b, v11.16b, v10.16b\n" + "and v4.16b, v4.16b, v10.16b\n" + ".inst 0x4e8ea582 // smmla v2.4s, v12.16b, v14.16b\n" + ".inst 0x4e83a589 // smmla v9.4s, v12.16b, v3.16b\n" + "shl v12.16b, v15.16b, #0x4\n" + ".inst 0x4e8ea407 // smmla v7.4s, v0.16b, v14.16b\n" + ".inst 0x4e83a405 // smmla v5.4s, v0.16b, v3.16b\n" + "shl v0.16b, v13.16b, #0x4\n" + "and v15.16b, v15.16b, v10.16b\n" + "and v13.16b, v13.16b, v10.16b\n" + "ldr q10, [x27, #0x50]\n" + ".inst 0x4e8ca502 // smmla v2.4s, v8.16b, v12.16b\n" + ".inst 0x4e80a509 // smmla v9.4s, v8.16b, v0.16b\n" + "ldr q8, [x27, #0x60]\n" + ".inst 0x4e8ca4c7 // smmla v7.4s, v6.16b, v12.16b\n" + ".inst 0x4e80a4c5 // smmla v5.4s, v6.16b, v0.16b\n" + "ldr q6, [x27, #0x70]\n" + "add x27, x27, #0x80\n" + ".inst 0x4e8ba422 // smmla v2.4s, v1.16b, v11.16b\n" + ".inst 0x4e84a429 // smmla v9.4s, v1.16b, v4.16b\n" + "ldr d1, [x11, #0x0]\n" + "add x11, x11, #0x8\n" + ".inst 0x4e8ba547 // smmla v7.4s, v10.16b, v11.16b\n" + ".inst 0x4e84a545 // smmla v5.4s, v10.16b, v4.16b\n" + ".inst 0x4e8fa502 // smmla v2.4s, v8.16b, v15.16b\n" + "shll v1.4s, v1.4h, #0x10\n" + ".inst 0x4e8da509 // smmla v9.4s, v8.16b, v13.16b\n" + ".inst 0x4e8fa4c7 // smmla v7.4s, v6.16b, v15.16b\n" + ".inst 0x4e8da4c5 // smmla v5.4s, v6.16b, v13.16b\n" + "uzp1 v6.2d, v2.2d, v9.2d\n" + "uzp2 v8.2d, v2.2d, v9.2d\n" + "scvtf v6.4s, v6.4s, #0x4\n" + "uzp1 v9.2d, v7.2d, v5.2d\n" + "uzp2 v2.2d, v7.2d, v5.2d\n" + "scvtf v8.4s, v8.4s, #0x4\n" + "fmla v31.4s, v6.4s, v1.4s\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v30.4s, v8.4s, v1.4s\n" + "fmla v29.4s, v9.4s, v1.4s\n" + "fmla v28.4s, v2.4s, v1.4s\n" + "ldr q9, [x23, #0x0]\n" + "ldr q7, [x23, #0x10]\n" + "movi v8.4s, #0x0\n" + "movi v2.4s, #0x0\n" + "ldr q5, [x23, #0x20]\n" + "ldr q10, [x23, #0x30]\n" + "movi v6.4s, #0x0\n" + ".inst 0x4e8ea528 // smmla v8.4s, v9.16b, v14.16b\n" + ".inst 0x4e83a522 // smmla v2.4s, v9.16b, v3.16b\n" + "ldr q9, [x23, #0x40]\n" + ".inst 0x4e8ea4e6 // smmla v6.4s, v7.16b, v14.16b\n" + ".inst 0x4e8ca4a8 // smmla v8.4s, v5.16b, v12.16b\n" + ".inst 0x4e80a4a2 // smmla v2.4s, v5.16b, v0.16b\n" + "ldr q5, [x23, #0x50]\n" + ".inst 0x4e8ca546 // smmla v6.4s, v10.16b, v12.16b\n" + ".inst 0x4e8ba528 // smmla v8.4s, v9.16b, v11.16b\n" + ".inst 0x4e84a522 // smmla v2.4s, v9.16b, v4.16b\n" + "ldr q9, [x23, #0x60]\n" + ".inst 0x4e8ba4a6 // smmla v6.4s, v5.16b, v11.16b\n" + ".inst 0x4e8fa528 // smmla v8.4s, v9.16b, v15.16b\n" + ".inst 0x4e8da522 // smmla v2.4s, v9.16b, v13.16b\n" + "movi v9.4s, #0x0\n" + ".inst 0x4e83a4e9 // smmla v9.4s, v7.16b, v3.16b\n" + "ldr q7, [x23, #0x70]\n" + "add x23, x23, #0x80\n" + ".inst 0x4e8fa4e6 // smmla v6.4s, v7.16b, v15.16b\n" + ".inst 0x4e80a549 // smmla v9.4s, v10.16b, v0.16b\n" + "uzp1 v10.2d, v8.2d, v2.2d\n" + "uzp2 v2.2d, v8.2d, v2.2d\n" + "scvtf v10.4s, v10.4s, #0x4\n" + ".inst 0x4e84a4a9 // smmla v9.4s, v5.16b, v4.16b\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v27.4s, v10.4s, v1.4s\n" + ".inst 0x4e8da4e9 // smmla v9.4s, v7.16b, v13.16b\n" + "fmla v26.4s, v2.4s, v1.4s\n" + "uzp1 v2.2d, v6.2d, v9.2d\n" + "uzp2 v10.2d, v6.2d, v9.2d\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "fmla v25.4s, v2.4s, v1.4s\n" + "fmla v24.4s, v10.4s, v1.4s\n" + "ldr q8, [x22, #0x0]\n" + "ldr q7, [x22, #0x10]\n" + "movi v9.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "ldr q2, [x22, #0x20]\n" + "ldr q5, [x22, #0x30]\n" + "movi v10.4s, #0x0\n" + ".inst 0x4e8ea509 // smmla v9.4s, v8.16b, v14.16b\n" + ".inst 0x4e83a506 // smmla v6.4s, v8.16b, v3.16b\n" + "ldr q8, [x22, #0x40]\n" + ".inst 0x4e8ea4ea // smmla v10.4s, v7.16b, v14.16b\n" + ".inst 0x4e8ca449 // smmla v9.4s, v2.16b, v12.16b\n" + ".inst 0x4e80a446 // smmla v6.4s, v2.16b, v0.16b\n" + "ldr q2, [x22, #0x50]\n" + ".inst 0x4e8ca4aa // smmla v10.4s, v5.16b, v12.16b\n" + ".inst 0x4e8ba509 // smmla v9.4s, v8.16b, v11.16b\n" + ".inst 0x4e84a506 // smmla v6.4s, v8.16b, v4.16b\n" + "ldr q8, [x22, #0x60]\n" + ".inst 0x4e8ba44a // smmla v10.4s, v2.16b, v11.16b\n" + ".inst 0x4e8fa509 // smmla v9.4s, v8.16b, v15.16b\n" + ".inst 0x4e8da506 // smmla v6.4s, v8.16b, v13.16b\n" + "movi v8.4s, #0x0\n" + ".inst 0x4e83a4e8 // smmla v8.4s, v7.16b, v3.16b\n" + "ldr q7, [x22, #0x70]\n" + "add x22, x22, #0x80\n" + ".inst 0x4e8fa4ea // smmla v10.4s, v7.16b, v15.16b\n" + ".inst 0x4e80a4a8 // smmla v8.4s, v5.16b, v0.16b\n" + "uzp1 v5.2d, v9.2d, v6.2d\n" + "uzp2 v9.2d, v9.2d, v6.2d\n" + "scvtf v5.4s, v5.4s, #0x4\n" + ".inst 0x4e84a448 // smmla v8.4s, v2.16b, v4.16b\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "fmla v23.4s, v5.4s, v1.4s\n" + ".inst 0x4e8da4e8 // smmla v8.4s, v7.16b, v13.16b\n" + "fmla v22.4s, v9.4s, v1.4s\n" + "uzp1 v2.2d, v10.2d, v8.2d\n" + "uzp2 v10.2d, v10.2d, v8.2d\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "fmla v21.4s, v2.4s, v1.4s\n" + "fmla v20.4s, v10.4s, v1.4s\n" + "ldr q2, [x21, #0x0]\n" + "ldr q10, [x21, #0x10]\n" + "movi v6.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "ldr q5, [x21, #0x20]\n" + "ldr q8, [x21, #0x30]\n" + "movi v7.4s, #0x0\n" + ".inst 0x4e8ea446 // smmla v6.4s, v2.16b, v14.16b\n" + ".inst 0x4e83a449 // smmla v9.4s, v2.16b, v3.16b\n" + "ldr q2, [x21, #0x40]\n" + ".inst 0x4e8ea547 // smmla v7.4s, v10.16b, v14.16b\n" + "ldr q14, [x21, #0x50]\n" + ".inst 0x4e8ca4a6 // smmla v6.4s, v5.16b, v12.16b\n" + ".inst 0x4e80a4a9 // smmla v9.4s, v5.16b, v0.16b\n" + "ldr q5, [x21, #0x60]\n" + ".inst 0x4e8ca507 // smmla v7.4s, v8.16b, v12.16b\n" + "ldr q12, [x21, #0x70]\n" + "add x21, x21, #0x80\n" + ".inst 0x4e8ba446 // smmla v6.4s, v2.16b, v11.16b\n" + ".inst 0x4e84a449 // smmla v9.4s, v2.16b, v4.16b\n" + "movi v2.4s, #0x0\n" + ".inst 0x4e83a542 // smmla v2.4s, v10.16b, v3.16b\n" + ".inst 0x4e8ba5c7 // smmla v7.4s, v14.16b, v11.16b\n" + ".inst 0x4e8fa4a6 // smmla v6.4s, v5.16b, v15.16b\n" + ".inst 0x4e80a502 // smmla v2.4s, v8.16b, v0.16b\n" + ".inst 0x4e8da4a9 // smmla v9.4s, v5.16b, v13.16b\n" + ".inst 0x4e8fa587 // smmla v7.4s, v12.16b, v15.16b\n" + ".inst 0x4e84a5c2 // smmla v2.4s, v14.16b, v4.16b\n" + "uzp1 v11.2d, v6.2d, v9.2d\n" + "uzp2 v14.2d, v6.2d, v9.2d\n" + "scvtf v11.4s, v11.4s, #0x4\n" + ".inst 0x4e8da582 // smmla v2.4s, v12.16b, v13.16b\n" + "scvtf v14.4s, v14.4s, #0x4\n" + "fmla v19.4s, v11.4s, v1.4s\n" + "uzp1 v9.2d, v7.2d, v2.2d\n" + "uzp2 v0.2d, v7.2d, v2.2d\n" + "fmla v18.4s, v14.4s, v1.4s\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "fmla v17.4s, v9.4s, v1.4s\n" + "fmla v16.4s, v0.4s, v1.4s\n" + "subs x20, x20, #0x1\n" + "bgt 3b\n" + "ld1 { v11.4s }, [x27]\n" + "ld1 { v10.4s }, [x23]\n" + "add x27, x27, #0x10\n" + "add x23, x23, #0x10\n" + "ld1 { v9.4s }, [x22]\n" + "ld1 { v8.4s }, [x21]\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "ldr q7, [x11, #0x0]\n" + "ldr q6, [x27, #0x0]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x10, #0x4\n" + "ldr q5, [x23, #0x0]\n" + "ldr q4, [x22, #0x0]\n" + "scvtf v11.4s, v11.4s\n" + "scvtf v10.4s, v10.4s\n" + "ldr q3, [x21, #0x0]\n" + "ldr q2, [x11, #0x10]\n" + "scvtf v9.4s, v9.4s\n" + "scvtf v8.4s, v8.4s\n" + "ld1r { v1.4s }, [%x[clamp_vals]]\n" + "ld1r { v0.4s }, [x20]\n" + "add x11, x11, #0x20\n" + "fmla v31.4s, v7.4s, v11.s[0]\n" + "fmla v30.4s, v7.4s, v11.s[1]\n" + "fmla v29.4s, v7.4s, v11.s[2]\n" + "fmla v28.4s, v7.4s, v11.s[3]\n" + "fmla v27.4s, v7.4s, v10.s[0]\n" + "fmla v26.4s, v7.4s, v10.s[1]\n" + "fmla v25.4s, v7.4s, v10.s[2]\n" + "fmla v24.4s, v7.4s, v10.s[3]\n" + "fmla v23.4s, v7.4s, v9.s[0]\n" + "fmul v31.4s, v31.4s, v6.s[0]\n" + "fmla v22.4s, v7.4s, v9.s[1]\n" + "fmla v21.4s, v7.4s, v9.s[2]\n" + "fmul v30.4s, v30.4s, v6.s[1]\n" + "fmla v20.4s, v7.4s, v9.s[3]\n" + "fmla v19.4s, v7.4s, v8.s[0]\n" + "fmul v29.4s, v29.4s, v6.s[2]\n" + "fmla v18.4s, v7.4s, v8.s[1]\n" + "fmla v17.4s, v7.4s, v8.s[2]\n" + "fmul v28.4s, v28.4s, v6.s[3]\n" + "fmla v16.4s, v7.4s, v8.s[3]\n" + "fmul v27.4s, v27.4s, v5.s[0]\n" + "fmul v26.4s, v26.4s, v5.s[1]\n" + "fmul v25.4s, v25.4s, v5.s[2]\n" + "fmul v24.4s, v24.4s, v5.s[3]\n" + "fmul v23.4s, v23.4s, v4.s[0]\n" + "fmul v22.4s, v22.4s, v4.s[1]\n" + "fmul v21.4s, v21.4s, v4.s[2]\n" + "fmul v20.4s, v20.4s, v4.s[3]\n" + "fmul v19.4s, v19.4s, v3.s[0]\n" + "fmul v18.4s, v18.4s, v3.s[1]\n" + "fmul v17.4s, v17.4s, v3.s[2]\n" + "fmul v16.4s, v16.4s, v3.s[3]\n" + "fadd v31.4s, v31.4s, v2.4s\n" + "fadd v30.4s, v30.4s, v2.4s\n" + "fadd v29.4s, v29.4s, v2.4s\n" + "fadd v28.4s, v28.4s, v2.4s\n" + "fadd v27.4s, v27.4s, v2.4s\n" + "fadd v26.4s, v26.4s, v2.4s\n" + "fadd v25.4s, v25.4s, v2.4s\n" + "fadd v24.4s, v24.4s, v2.4s\n" + "fadd v23.4s, v23.4s, v2.4s\n" + "fadd v22.4s, v22.4s, v2.4s\n" + "fadd v21.4s, v21.4s, v2.4s\n" + "fadd v20.4s, v20.4s, v2.4s\n" + "fadd v19.4s, v19.4s, v2.4s\n" + "fadd v18.4s, v18.4s, v2.4s\n" + "fadd v17.4s, v17.4s, v2.4s\n" + "fadd v16.4s, v16.4s, v2.4s\n" + "fmax v31.4s, v31.4s, v1.4s\n" + "fmax v30.4s, v30.4s, v1.4s\n" + "fmax v29.4s, v29.4s, v1.4s\n" + "fmax v28.4s, v28.4s, v1.4s\n" + "fmax v27.4s, v27.4s, v1.4s\n" + "fmax v26.4s, v26.4s, v1.4s\n" + "fmax v25.4s, v25.4s, v1.4s\n" + "fmax v24.4s, v24.4s, v1.4s\n" + "fmax v23.4s, v23.4s, v1.4s\n" + "fmax v22.4s, v22.4s, v1.4s\n" + "fmax v21.4s, v21.4s, v1.4s\n" + "fmax v20.4s, v20.4s, v1.4s\n" + "fmax v19.4s, v19.4s, v1.4s\n" + "fmax v18.4s, v18.4s, v1.4s\n" + "fmax v17.4s, v17.4s, v1.4s\n" + "fmax v16.4s, v16.4s, v1.4s\n" + "fmin v31.4s, v31.4s, v0.4s\n" + "fmin v30.4s, v30.4s, v0.4s\n" + "fmin v29.4s, v29.4s, v0.4s\n" + "fmin v28.4s, v28.4s, v0.4s\n" + "fmin v27.4s, v27.4s, v0.4s\n" + "fmin v26.4s, v26.4s, v0.4s\n" + "fmin v25.4s, v25.4s, v0.4s\n" + "fmin v24.4s, v24.4s, v0.4s\n" + "fmin v23.4s, v23.4s, v0.4s\n" + "fmin v22.4s, v22.4s, v0.4s\n" + "fmin v21.4s, v21.4s, v0.4s\n" + "fmin v20.4s, v20.4s, v0.4s\n" + "fmin v19.4s, v19.4s, v0.4s\n" + "fmin v18.4s, v18.4s, v0.4s\n" + "fmin v17.4s, v17.4s, v0.4s\n" + "fmin v16.4s, v16.4s, v0.4s\n" + "blt 8f\n" + "mov x20, %x[dst]\n" + "str q31, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q29, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q27, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q26, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q24, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q21, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q20, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q17, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q16, [x20, #0x0]\n" + "b 13f\n" + "8:" // Partial output + "mov x28, %x[dst]\n" + "add x26, x28, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x28, %x[dst_stride_row], LSL #1\n" + "add x21, x28, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "add x27, x23, %x[dst_stride_row]\n" + "tbz x10, #1, 9f\n" + "st1 { v24.d }[0], [x23], #0x8\n" + "st1 { v25.d }[0], [x25], #0x8\n" + "st1 { v26.d }[0], [x24], #0x8\n" + "st1 { v27.d }[0], [x26], #0x8\n" + "st1 { v28.d }[0], [x20], #0x8\n" + "st1 { v29.d }[0], [x22], #0x8\n" + "st1 { v30.d }[0], [x21], #0x8\n" + "st1 { v31.d }[0], [x28], #0x8\n" + "tbz x10, #0, 10f\n" + "st1 { v24.s }[2], [x23]\n" + "st1 { v25.s }[2], [x25]\n" + "st1 { v26.s }[2], [x24]\n" + "st1 { v27.s }[2], [x26]\n" + "st1 { v28.s }[2], [x20]\n" + "st1 { v29.s }[2], [x22]\n" + "st1 { v30.s }[2], [x21]\n" + "st1 { v31.s }[2], [x28]\n" + "b 10f\n" + "9:" // Output block 0: partial_1_0 + "st1 { v24.s }[0], [x23]\n" + "st1 { v25.s }[0], [x25]\n" + "st1 { v26.s }[0], [x24]\n" + "st1 { v27.s }[0], [x26]\n" + "st1 { v28.s }[0], [x20]\n" + "st1 { v29.s }[0], [x22]\n" + "st1 { v30.s }[0], [x21]\n" + "st1 { v31.s }[0], [x28]\n" + "10:" // Output block 0: Done + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x27, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row], LSL #1\n" + "add x23, x27, %x[dst_stride_row]\n" + "add x22, x25, %x[dst_stride_row]\n" + "add x21, x26, %x[dst_stride_row]\n" + "add x20, x24, %x[dst_stride_row]\n" + "tbz x10, #1, 11f\n" + "st1 { v16.d }[0], [x20], #0x8\n" + "st1 { v17.d }[0], [x24], #0x8\n" + "st1 { v18.d }[0], [x21], #0x8\n" + "st1 { v19.d }[0], [x26], #0x8\n" + "st1 { v20.d }[0], [x22], #0x8\n" + "st1 { v21.d }[0], [x25], #0x8\n" + "st1 { v22.d }[0], [x23], #0x8\n" + "st1 { v23.d }[0], [x27], #0x8\n" + "tbz x10, #0, 12f\n" + "st1 { v16.s }[2], [x20]\n" + "st1 { v17.s }[2], [x24]\n" + "st1 { v18.s }[2], [x21]\n" + "st1 { v19.s }[2], [x26]\n" + "st1 { v20.s }[2], [x22]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v22.s }[2], [x23]\n" + "st1 { v23.s }[2], [x27]\n" + "b 12f\n" + "11:" // Output block 1: partial_1_0 + "st1 { v16.s }[0], [x20]\n" + "st1 { v17.s }[0], [x24]\n" + "st1 { v18.s }[0], [x21]\n" + "st1 { v19.s }[0], [x26]\n" + "st1 { v20.s }[0], [x22]\n" + "st1 { v21.s }[0], [x25]\n" + "st1 { v22.s }[0], [x23]\n" + "st1 { v23.s }[0], [x27]\n" + "12:" // Output block 1: Done + "13:" // Output stage exit + "subs x10, x10, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "mov x20, #0x4\n" + "sub x13, x13, #0x10\n" + "cmp x13, #0x10\n" + "mov %x[dst], x9\n" + "madd %x[lhs_packed], x20, x12, %x[lhs_packed]\n" + "bge 1b\n" + "14:" // Row loop skip + "cbz x13, 23f\n" + "15:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "16:" // Row tail: Column loop + "movi v31.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "mov x27, %x[lhs_packed]\n" + "mov x20, %x[num_blocks]\n" + "movi v29.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "17:" // Row tail: Block loop + "ldr q9, [x26, #0x0]\n" + "ldr q8, [x26, #0x10]\n" + "movi v7.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "ldr q5, [x27, #0x0]\n" + "ldr q4, [x27, #0x10]\n" + "movi v3.4s, #0x0\n" + "movi v2.4s, #0x0\n" + "ldr q1, [x26, #0x20]\n" + "ldr q0, [x26, #0x30]\n" + "movi v27.16b, #0xf0\n" + "add x26, x26, #0x40\n" + "ldr q26, [x27, #0x20]\n" + "ldr q25, [x27, #0x30]\n" + "shl v24.16b, v9.16b, #0x4\n" + "shl v20.16b, v8.16b, #0x4\n" + "ldr q23, [x27, #0x40]\n" + "ldr q22, [x27, #0x50]\n" + "and v9.16b, v9.16b, v27.16b\n" + "and v8.16b, v8.16b, v27.16b\n" + "ldr q21, [x27, #0x60]\n" + "ldr q19, [x27, #0x70]\n" + "shl v18.16b, v1.16b, #0x4\n" + "shl v17.16b, v0.16b, #0x4\n" + "ldr d16, [x26, #0x0]\n" + ".inst 0x4e98a4a7 // smmla v7.4s, v5.16b, v24.16b\n" + ".inst 0x4e94a4a6 // smmla v6.4s, v5.16b, v20.16b\n" + "and v1.16b, v1.16b, v27.16b\n" + ".inst 0x4e98a483 // smmla v3.4s, v4.16b, v24.16b\n" + ".inst 0x4e94a482 // smmla v2.4s, v4.16b, v20.16b\n" + "and v0.16b, v0.16b, v27.16b\n" + "add x26, x26, #0x8\n" + "add x27, x27, #0x80\n" + "shll v20.4s, v16.4h, #0x10\n" + ".inst 0x4e92a747 // smmla v7.4s, v26.16b, v18.16b\n" + ".inst 0x4e91a746 // smmla v6.4s, v26.16b, v17.16b\n" + ".inst 0x4e92a723 // smmla v3.4s, v25.16b, v18.16b\n" + ".inst 0x4e91a722 // smmla v2.4s, v25.16b, v17.16b\n" + ".inst 0x4e89a6e7 // smmla v7.4s, v23.16b, v9.16b\n" + ".inst 0x4e88a6e6 // smmla v6.4s, v23.16b, v8.16b\n" + ".inst 0x4e89a6c3 // smmla v3.4s, v22.16b, v9.16b\n" + ".inst 0x4e88a6c2 // smmla v2.4s, v22.16b, v8.16b\n" + ".inst 0x4e81a6a7 // smmla v7.4s, v21.16b, v1.16b\n" + ".inst 0x4e80a6a6 // smmla v6.4s, v21.16b, v0.16b\n" + ".inst 0x4e81a663 // smmla v3.4s, v19.16b, v1.16b\n" + ".inst 0x4e80a662 // smmla v2.4s, v19.16b, v0.16b\n" + "uzp1 v19.2d, v7.2d, v6.2d\n" + "uzp2 v18.2d, v7.2d, v6.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp1 v17.2d, v3.2d, v2.2d\n" + "uzp2 v16.2d, v3.2d, v2.2d\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v31.4s, v19.4s, v20.4s\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v30.4s, v18.4s, v20.4s\n" + "fmla v29.4s, v17.4s, v20.4s\n" + "fmla v28.4s, v16.4s, v20.4s\n" + "subs x20, x20, #0x1\n" + "bgt 17b\n" + "ld1 { v21.4s }, [x27]\n" + "ldr q20, [x26, #0x0]\n" + "add x27, x27, #0x10\n" + "add x20, %x[clamp_vals], #0x4\n" + "ldr q19, [x27, #0x0]\n" + "ldr q18, [x26, #0x10]\n" + "cmp x25, #0x4\n" + "add x26, x26, #0x20\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v21.4s, v21.4s\n" + "fmla v31.4s, v20.4s, v21.s[0]\n" + "fmla v30.4s, v20.4s, v21.s[1]\n" + "fmla v29.4s, v20.4s, v21.s[2]\n" + "fmla v28.4s, v20.4s, v21.s[3]\n" + "fmul v31.4s, v31.4s, v19.s[0]\n" + "fmul v30.4s, v30.4s, v19.s[1]\n" + "fmul v29.4s, v29.4s, v19.s[2]\n" + "fadd v31.4s, v31.4s, v18.4s\n" + "fmul v28.4s, v28.4s, v19.s[3]\n" + "fadd v30.4s, v30.4s, v18.4s\n" + "fadd v29.4s, v29.4s, v18.4s\n" + "fadd v28.4s, v28.4s, v18.4s\n" + "fmax v31.4s, v31.4s, v17.4s\n" + "fmax v30.4s, v30.4s, v17.4s\n" + "fmax v29.4s, v29.4s, v17.4s\n" + "fmax v28.4s, v28.4s, v17.4s\n" + "fmin v31.4s, v31.4s, v16.4s\n" + "fmin v30.4s, v30.4s, v16.4s\n" + "fmin v29.4s, v29.4s, v16.4s\n" + "fmin v28.4s, v28.4s, v16.4s\n" + "blt 19f\n" + "mov x20, %x[dst]\n" + "cmp x13, #0x1\n" + "str q31, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "cmp x13, #0x2\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "cmp x13, #0x3\n" + "str q29, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 22f\n" + "str q28, [x20, #0x0]\n" + "b 22f\n" + "19:" // Row tail: Partial output + "mov x23, %x[dst]\n" + "cmp x13, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GT\n" + "cmp x13, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GT\n" + "cmp x13, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GT\n" + "tbz x25, #1, 20f\n" + "st1 { v28.d }[0], [x20], #0x8\n" + "st1 { v29.d }[0], [x21], #0x8\n" + "st1 { v30.d }[0], [x22], #0x8\n" + "st1 { v31.d }[0], [x23], #0x8\n" + "tbz x25, #0, 21f\n" + "st1 { v28.s }[2], [x20]\n" + "st1 { v29.s }[2], [x21]\n" + "st1 { v30.s }[2], [x22]\n" + "st1 { v31.s }[2], [x23]\n" + "b 21f\n" + "20:" // Row tail: Output block 0: partial_1_0 + "st1 { v28.s }[0], [x20]\n" + "st1 { v29.s }[0], [x21]\n" + "st1 { v30.s }[0], [x22]\n" + "st1 { v31.s }[0], [x23]\n" + "21:" // Row tail: Output block 0: Done + "22:" // Row tail: Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 16b\n" + "subs x13, x13, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x12\n" + "mov %x[dst], x24\n" + "bgt 15b\n" + "23:" // Row tail: Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v2", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v31", "v4", "v5", + "v6", "v7", "v8", "v9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", + "x27", "x28", "x9"); + } else { + __asm__ __volatile__( + "mov x13, #0x80\n" + "mov x12, %x[m]\n" + "mov x20, #0x20\n" + "sub SP, SP, #0x100\n" + "mul x13, %x[num_subblocks], x13\n" + "cmp x12, #0x10\n" + "madd x13, %x[num_blocks], x13, x20\n" + "blt 15f\n" + "1:" // Row loop + "mov x11, %x[rhs_packed]\n" + "mov x10, %x[n]\n" + "add x9, %x[dst], %x[dst_stride_row], LSL #4\n" + "2:" // Column loop + "mov x27, %x[lhs_packed]\n" + "movi v29.4s, #0x0\n" + "mov x24, %x[num_blocks]\n" + "str q29, [SP, #0x0]\n" + "str q29, [SP, #0x10]\n" + "str q29, [SP, #0x20]\n" + "add x23, x27, x13\n" + "add x22, x23, x13\n" + "str q29, [SP, #0x30]\n" + "add x21, x22, x13\n" + "str q29, [SP, #0x40]\n" + "str q29, [SP, #0x50]\n" + "str q29, [SP, #0x60]\n" + "str q29, [SP, #0x70]\n" + "str q29, [SP, #0x80]\n" + "str q29, [SP, #0x90]\n" + "str q29, [SP, #0xa0]\n" + "str q29, [SP, #0xb0]\n" + "str q29, [SP, #0xc0]\n" + "str q29, [SP, #0xd0]\n" + "str q29, [SP, #0xe0]\n" + "str q29, [SP, #0xf0]\n" + "3:" // Block loop + "movi v7.4s, #0x0\n" + "movi v13.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v29.4s, #0x0\n" + "movi v12.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v15.4s, #0x0\n" + "movi v2.4s, #0x0\n" + "movi v22.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "movi v14.4s, #0x0\n" + "4:" // Sub block loop + "ldr q4, [x11, #0x0]\n" + "ldr q3, [x11, #0x10]\n" + "movi v31.16b, #0xf0\n" + "subs x20, x20, #0x1\n" + "ldr q27, [x27, #0x0]\n" + "ldr q1, [x27, #0x10]\n" + "ldr q19, [x23, #0x0]\n" + "ldr q17, [x23, #0x10]\n" + "ldr q21, [x22, #0x0]\n" + "ldr q23, [x22, #0x10]\n" + "shl v25.16b, v4.16b, #0x4\n" + "shl v20.16b, v3.16b, #0x4\n" + "ldr q5, [x21, #0x0]\n" + "ldr q16, [x21, #0x10]\n" + "and v4.16b, v4.16b, v31.16b\n" + "and v3.16b, v3.16b, v31.16b\n" + "ldr q8, [x11, #0x20]\n" + "ldr q11, [x11, #0x30]\n" + "add x11, x11, #0x40\n" + "ldr q24, [x27, #0x20]\n" + ".inst 0x4e99a767 // smmla v7.4s, v27.16b, v25.16b\n" + ".inst 0x4e94a76d // smmla v13.4s, v27.16b, v20.16b\n" + "ldr q27, [x27, #0x30]\n" + ".inst 0x4e99a43d // smmla v29.4s, v1.16b, v25.16b\n" + ".inst 0x4e94a42c // smmla v12.4s, v1.16b, v20.16b\n" + "ldr q1, [x23, #0x20]\n" + ".inst 0x4e99a67c // smmla v28.4s, v19.16b, v25.16b\n" + ".inst 0x4e94a66f // smmla v15.4s, v19.16b, v20.16b\n" + "ldr q19, [x23, #0x30]\n" + ".inst 0x4e99a622 // smmla v2.4s, v17.16b, v25.16b\n" + ".inst 0x4e94a636 // smmla v22.4s, v17.16b, v20.16b\n" + "ldr q17, [x22, #0x20]\n" + ".inst 0x4e99a6be // smmla v30.4s, v21.16b, v25.16b\n" + ".inst 0x4e94a6ba // smmla v26.4s, v21.16b, v20.16b\n" + "ldr q21, [x22, #0x30]\n" + ".inst 0x4e99a6e6 // smmla v6.4s, v23.16b, v25.16b\n" + ".inst 0x4e94a6ea // smmla v10.4s, v23.16b, v20.16b\n" + "ldr q23, [x21, #0x20]\n" + ".inst 0x4e99a4a9 // smmla v9.4s, v5.16b, v25.16b\n" + ".inst 0x4e94a4b2 // smmla v18.4s, v5.16b, v20.16b\n" + "ldr q5, [x21, #0x30]\n" + ".inst 0x4e99a600 // smmla v0.4s, v16.16b, v25.16b\n" + "ldr q25, [x27, #0x40]\n" + ".inst 0x4e94a60e // smmla v14.4s, v16.16b, v20.16b\n" + "ldr q16, [x27, #0x50]\n" + "shl v20.16b, v8.16b, #0x4\n" + "and v8.16b, v8.16b, v31.16b\n" + ".inst 0x4e94a707 // smmla v7.4s, v24.16b, v20.16b\n" + ".inst 0x4e94a77d // smmla v29.4s, v27.16b, v20.16b\n" + ".inst 0x4e94a43c // smmla v28.4s, v1.16b, v20.16b\n" + ".inst 0x4e94a662 // smmla v2.4s, v19.16b, v20.16b\n" + ".inst 0x4e94a63e // smmla v30.4s, v17.16b, v20.16b\n" + ".inst 0x4e94a6a6 // smmla v6.4s, v21.16b, v20.16b\n" + ".inst 0x4e94a6e9 // smmla v9.4s, v23.16b, v20.16b\n" + ".inst 0x4e94a4a0 // smmla v0.4s, v5.16b, v20.16b\n" + "shl v20.16b, v11.16b, #0x4\n" + ".inst 0x4e84a727 // smmla v7.4s, v25.16b, v4.16b\n" + ".inst 0x4e84a61d // smmla v29.4s, v16.16b, v4.16b\n" + "and v11.16b, v11.16b, v31.16b\n" + "ldr q31, [x23, #0x40]\n" + ".inst 0x4e94a70d // smmla v13.4s, v24.16b, v20.16b\n" + "ldr q24, [x23, #0x50]\n" + ".inst 0x4e94a76c // smmla v12.4s, v27.16b, v20.16b\n" + "ldr q27, [x22, #0x40]\n" + ".inst 0x4e94a42f // smmla v15.4s, v1.16b, v20.16b\n" + "ldr q1, [x22, #0x50]\n" + ".inst 0x4e94a676 // smmla v22.4s, v19.16b, v20.16b\n" + "ldr q19, [x21, #0x40]\n" + ".inst 0x4e94a63a // smmla v26.4s, v17.16b, v20.16b\n" + "ldr q17, [x21, #0x50]\n" + ".inst 0x4e94a6aa // smmla v10.4s, v21.16b, v20.16b\n" + "ldr q21, [x27, #0x60]\n" + ".inst 0x4e94a6f2 // smmla v18.4s, v23.16b, v20.16b\n" + "ldr q23, [x27, #0x70]\n" + ".inst 0x4e94a4ae // smmla v14.4s, v5.16b, v20.16b\n" + "ldr q20, [x23, #0x60]\n" + ".inst 0x4e83a72d // smmla v13.4s, v25.16b, v3.16b\n" + "ldr q5, [x23, #0x70]\n" + "ldr q25, [x22, #0x60]\n" + ".inst 0x4e83a60c // smmla v12.4s, v16.16b, v3.16b\n" + ".inst 0x4e84a7fc // smmla v28.4s, v31.16b, v4.16b\n" + "ldr q16, [x22, #0x70]\n" + ".inst 0x4e83a7ef // smmla v15.4s, v31.16b, v3.16b\n" + "ldr q31, [x21, #0x60]\n" + ".inst 0x4e84a702 // smmla v2.4s, v24.16b, v4.16b\n" + ".inst 0x4e83a716 // smmla v22.4s, v24.16b, v3.16b\n" + "ldr q24, [x21, #0x70]\n" + ".inst 0x4e84a77e // smmla v30.4s, v27.16b, v4.16b\n" + "add x27, x27, #0x80\n" + ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" + ".inst 0x4e84a426 // smmla v6.4s, v1.16b, v4.16b\n" + "add x23, x23, #0x80\n" + "add x22, x22, #0x80\n" + ".inst 0x4e83a42a // smmla v10.4s, v1.16b, v3.16b\n" + ".inst 0x4e84a669 // smmla v9.4s, v19.16b, v4.16b\n" + "add x21, x21, #0x80\n" + ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" + ".inst 0x4e84a620 // smmla v0.4s, v17.16b, v4.16b\n" + ".inst 0x4e83a62e // smmla v14.4s, v17.16b, v3.16b\n" + ".inst 0x4e88a6a7 // smmla v7.4s, v21.16b, v8.16b\n" + ".inst 0x4e8ba6ad // smmla v13.4s, v21.16b, v11.16b\n" + ".inst 0x4e88a6fd // smmla v29.4s, v23.16b, v8.16b\n" + ".inst 0x4e8ba6ec // smmla v12.4s, v23.16b, v11.16b\n" + ".inst 0x4e88a69c // smmla v28.4s, v20.16b, v8.16b\n" + ".inst 0x4e8ba68f // smmla v15.4s, v20.16b, v11.16b\n" + ".inst 0x4e88a4a2 // smmla v2.4s, v5.16b, v8.16b\n" + ".inst 0x4e8ba4b6 // smmla v22.4s, v5.16b, v11.16b\n" + ".inst 0x4e88a73e // smmla v30.4s, v25.16b, v8.16b\n" + ".inst 0x4e8ba73a // smmla v26.4s, v25.16b, v11.16b\n" + ".inst 0x4e88a606 // smmla v6.4s, v16.16b, v8.16b\n" + ".inst 0x4e8ba60a // smmla v10.4s, v16.16b, v11.16b\n" + ".inst 0x4e88a7e9 // smmla v9.4s, v31.16b, v8.16b\n" + ".inst 0x4e8ba7f2 // smmla v18.4s, v31.16b, v11.16b\n" + ".inst 0x4e88a700 // smmla v0.4s, v24.16b, v8.16b\n" + ".inst 0x4e8ba70e // smmla v14.4s, v24.16b, v11.16b\n" + "bgt 4b\n" + "ldr d4, [x11, #0x0]\n" + "ldr q23, [SP, #0x0]\n" + "uzp1 v16.2d, v7.2d, v13.2d\n" + "uzp2 v19.2d, v7.2d, v13.2d\n" + "uzp1 v20.2d, v29.2d, v12.2d\n" + "uzp2 v17.2d, v29.2d, v12.2d\n" + "add x11, x11, #0x8\n" + "shll v24.4s, v4.4h, #0x10\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "fmla v23.4s, v16.4s, v24.4s\n" + "str q23, [SP, #0x0]\n" + "ldr q16, [SP, #0x10]\n" + "fmla v16.4s, v19.4s, v24.4s\n" + "str q16, [SP, #0x10]\n" + "ldr q16, [SP, #0x20]\n" + "fmla v16.4s, v20.4s, v24.4s\n" + "str q16, [SP, #0x20]\n" + "ldr q16, [SP, #0x30]\n" + "fmla v16.4s, v17.4s, v24.4s\n" + "str q16, [SP, #0x30]\n" + "ldr q1, [SP, #0x40]\n" + "uzp1 v16.2d, v28.2d, v15.2d\n" + "uzp2 v19.2d, v28.2d, v15.2d\n" + "uzp1 v5.2d, v2.2d, v22.2d\n" + "uzp2 v17.2d, v2.2d, v22.2d\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "scvtf v5.4s, v5.4s, #0x4\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "fmla v1.4s, v16.4s, v24.4s\n" + "str q1, [SP, #0x40]\n" + "ldr q16, [SP, #0x50]\n" + "fmla v16.4s, v19.4s, v24.4s\n" + "str q16, [SP, #0x50]\n" + "ldr q16, [SP, #0x60]\n" + "fmla v16.4s, v5.4s, v24.4s\n" + "str q16, [SP, #0x60]\n" + "ldr q16, [SP, #0x70]\n" + "fmla v16.4s, v17.4s, v24.4s\n" + "str q16, [SP, #0x70]\n" + "ldr q1, [SP, #0x80]\n" + "uzp1 v16.2d, v30.2d, v26.2d\n" + "uzp2 v19.2d, v30.2d, v26.2d\n" + "uzp1 v30.2d, v6.2d, v10.2d\n" + "uzp2 v17.2d, v6.2d, v10.2d\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "scvtf v30.4s, v30.4s, #0x4\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "fmla v1.4s, v16.4s, v24.4s\n" + "str q1, [SP, #0x80]\n" + "ldr q16, [SP, #0x90]\n" + "fmla v16.4s, v19.4s, v24.4s\n" + "str q16, [SP, #0x90]\n" + "ldr q16, [SP, #0xa0]\n" + "fmla v16.4s, v30.4s, v24.4s\n" + "str q16, [SP, #0xa0]\n" + "ldr q16, [SP, #0xb0]\n" + "fmla v16.4s, v17.4s, v24.4s\n" + "str q16, [SP, #0xb0]\n" + "ldr q31, [SP, #0xc0]\n" + "uzp1 v16.2d, v9.2d, v18.2d\n" + "uzp2 v19.2d, v9.2d, v18.2d\n" + "uzp1 v21.2d, v0.2d, v14.2d\n" + "uzp2 v17.2d, v0.2d, v14.2d\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "scvtf v21.4s, v21.4s, #0x4\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "fmla v31.4s, v16.4s, v24.4s\n" + "str q31, [SP, #0xc0]\n" + "ldr q16, [SP, #0xd0]\n" + "fmla v16.4s, v19.4s, v24.4s\n" + "str q16, [SP, #0xd0]\n" + "ldr q16, [SP, #0xe0]\n" + "fmla v16.4s, v21.4s, v24.4s\n" + "str q16, [SP, #0xe0]\n" + "ldr q16, [SP, #0xf0]\n" + "fmla v16.4s, v17.4s, v24.4s\n" + "str q16, [SP, #0xf0]\n" + "subs x24, x24, #0x1\n" + "bgt 3b\n" + "ld1 { v11.4s }, [x27]\n" + "ld1 { v10.4s }, [x23]\n" + "add x27, x27, #0x10\n" + "add x23, x23, #0x10\n" + "ld1 { v9.4s }, [x22]\n" + "ld1 { v8.4s }, [x21]\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "ldr q31, [SP, #0x0]\n" + "ldr q30, [SP, #0x10]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x10, #0x4\n" + "ldr q29, [SP, #0x20]\n" + "ldr q28, [SP, #0x30]\n" + "scvtf v11.4s, v11.4s\n" + "scvtf v10.4s, v10.4s\n" + "ldr q27, [SP, #0x40]\n" + "ldr q26, [SP, #0x50]\n" + "scvtf v9.4s, v9.4s\n" + "scvtf v8.4s, v8.4s\n" + "ldr q25, [SP, #0x60]\n" + "ldr q24, [SP, #0x70]\n" + "ldr q23, [SP, #0x80]\n" + "ldr q22, [SP, #0x90]\n" + "ldr q21, [SP, #0xa0]\n" + "ldr q20, [SP, #0xb0]\n" + "ldr q19, [SP, #0xc0]\n" + "ldr q18, [SP, #0xd0]\n" + "ldr q17, [SP, #0xe0]\n" + "ldr q16, [SP, #0xf0]\n" + "ldr q7, [x11, #0x0]\n" + "ldr q6, [x27, #0x0]\n" + "ldr q5, [x23, #0x0]\n" + "ldr q4, [x22, #0x0]\n" + "ldr q3, [x21, #0x0]\n" + "ldr q2, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + "ld1r { v1.4s }, [%x[clamp_vals]]\n" + "ld1r { v0.4s }, [x20]\n" + "fmla v31.4s, v7.4s, v11.s[0]\n" + "fmla v30.4s, v7.4s, v11.s[1]\n" + "fmla v29.4s, v7.4s, v11.s[2]\n" + "fmla v28.4s, v7.4s, v11.s[3]\n" + "fmla v27.4s, v7.4s, v10.s[0]\n" + "fmla v26.4s, v7.4s, v10.s[1]\n" + "fmla v25.4s, v7.4s, v10.s[2]\n" + "fmla v24.4s, v7.4s, v10.s[3]\n" + "fmla v23.4s, v7.4s, v9.s[0]\n" + "fmla v22.4s, v7.4s, v9.s[1]\n" + "fmul v31.4s, v31.4s, v6.s[0]\n" + "fmla v21.4s, v7.4s, v9.s[2]\n" + "fmla v20.4s, v7.4s, v9.s[3]\n" + "fmul v30.4s, v30.4s, v6.s[1]\n" + "fmla v19.4s, v7.4s, v8.s[0]\n" + "fmla v18.4s, v7.4s, v8.s[1]\n" + "fmul v29.4s, v29.4s, v6.s[2]\n" + "fmla v17.4s, v7.4s, v8.s[2]\n" + "fmla v16.4s, v7.4s, v8.s[3]\n" + "fmul v28.4s, v28.4s, v6.s[3]\n" + "fmul v27.4s, v27.4s, v5.s[0]\n" + "fmul v26.4s, v26.4s, v5.s[1]\n" + "fmul v25.4s, v25.4s, v5.s[2]\n" + "fmul v24.4s, v24.4s, v5.s[3]\n" + "fmul v23.4s, v23.4s, v4.s[0]\n" + "fmul v22.4s, v22.4s, v4.s[1]\n" + "fmul v21.4s, v21.4s, v4.s[2]\n" + "fmul v20.4s, v20.4s, v4.s[3]\n" + "fmul v19.4s, v19.4s, v3.s[0]\n" + "fmul v18.4s, v18.4s, v3.s[1]\n" + "fmul v17.4s, v17.4s, v3.s[2]\n" + "fmul v16.4s, v16.4s, v3.s[3]\n" + "fadd v31.4s, v31.4s, v2.4s\n" + "fadd v30.4s, v30.4s, v2.4s\n" + "fadd v29.4s, v29.4s, v2.4s\n" + "fadd v28.4s, v28.4s, v2.4s\n" + "fadd v27.4s, v27.4s, v2.4s\n" + "fadd v26.4s, v26.4s, v2.4s\n" + "fadd v25.4s, v25.4s, v2.4s\n" + "fadd v24.4s, v24.4s, v2.4s\n" + "fadd v23.4s, v23.4s, v2.4s\n" + "fadd v22.4s, v22.4s, v2.4s\n" + "fadd v21.4s, v21.4s, v2.4s\n" + "fadd v20.4s, v20.4s, v2.4s\n" + "fadd v19.4s, v19.4s, v2.4s\n" + "fadd v18.4s, v18.4s, v2.4s\n" + "fadd v17.4s, v17.4s, v2.4s\n" + "fadd v16.4s, v16.4s, v2.4s\n" + "fmax v31.4s, v31.4s, v1.4s\n" + "fmax v30.4s, v30.4s, v1.4s\n" + "fmax v29.4s, v29.4s, v1.4s\n" + "fmax v28.4s, v28.4s, v1.4s\n" + "fmax v27.4s, v27.4s, v1.4s\n" + "fmax v26.4s, v26.4s, v1.4s\n" + "fmax v25.4s, v25.4s, v1.4s\n" + "fmax v24.4s, v24.4s, v1.4s\n" + "fmax v23.4s, v23.4s, v1.4s\n" + "fmax v22.4s, v22.4s, v1.4s\n" + "fmax v21.4s, v21.4s, v1.4s\n" + "fmax v20.4s, v20.4s, v1.4s\n" + "fmax v19.4s, v19.4s, v1.4s\n" + "fmax v18.4s, v18.4s, v1.4s\n" + "fmax v17.4s, v17.4s, v1.4s\n" + "fmax v16.4s, v16.4s, v1.4s\n" + "fmin v31.4s, v31.4s, v0.4s\n" + "fmin v30.4s, v30.4s, v0.4s\n" + "fmin v29.4s, v29.4s, v0.4s\n" + "fmin v28.4s, v28.4s, v0.4s\n" + "fmin v27.4s, v27.4s, v0.4s\n" + "fmin v26.4s, v26.4s, v0.4s\n" + "fmin v25.4s, v25.4s, v0.4s\n" + "fmin v24.4s, v24.4s, v0.4s\n" + "fmin v23.4s, v23.4s, v0.4s\n" + "fmin v22.4s, v22.4s, v0.4s\n" + "fmin v21.4s, v21.4s, v0.4s\n" + "fmin v20.4s, v20.4s, v0.4s\n" + "fmin v19.4s, v19.4s, v0.4s\n" + "fmin v18.4s, v18.4s, v0.4s\n" + "fmin v17.4s, v17.4s, v0.4s\n" + "fmin v16.4s, v16.4s, v0.4s\n" + "blt 9f\n" + "mov x20, %x[dst]\n" + "str q31, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q29, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q27, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q26, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q24, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q21, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q20, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q17, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q16, [x20, #0x0]\n" + "b 14f\n" + "9:" // Partial output + "mov x28, %x[dst]\n" + "add x26, x28, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x28, %x[dst_stride_row], LSL #1\n" + "add x21, x28, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "add x27, x23, %x[dst_stride_row]\n" + "tbz x10, #1, 10f\n" + "st1 { v24.d }[0], [x23], #0x8\n" + "st1 { v25.d }[0], [x25], #0x8\n" + "st1 { v26.d }[0], [x24], #0x8\n" + "st1 { v27.d }[0], [x26], #0x8\n" + "st1 { v28.d }[0], [x20], #0x8\n" + "st1 { v29.d }[0], [x22], #0x8\n" + "st1 { v30.d }[0], [x21], #0x8\n" + "st1 { v31.d }[0], [x28], #0x8\n" + "tbz x10, #0, 11f\n" + "st1 { v24.s }[2], [x23]\n" + "st1 { v25.s }[2], [x25]\n" + "st1 { v26.s }[2], [x24]\n" + "st1 { v27.s }[2], [x26]\n" + "st1 { v28.s }[2], [x20]\n" + "st1 { v29.s }[2], [x22]\n" + "st1 { v30.s }[2], [x21]\n" + "st1 { v31.s }[2], [x28]\n" + "b 11f\n" + "10:" // Output block 0: partial_1_0 + "st1 { v24.s }[0], [x23]\n" + "st1 { v25.s }[0], [x25]\n" + "st1 { v26.s }[0], [x24]\n" + "st1 { v27.s }[0], [x26]\n" + "st1 { v28.s }[0], [x20]\n" + "st1 { v29.s }[0], [x22]\n" + "st1 { v30.s }[0], [x21]\n" + "st1 { v31.s }[0], [x28]\n" + "11:" // Output block 0: Done + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x27, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row], LSL #1\n" + "add x23, x27, %x[dst_stride_row]\n" + "add x22, x25, %x[dst_stride_row]\n" + "add x21, x26, %x[dst_stride_row]\n" + "add x20, x24, %x[dst_stride_row]\n" + "tbz x10, #1, 12f\n" + "st1 { v16.d }[0], [x20], #0x8\n" + "st1 { v17.d }[0], [x24], #0x8\n" + "st1 { v18.d }[0], [x21], #0x8\n" + "st1 { v19.d }[0], [x26], #0x8\n" + "st1 { v20.d }[0], [x22], #0x8\n" + "st1 { v21.d }[0], [x25], #0x8\n" + "st1 { v22.d }[0], [x23], #0x8\n" + "st1 { v23.d }[0], [x27], #0x8\n" + "tbz x10, #0, 13f\n" + "st1 { v16.s }[2], [x20]\n" + "st1 { v17.s }[2], [x24]\n" + "st1 { v18.s }[2], [x21]\n" + "st1 { v19.s }[2], [x26]\n" + "st1 { v20.s }[2], [x22]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v22.s }[2], [x23]\n" + "st1 { v23.s }[2], [x27]\n" + "b 13f\n" + "12:" // Output block 1: partial_1_0 + "st1 { v16.s }[0], [x20]\n" + "st1 { v17.s }[0], [x24]\n" + "st1 { v18.s }[0], [x21]\n" + "st1 { v19.s }[0], [x26]\n" + "st1 { v20.s }[0], [x22]\n" + "st1 { v21.s }[0], [x25]\n" + "st1 { v22.s }[0], [x23]\n" + "st1 { v23.s }[0], [x27]\n" + "13:" // Output block 1: Done + "14:" // Output stage exit + "subs x10, x10, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "mov x20, #0x4\n" + "sub x12, x12, #0x10\n" + "cmp x12, #0x10\n" + "mov %x[dst], x9\n" + "madd %x[lhs_packed], x20, x13, %x[lhs_packed]\n" + "bge 1b\n" + "15:" // Row loop skip + "cbz x12, 25f\n" + "16:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "17:" // Row tail: Column loop + "movi v16.4s, #0x0\n" + "mov x27, %x[lhs_packed]\n" + "mov x21, %x[num_blocks]\n" + "str q16, [SP, #0x0]\n" + "str q16, [SP, #0x10]\n" + "str q16, [SP, #0x20]\n" + "str q16, [SP, #0x30]\n" + "18:" // Row tail: Block loop + "movi v7.4s, #0x0\n" + "movi v13.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v29.4s, #0x0\n" + "movi v12.4s, #0x0\n" + "19:" // Row tail: Sub block loop + "ldr q0, [x26, #0x0]\n" + "ldr q31, [x26, #0x10]\n" + "movi v30.16b, #0xf0\n" + "subs x20, x20, #0x1\n" + "ldr q18, [x27, #0x0]\n" + "ldr q28, [x27, #0x10]\n" + "ldr q27, [x26, #0x20]\n" + "ldr q26, [x26, #0x30]\n" + "add x26, x26, #0x40\n" + "ldr q25, [x27, #0x20]\n" + "ldr q24, [x27, #0x30]\n" + "shl v23.16b, v0.16b, #0x4\n" + "shl v22.16b, v31.16b, #0x4\n" + "ldr q21, [x27, #0x40]\n" + "ldr q20, [x27, #0x50]\n" + "and v0.16b, v0.16b, v30.16b\n" + "and v31.16b, v31.16b, v30.16b\n" + "ldr q19, [x27, #0x60]\n" + "ldr q14, [x27, #0x70]\n" + "shl v17.16b, v27.16b, #0x4\n" + "shl v16.16b, v26.16b, #0x4\n" + ".inst 0x4e97a647 // smmla v7.4s, v18.16b, v23.16b\n" + ".inst 0x4e96a64d // smmla v13.4s, v18.16b, v22.16b\n" + "and v27.16b, v27.16b, v30.16b\n" + "add x27, x27, #0x80\n" + ".inst 0x4e97a79d // smmla v29.4s, v28.16b, v23.16b\n" + ".inst 0x4e96a78c // smmla v12.4s, v28.16b, v22.16b\n" + "and v26.16b, v26.16b, v30.16b\n" + ".inst 0x4e91a727 // smmla v7.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a72d // smmla v13.4s, v25.16b, v16.16b\n" + ".inst 0x4e91a71d // smmla v29.4s, v24.16b, v17.16b\n" + ".inst 0x4e90a70c // smmla v12.4s, v24.16b, v16.16b\n" + ".inst 0x4e80a6a7 // smmla v7.4s, v21.16b, v0.16b\n" + ".inst 0x4e9fa6ad // smmla v13.4s, v21.16b, v31.16b\n" + ".inst 0x4e80a69d // smmla v29.4s, v20.16b, v0.16b\n" + ".inst 0x4e9fa68c // smmla v12.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ba667 // smmla v7.4s, v19.16b, v27.16b\n" + ".inst 0x4e9aa66d // smmla v13.4s, v19.16b, v26.16b\n" + ".inst 0x4e9ba5dd // smmla v29.4s, v14.16b, v27.16b\n" + ".inst 0x4e9aa5cc // smmla v12.4s, v14.16b, v26.16b\n" + "bgt 19b\n" + "ldr d17, [x26, #0x0]\n" + "ldr q21, [SP, #0x0]\n" + "uzp1 v16.2d, v7.2d, v13.2d\n" + "uzp2 v20.2d, v7.2d, v13.2d\n" + "uzp1 v19.2d, v29.2d, v12.2d\n" + "uzp2 v18.2d, v29.2d, v12.2d\n" + "add x26, x26, #0x8\n" + "shll v17.4s, v17.4h, #0x10\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v21.4s, v16.4s, v17.4s\n" + "str q21, [SP, #0x0]\n" + "ldr q16, [SP, #0x10]\n" + "fmla v16.4s, v20.4s, v17.4s\n" + "str q16, [SP, #0x10]\n" + "ldr q16, [SP, #0x20]\n" + "fmla v16.4s, v19.4s, v17.4s\n" + "str q16, [SP, #0x20]\n" + "ldr q16, [SP, #0x30]\n" + "fmla v16.4s, v18.4s, v17.4s\n" + "str q16, [SP, #0x30]\n" + "subs x21, x21, #0x1\n" + "bgt 18b\n" + "ld1 { v21.4s }, [x27]\n" + "ldr q31, [SP, #0x0]\n" + "add x27, x27, #0x10\n" + "add x20, %x[clamp_vals], #0x4\n" + "ldr q30, [SP, #0x10]\n" + "ldr q29, [SP, #0x20]\n" + "cmp x25, #0x4\n" + "ldr q28, [SP, #0x30]\n" + "ldr q20, [x26, #0x0]\n" + "ldr q19, [x27, #0x0]\n" + "ldr q18, [x26, #0x10]\n" + "scvtf v21.4s, v21.4s\n" + "add x26, x26, #0x20\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "ld1r { v16.4s }, [x20]\n" + "fmla v31.4s, v20.4s, v21.s[0]\n" + "fmla v30.4s, v20.4s, v21.s[1]\n" + "fmla v29.4s, v20.4s, v21.s[2]\n" + "fmla v28.4s, v20.4s, v21.s[3]\n" + "fmul v31.4s, v31.4s, v19.s[0]\n" + "fmul v30.4s, v30.4s, v19.s[1]\n" + "fadd v31.4s, v31.4s, v18.4s\n" + "fmul v29.4s, v29.4s, v19.s[2]\n" + "fmul v28.4s, v28.4s, v19.s[3]\n" + "fadd v30.4s, v30.4s, v18.4s\n" + "fmax v31.4s, v31.4s, v17.4s\n" + "fadd v29.4s, v29.4s, v18.4s\n" + "fadd v28.4s, v28.4s, v18.4s\n" + "fmax v30.4s, v30.4s, v17.4s\n" + "fmin v31.4s, v31.4s, v16.4s\n" + "fmax v29.4s, v29.4s, v17.4s\n" + "fmax v28.4s, v28.4s, v17.4s\n" + "fmin v30.4s, v30.4s, v16.4s\n" + "fmin v29.4s, v29.4s, v16.4s\n" + "fmin v28.4s, v28.4s, v16.4s\n" + "blt 21f\n" + "mov x20, %x[dst]\n" + "cmp x12, #0x1\n" + "str q31, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 24f\n" + "cmp x12, #0x2\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 24f\n" + "cmp x12, #0x3\n" + "str q29, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 24f\n" + "str q28, [x20, #0x0]\n" + "b 24f\n" + "21:" // Row tail: Partial output + "mov x23, %x[dst]\n" + "cmp x12, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GT\n" + "cmp x12, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GT\n" + "cmp x12, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GT\n" + "tbz x25, #1, 22f\n" + "st1 { v28.d }[0], [x20], #0x8\n" + "st1 { v29.d }[0], [x21], #0x8\n" + "st1 { v30.d }[0], [x22], #0x8\n" + "st1 { v31.d }[0], [x23], #0x8\n" + "tbz x25, #0, 23f\n" + "st1 { v28.s }[2], [x20]\n" + "st1 { v29.s }[2], [x21]\n" + "st1 { v30.s }[2], [x22]\n" + "st1 { v31.s }[2], [x23]\n" + "b 23f\n" + "22:" // Row tail: Output block 0: partial_1_0 + "st1 { v28.s }[0], [x20]\n" + "st1 { v29.s }[0], [x21]\n" + "st1 { v30.s }[0], [x22]\n" + "st1 { v31.s }[0], [x23]\n" + "23:" // Row tail: Output block 0: Done + "24:" // Row tail: Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 17b\n" + "subs x12, x12, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x13\n" + "mov %x[dst], x24\n" + "bgt 16b\n" + "25:" // Row tail: Row loop skip + "add SP, SP, #0x100\n" + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v2", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v31", "v4", "v5", + "v6", "v7", "v8", "v9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", + "x27", "x28", "x9"); + } +} +#endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h new file mode 100644 index 00000000..d53d62bd --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h @@ -0,0 +1,142 @@ + +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void); + +/// Gets the nr value, which must be used to pack the RHS matrix +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 16 +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 16. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. +/// Output tile: (rows x cols) = 8 x 4 +/// Accumulation performed in a single for loop: 32 +/// Extension used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] bl Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref +/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c index 5af3e172..5e84e870 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -14,7 +14,7 @@ #include "kai/kai_common.h" -static const size_t kai_m_step = 16; +static const size_t kai_m_step = 8; static const size_t kai_n_step = 4; static const size_t kai_mr = 4; static const size_t kai_nr = 4; @@ -106,627 +106,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm return m * n * sizeof(float); } -inline static void kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( - size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, - size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { - KAI_ASSERT(bl == 32); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(dst_stride_col == sizeof(float)); - - if (m == 0) { - return; - } - - size_t num_blocks = kai_num_blocks_per_row(k, bl); - - float clamp_vals[2] = {scalar_min, scalar_max}; - - __asm__ __volatile__( - "mov x13, %x[m]\n" - "mov x12, #0x80\n" - "mov x20, #0x20\n" - "cmp x13, #0x10\n" - "madd x12, %x[num_blocks], x12, x20\n" - "blt 14f\n" - "1:" // Row loop - "mov x11, %x[rhs_packed]\n" - "mov x10, %x[n]\n" - "add x9, %x[dst], %x[dst_stride_row], LSL #4\n" - "2:" // Column loop - "mov x27, %x[lhs_packed]\n" - "movi v31.16b, #0x0\n" - "movi v30.16b, #0x0\n" - "mov x20, %x[num_blocks]\n" - "movi v29.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "movi v27.16b, #0x0\n" - "movi v26.16b, #0x0\n" - "add x23, x27, x12\n" - "add x22, x23, x12\n" - "movi v25.16b, #0x0\n" - "movi v24.16b, #0x0\n" - "add x21, x22, x12\n" - "movi v23.16b, #0x0\n" - "movi v22.16b, #0x0\n" - "movi v21.16b, #0x0\n" - "movi v20.16b, #0x0\n" - "movi v19.16b, #0x0\n" - "movi v18.16b, #0x0\n" - "movi v17.16b, #0x0\n" - "movi v16.16b, #0x0\n" - "3:" // Block loop - "ldr q11, [x11, #0x0]\n" - "ldr q4, [x11, #0x10]\n" - "movi v2.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "ldr q12, [x27, #0x0]\n" - "ldr q0, [x27, #0x10]\n" - "movi v7.4s, #0x0\n" - "movi v5.4s, #0x0\n" - "ldr q15, [x11, #0x20]\n" - "ldr q13, [x11, #0x30]\n" - "movi v10.16b, #0xf0\n" - "add x11, x11, #0x40\n" - "ldr q8, [x27, #0x20]\n" - "ldr q6, [x27, #0x30]\n" - "shl v14.16b, v11.16b, #0x4\n" - "shl v3.16b, v4.16b, #0x4\n" - "ldr q1, [x27, #0x40]\n" - "and v11.16b, v11.16b, v10.16b\n" - "and v4.16b, v4.16b, v10.16b\n" - ".inst 0x4e8ea582 // smmla v2.4s, v12.16b, v14.16b\n" - ".inst 0x4e83a589 // smmla v9.4s, v12.16b, v3.16b\n" - "shl v12.16b, v15.16b, #0x4\n" - ".inst 0x4e8ea407 // smmla v7.4s, v0.16b, v14.16b\n" - ".inst 0x4e83a405 // smmla v5.4s, v0.16b, v3.16b\n" - "shl v0.16b, v13.16b, #0x4\n" - "and v15.16b, v15.16b, v10.16b\n" - "and v13.16b, v13.16b, v10.16b\n" - "ldr q10, [x27, #0x50]\n" - ".inst 0x4e8ca502 // smmla v2.4s, v8.16b, v12.16b\n" - ".inst 0x4e80a509 // smmla v9.4s, v8.16b, v0.16b\n" - "ldr q8, [x27, #0x60]\n" - ".inst 0x4e8ca4c7 // smmla v7.4s, v6.16b, v12.16b\n" - ".inst 0x4e80a4c5 // smmla v5.4s, v6.16b, v0.16b\n" - "ldr q6, [x27, #0x70]\n" - "add x27, x27, #0x80\n" - ".inst 0x4e8ba422 // smmla v2.4s, v1.16b, v11.16b\n" - ".inst 0x4e84a429 // smmla v9.4s, v1.16b, v4.16b\n" - "ldr d1, [x11, #0x0]\n" - "add x11, x11, #0x8\n" - ".inst 0x4e8ba547 // smmla v7.4s, v10.16b, v11.16b\n" - ".inst 0x4e84a545 // smmla v5.4s, v10.16b, v4.16b\n" - ".inst 0x4e8fa502 // smmla v2.4s, v8.16b, v15.16b\n" - "shll v1.4s, v1.4h, #0x10\n" - ".inst 0x4e8da509 // smmla v9.4s, v8.16b, v13.16b\n" - ".inst 0x4e8fa4c7 // smmla v7.4s, v6.16b, v15.16b\n" - ".inst 0x4e8da4c5 // smmla v5.4s, v6.16b, v13.16b\n" - "uzp1 v6.2d, v2.2d, v9.2d\n" - "uzp2 v8.2d, v2.2d, v9.2d\n" - "scvtf v6.4s, v6.4s, #0x4\n" - "uzp1 v9.2d, v7.2d, v5.2d\n" - "uzp2 v2.2d, v7.2d, v5.2d\n" - "scvtf v8.4s, v8.4s, #0x4\n" - "fmla v31.4s, v6.4s, v1.4s\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v30.4s, v8.4s, v1.4s\n" - "fmla v29.4s, v9.4s, v1.4s\n" - "fmla v28.4s, v2.4s, v1.4s\n" - "ldr q9, [x23, #0x0]\n" - "ldr q7, [x23, #0x10]\n" - "movi v8.4s, #0x0\n" - "movi v2.4s, #0x0\n" - "ldr q5, [x23, #0x20]\n" - "ldr q10, [x23, #0x30]\n" - "movi v6.4s, #0x0\n" - ".inst 0x4e8ea528 // smmla v8.4s, v9.16b, v14.16b\n" - ".inst 0x4e83a522 // smmla v2.4s, v9.16b, v3.16b\n" - "ldr q9, [x23, #0x40]\n" - ".inst 0x4e8ea4e6 // smmla v6.4s, v7.16b, v14.16b\n" - ".inst 0x4e8ca4a8 // smmla v8.4s, v5.16b, v12.16b\n" - ".inst 0x4e80a4a2 // smmla v2.4s, v5.16b, v0.16b\n" - "ldr q5, [x23, #0x50]\n" - ".inst 0x4e8ca546 // smmla v6.4s, v10.16b, v12.16b\n" - ".inst 0x4e8ba528 // smmla v8.4s, v9.16b, v11.16b\n" - ".inst 0x4e84a522 // smmla v2.4s, v9.16b, v4.16b\n" - "ldr q9, [x23, #0x60]\n" - ".inst 0x4e8ba4a6 // smmla v6.4s, v5.16b, v11.16b\n" - ".inst 0x4e8fa528 // smmla v8.4s, v9.16b, v15.16b\n" - ".inst 0x4e8da522 // smmla v2.4s, v9.16b, v13.16b\n" - "movi v9.4s, #0x0\n" - ".inst 0x4e83a4e9 // smmla v9.4s, v7.16b, v3.16b\n" - "ldr q7, [x23, #0x70]\n" - "add x23, x23, #0x80\n" - ".inst 0x4e8fa4e6 // smmla v6.4s, v7.16b, v15.16b\n" - ".inst 0x4e80a549 // smmla v9.4s, v10.16b, v0.16b\n" - "uzp1 v10.2d, v8.2d, v2.2d\n" - "uzp2 v2.2d, v8.2d, v2.2d\n" - "scvtf v10.4s, v10.4s, #0x4\n" - ".inst 0x4e84a4a9 // smmla v9.4s, v5.16b, v4.16b\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v27.4s, v10.4s, v1.4s\n" - ".inst 0x4e8da4e9 // smmla v9.4s, v7.16b, v13.16b\n" - "fmla v26.4s, v2.4s, v1.4s\n" - "uzp1 v2.2d, v6.2d, v9.2d\n" - "uzp2 v10.2d, v6.2d, v9.2d\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "fmla v25.4s, v2.4s, v1.4s\n" - "fmla v24.4s, v10.4s, v1.4s\n" - "ldr q8, [x22, #0x0]\n" - "ldr q7, [x22, #0x10]\n" - "movi v9.4s, #0x0\n" - "movi v6.4s, #0x0\n" - "ldr q2, [x22, #0x20]\n" - "ldr q5, [x22, #0x30]\n" - "movi v10.4s, #0x0\n" - ".inst 0x4e8ea509 // smmla v9.4s, v8.16b, v14.16b\n" - ".inst 0x4e83a506 // smmla v6.4s, v8.16b, v3.16b\n" - "ldr q8, [x22, #0x40]\n" - ".inst 0x4e8ea4ea // smmla v10.4s, v7.16b, v14.16b\n" - ".inst 0x4e8ca449 // smmla v9.4s, v2.16b, v12.16b\n" - ".inst 0x4e80a446 // smmla v6.4s, v2.16b, v0.16b\n" - "ldr q2, [x22, #0x50]\n" - ".inst 0x4e8ca4aa // smmla v10.4s, v5.16b, v12.16b\n" - ".inst 0x4e8ba509 // smmla v9.4s, v8.16b, v11.16b\n" - ".inst 0x4e84a506 // smmla v6.4s, v8.16b, v4.16b\n" - "ldr q8, [x22, #0x60]\n" - ".inst 0x4e8ba44a // smmla v10.4s, v2.16b, v11.16b\n" - ".inst 0x4e8fa509 // smmla v9.4s, v8.16b, v15.16b\n" - ".inst 0x4e8da506 // smmla v6.4s, v8.16b, v13.16b\n" - "movi v8.4s, #0x0\n" - ".inst 0x4e83a4e8 // smmla v8.4s, v7.16b, v3.16b\n" - "ldr q7, [x22, #0x70]\n" - "add x22, x22, #0x80\n" - ".inst 0x4e8fa4ea // smmla v10.4s, v7.16b, v15.16b\n" - ".inst 0x4e80a4a8 // smmla v8.4s, v5.16b, v0.16b\n" - "uzp1 v5.2d, v9.2d, v6.2d\n" - "uzp2 v9.2d, v9.2d, v6.2d\n" - "scvtf v5.4s, v5.4s, #0x4\n" - ".inst 0x4e84a448 // smmla v8.4s, v2.16b, v4.16b\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "fmla v23.4s, v5.4s, v1.4s\n" - ".inst 0x4e8da4e8 // smmla v8.4s, v7.16b, v13.16b\n" - "fmla v22.4s, v9.4s, v1.4s\n" - "uzp1 v2.2d, v10.2d, v8.2d\n" - "uzp2 v10.2d, v10.2d, v8.2d\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "fmla v21.4s, v2.4s, v1.4s\n" - "fmla v20.4s, v10.4s, v1.4s\n" - "ldr q2, [x21, #0x0]\n" - "ldr q10, [x21, #0x10]\n" - "movi v6.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "ldr q5, [x21, #0x20]\n" - "ldr q8, [x21, #0x30]\n" - "movi v7.4s, #0x0\n" - ".inst 0x4e8ea446 // smmla v6.4s, v2.16b, v14.16b\n" - ".inst 0x4e83a449 // smmla v9.4s, v2.16b, v3.16b\n" - "ldr q2, [x21, #0x40]\n" - ".inst 0x4e8ea547 // smmla v7.4s, v10.16b, v14.16b\n" - "ldr q14, [x21, #0x50]\n" - ".inst 0x4e8ca4a6 // smmla v6.4s, v5.16b, v12.16b\n" - ".inst 0x4e80a4a9 // smmla v9.4s, v5.16b, v0.16b\n" - "ldr q5, [x21, #0x60]\n" - ".inst 0x4e8ca507 // smmla v7.4s, v8.16b, v12.16b\n" - "ldr q12, [x21, #0x70]\n" - "add x21, x21, #0x80\n" - ".inst 0x4e8ba446 // smmla v6.4s, v2.16b, v11.16b\n" - ".inst 0x4e84a449 // smmla v9.4s, v2.16b, v4.16b\n" - "movi v2.4s, #0x0\n" - ".inst 0x4e83a542 // smmla v2.4s, v10.16b, v3.16b\n" - ".inst 0x4e8ba5c7 // smmla v7.4s, v14.16b, v11.16b\n" - ".inst 0x4e8fa4a6 // smmla v6.4s, v5.16b, v15.16b\n" - ".inst 0x4e80a502 // smmla v2.4s, v8.16b, v0.16b\n" - ".inst 0x4e8da4a9 // smmla v9.4s, v5.16b, v13.16b\n" - ".inst 0x4e8fa587 // smmla v7.4s, v12.16b, v15.16b\n" - ".inst 0x4e84a5c2 // smmla v2.4s, v14.16b, v4.16b\n" - "uzp1 v11.2d, v6.2d, v9.2d\n" - "uzp2 v14.2d, v6.2d, v9.2d\n" - "scvtf v11.4s, v11.4s, #0x4\n" - ".inst 0x4e8da582 // smmla v2.4s, v12.16b, v13.16b\n" - "scvtf v14.4s, v14.4s, #0x4\n" - "fmla v19.4s, v11.4s, v1.4s\n" - "uzp1 v9.2d, v7.2d, v2.2d\n" - "uzp2 v0.2d, v7.2d, v2.2d\n" - "fmla v18.4s, v14.4s, v1.4s\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "fmla v17.4s, v9.4s, v1.4s\n" - "fmla v16.4s, v0.4s, v1.4s\n" - "subs x20, x20, #0x1\n" - "bgt 3b\n" - "ld1 { v11.4s }, [x27]\n" - "ld1 { v10.4s }, [x23]\n" - "add x27, x27, #0x10\n" - "add x23, x23, #0x10\n" - "ld1 { v9.4s }, [x22]\n" - "ld1 { v8.4s }, [x21]\n" - "add x22, x22, #0x10\n" - "add x21, x21, #0x10\n" - "ldr q7, [x11, #0x0]\n" - "ldr q6, [x27, #0x0]\n" - "add x20, %x[clamp_vals], #0x4\n" - "cmp x10, #0x4\n" - "ldr q5, [x23, #0x0]\n" - "ldr q4, [x22, #0x0]\n" - "scvtf v11.4s, v11.4s\n" - "scvtf v10.4s, v10.4s\n" - "ldr q3, [x21, #0x0]\n" - "ldr q2, [x11, #0x10]\n" - "scvtf v9.4s, v9.4s\n" - "scvtf v8.4s, v8.4s\n" - "ld1r { v1.4s }, [%x[clamp_vals]]\n" - "ld1r { v0.4s }, [x20]\n" - "add x11, x11, #0x20\n" - "fmla v31.4s, v7.4s, v11.s[0]\n" - "fmla v30.4s, v7.4s, v11.s[1]\n" - "fmla v29.4s, v7.4s, v11.s[2]\n" - "fmla v28.4s, v7.4s, v11.s[3]\n" - "fmla v27.4s, v7.4s, v10.s[0]\n" - "fmla v26.4s, v7.4s, v10.s[1]\n" - "fmla v25.4s, v7.4s, v10.s[2]\n" - "fmla v24.4s, v7.4s, v10.s[3]\n" - "fmla v23.4s, v7.4s, v9.s[0]\n" - "fmul v31.4s, v31.4s, v6.s[0]\n" - "fmla v22.4s, v7.4s, v9.s[1]\n" - "fmla v21.4s, v7.4s, v9.s[2]\n" - "fmul v30.4s, v30.4s, v6.s[1]\n" - "fmla v20.4s, v7.4s, v9.s[3]\n" - "fmla v19.4s, v7.4s, v8.s[0]\n" - "fmul v29.4s, v29.4s, v6.s[2]\n" - "fmla v18.4s, v7.4s, v8.s[1]\n" - "fmla v17.4s, v7.4s, v8.s[2]\n" - "fmul v28.4s, v28.4s, v6.s[3]\n" - "fmla v16.4s, v7.4s, v8.s[3]\n" - "fmul v27.4s, v27.4s, v5.s[0]\n" - "fmul v26.4s, v26.4s, v5.s[1]\n" - "fmul v25.4s, v25.4s, v5.s[2]\n" - "fmul v24.4s, v24.4s, v5.s[3]\n" - "fmul v23.4s, v23.4s, v4.s[0]\n" - "fmul v22.4s, v22.4s, v4.s[1]\n" - "fmul v21.4s, v21.4s, v4.s[2]\n" - "fmul v20.4s, v20.4s, v4.s[3]\n" - "fmul v19.4s, v19.4s, v3.s[0]\n" - "fmul v18.4s, v18.4s, v3.s[1]\n" - "fmul v17.4s, v17.4s, v3.s[2]\n" - "fmul v16.4s, v16.4s, v3.s[3]\n" - "fadd v31.4s, v31.4s, v2.4s\n" - "fadd v30.4s, v30.4s, v2.4s\n" - "fadd v29.4s, v29.4s, v2.4s\n" - "fadd v28.4s, v28.4s, v2.4s\n" - "fadd v27.4s, v27.4s, v2.4s\n" - "fadd v26.4s, v26.4s, v2.4s\n" - "fadd v25.4s, v25.4s, v2.4s\n" - "fadd v24.4s, v24.4s, v2.4s\n" - "fadd v23.4s, v23.4s, v2.4s\n" - "fadd v22.4s, v22.4s, v2.4s\n" - "fadd v21.4s, v21.4s, v2.4s\n" - "fadd v20.4s, v20.4s, v2.4s\n" - "fadd v19.4s, v19.4s, v2.4s\n" - "fadd v18.4s, v18.4s, v2.4s\n" - "fadd v17.4s, v17.4s, v2.4s\n" - "fadd v16.4s, v16.4s, v2.4s\n" - "fmax v31.4s, v31.4s, v1.4s\n" - "fmax v30.4s, v30.4s, v1.4s\n" - "fmax v29.4s, v29.4s, v1.4s\n" - "fmax v28.4s, v28.4s, v1.4s\n" - "fmax v27.4s, v27.4s, v1.4s\n" - "fmax v26.4s, v26.4s, v1.4s\n" - "fmax v25.4s, v25.4s, v1.4s\n" - "fmax v24.4s, v24.4s, v1.4s\n" - "fmax v23.4s, v23.4s, v1.4s\n" - "fmax v22.4s, v22.4s, v1.4s\n" - "fmax v21.4s, v21.4s, v1.4s\n" - "fmax v20.4s, v20.4s, v1.4s\n" - "fmax v19.4s, v19.4s, v1.4s\n" - "fmax v18.4s, v18.4s, v1.4s\n" - "fmax v17.4s, v17.4s, v1.4s\n" - "fmax v16.4s, v16.4s, v1.4s\n" - "fmin v31.4s, v31.4s, v0.4s\n" - "fmin v30.4s, v30.4s, v0.4s\n" - "fmin v29.4s, v29.4s, v0.4s\n" - "fmin v28.4s, v28.4s, v0.4s\n" - "fmin v27.4s, v27.4s, v0.4s\n" - "fmin v26.4s, v26.4s, v0.4s\n" - "fmin v25.4s, v25.4s, v0.4s\n" - "fmin v24.4s, v24.4s, v0.4s\n" - "fmin v23.4s, v23.4s, v0.4s\n" - "fmin v22.4s, v22.4s, v0.4s\n" - "fmin v21.4s, v21.4s, v0.4s\n" - "fmin v20.4s, v20.4s, v0.4s\n" - "fmin v19.4s, v19.4s, v0.4s\n" - "fmin v18.4s, v18.4s, v0.4s\n" - "fmin v17.4s, v17.4s, v0.4s\n" - "fmin v16.4s, v16.4s, v0.4s\n" - "blt 8f\n" - "mov x20, %x[dst]\n" - "str q31, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q29, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q28, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q27, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q26, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q24, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q22, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q21, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q20, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q17, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q16, [x20, #0x0]\n" - "b 13f\n" - "8:" // Partial output - "mov x28, %x[dst]\n" - "add x26, x28, %x[dst_stride_row], LSL #2\n" - "add x25, x26, %x[dst_stride_row], LSL #1\n" - "add x24, x26, %x[dst_stride_row]\n" - "add x23, x25, %x[dst_stride_row]\n" - "add x22, x28, %x[dst_stride_row], LSL #1\n" - "add x21, x28, %x[dst_stride_row]\n" - "add x20, x22, %x[dst_stride_row]\n" - "add x27, x23, %x[dst_stride_row]\n" - "tbz x10, #1, 9f\n" - "st1 { v24.d }[0], [x23], #0x8\n" - "st1 { v25.d }[0], [x25], #0x8\n" - "st1 { v26.d }[0], [x24], #0x8\n" - "st1 { v27.d }[0], [x26], #0x8\n" - "st1 { v28.d }[0], [x20], #0x8\n" - "st1 { v29.d }[0], [x22], #0x8\n" - "st1 { v30.d }[0], [x21], #0x8\n" - "st1 { v31.d }[0], [x28], #0x8\n" - "tbz x10, #0, 10f\n" - "st1 { v24.s }[2], [x23]\n" - "st1 { v25.s }[2], [x25]\n" - "st1 { v26.s }[2], [x24]\n" - "st1 { v27.s }[2], [x26]\n" - "st1 { v28.s }[2], [x20]\n" - "st1 { v29.s }[2], [x22]\n" - "st1 { v30.s }[2], [x21]\n" - "st1 { v31.s }[2], [x28]\n" - "b 10f\n" - "9:" // Output block 0: partial_1_0 - "st1 { v24.s }[0], [x23]\n" - "st1 { v25.s }[0], [x25]\n" - "st1 { v26.s }[0], [x24]\n" - "st1 { v27.s }[0], [x26]\n" - "st1 { v28.s }[0], [x20]\n" - "st1 { v29.s }[0], [x22]\n" - "st1 { v30.s }[0], [x21]\n" - "st1 { v31.s }[0], [x28]\n" - "10:" // Output block 0: Done - "add x26, x27, %x[dst_stride_row], LSL #2\n" - "add x25, x27, %x[dst_stride_row], LSL #1\n" - "add x24, x26, %x[dst_stride_row], LSL #1\n" - "add x23, x27, %x[dst_stride_row]\n" - "add x22, x25, %x[dst_stride_row]\n" - "add x21, x26, %x[dst_stride_row]\n" - "add x20, x24, %x[dst_stride_row]\n" - "tbz x10, #1, 11f\n" - "st1 { v16.d }[0], [x20], #0x8\n" - "st1 { v17.d }[0], [x24], #0x8\n" - "st1 { v18.d }[0], [x21], #0x8\n" - "st1 { v19.d }[0], [x26], #0x8\n" - "st1 { v20.d }[0], [x22], #0x8\n" - "st1 { v21.d }[0], [x25], #0x8\n" - "st1 { v22.d }[0], [x23], #0x8\n" - "st1 { v23.d }[0], [x27], #0x8\n" - "tbz x10, #0, 12f\n" - "st1 { v16.s }[2], [x20]\n" - "st1 { v17.s }[2], [x24]\n" - "st1 { v18.s }[2], [x21]\n" - "st1 { v19.s }[2], [x26]\n" - "st1 { v20.s }[2], [x22]\n" - "st1 { v21.s }[2], [x25]\n" - "st1 { v22.s }[2], [x23]\n" - "st1 { v23.s }[2], [x27]\n" - "b 12f\n" - "11:" // Output block 1: partial_1_0 - "st1 { v16.s }[0], [x20]\n" - "st1 { v17.s }[0], [x24]\n" - "st1 { v18.s }[0], [x21]\n" - "st1 { v19.s }[0], [x26]\n" - "st1 { v20.s }[0], [x22]\n" - "st1 { v21.s }[0], [x25]\n" - "st1 { v22.s }[0], [x23]\n" - "st1 { v23.s }[0], [x27]\n" - "12:" // Output block 1: Done - "13:" // Output stage exit - "subs x10, x10, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 2b\n" - "mov x20, #0x4\n" - "sub x13, x13, #0x10\n" - "cmp x13, #0x10\n" - "mov %x[dst], x9\n" - "madd %x[lhs_packed], x20, x12, %x[lhs_packed]\n" - "bge 1b\n" - "14:" // Row loop skip - "cbz x13, 23f\n" - "15:" // Row tail: Row loop - "mov x26, %x[rhs_packed]\n" - "mov x25, %x[n]\n" - "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" - "16:" // Row tail: Column loop - "movi v31.16b, #0x0\n" - "movi v30.16b, #0x0\n" - "mov x27, %x[lhs_packed]\n" - "mov x20, %x[num_blocks]\n" - "movi v29.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "17:" // Row tail: Block loop - "ldr q9, [x26, #0x0]\n" - "ldr q8, [x26, #0x10]\n" - "movi v7.4s, #0x0\n" - "movi v6.4s, #0x0\n" - "ldr q5, [x27, #0x0]\n" - "ldr q4, [x27, #0x10]\n" - "movi v3.4s, #0x0\n" - "movi v2.4s, #0x0\n" - "ldr q1, [x26, #0x20]\n" - "ldr q0, [x26, #0x30]\n" - "movi v27.16b, #0xf0\n" - "add x26, x26, #0x40\n" - "ldr q26, [x27, #0x20]\n" - "ldr q25, [x27, #0x30]\n" - "shl v24.16b, v9.16b, #0x4\n" - "shl v20.16b, v8.16b, #0x4\n" - "ldr q23, [x27, #0x40]\n" - "ldr q22, [x27, #0x50]\n" - "and v9.16b, v9.16b, v27.16b\n" - "and v8.16b, v8.16b, v27.16b\n" - "ldr q21, [x27, #0x60]\n" - "ldr q19, [x27, #0x70]\n" - "shl v18.16b, v1.16b, #0x4\n" - "shl v17.16b, v0.16b, #0x4\n" - "ldr d16, [x26, #0x0]\n" - ".inst 0x4e98a4a7 // smmla v7.4s, v5.16b, v24.16b\n" - ".inst 0x4e94a4a6 // smmla v6.4s, v5.16b, v20.16b\n" - "and v1.16b, v1.16b, v27.16b\n" - ".inst 0x4e98a483 // smmla v3.4s, v4.16b, v24.16b\n" - ".inst 0x4e94a482 // smmla v2.4s, v4.16b, v20.16b\n" - "and v0.16b, v0.16b, v27.16b\n" - "add x26, x26, #0x8\n" - "add x27, x27, #0x80\n" - "shll v20.4s, v16.4h, #0x10\n" - ".inst 0x4e92a747 // smmla v7.4s, v26.16b, v18.16b\n" - ".inst 0x4e91a746 // smmla v6.4s, v26.16b, v17.16b\n" - ".inst 0x4e92a723 // smmla v3.4s, v25.16b, v18.16b\n" - ".inst 0x4e91a722 // smmla v2.4s, v25.16b, v17.16b\n" - ".inst 0x4e89a6e7 // smmla v7.4s, v23.16b, v9.16b\n" - ".inst 0x4e88a6e6 // smmla v6.4s, v23.16b, v8.16b\n" - ".inst 0x4e89a6c3 // smmla v3.4s, v22.16b, v9.16b\n" - ".inst 0x4e88a6c2 // smmla v2.4s, v22.16b, v8.16b\n" - ".inst 0x4e81a6a7 // smmla v7.4s, v21.16b, v1.16b\n" - ".inst 0x4e80a6a6 // smmla v6.4s, v21.16b, v0.16b\n" - ".inst 0x4e81a663 // smmla v3.4s, v19.16b, v1.16b\n" - ".inst 0x4e80a662 // smmla v2.4s, v19.16b, v0.16b\n" - "uzp1 v19.2d, v7.2d, v6.2d\n" - "uzp2 v18.2d, v7.2d, v6.2d\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "uzp1 v17.2d, v3.2d, v2.2d\n" - "uzp2 v16.2d, v3.2d, v2.2d\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "fmla v31.4s, v19.4s, v20.4s\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "fmla v30.4s, v18.4s, v20.4s\n" - "fmla v29.4s, v17.4s, v20.4s\n" - "fmla v28.4s, v16.4s, v20.4s\n" - "subs x20, x20, #0x1\n" - "bgt 17b\n" - "ld1 { v21.4s }, [x27]\n" - "ldr q20, [x26, #0x0]\n" - "add x27, x27, #0x10\n" - "add x20, %x[clamp_vals], #0x4\n" - "ldr q19, [x27, #0x0]\n" - "ldr q18, [x26, #0x10]\n" - "cmp x25, #0x4\n" - "add x26, x26, #0x20\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "ld1r { v16.4s }, [x20]\n" - "scvtf v21.4s, v21.4s\n" - "fmla v31.4s, v20.4s, v21.s[0]\n" - "fmla v30.4s, v20.4s, v21.s[1]\n" - "fmla v29.4s, v20.4s, v21.s[2]\n" - "fmla v28.4s, v20.4s, v21.s[3]\n" - "fmul v31.4s, v31.4s, v19.s[0]\n" - "fmul v30.4s, v30.4s, v19.s[1]\n" - "fmul v29.4s, v29.4s, v19.s[2]\n" - "fadd v31.4s, v31.4s, v18.4s\n" - "fmul v28.4s, v28.4s, v19.s[3]\n" - "fadd v30.4s, v30.4s, v18.4s\n" - "fadd v29.4s, v29.4s, v18.4s\n" - "fadd v28.4s, v28.4s, v18.4s\n" - "fmax v31.4s, v31.4s, v17.4s\n" - "fmax v30.4s, v30.4s, v17.4s\n" - "fmax v29.4s, v29.4s, v17.4s\n" - "fmax v28.4s, v28.4s, v17.4s\n" - "fmin v31.4s, v31.4s, v16.4s\n" - "fmin v30.4s, v30.4s, v16.4s\n" - "fmin v29.4s, v29.4s, v16.4s\n" - "fmin v28.4s, v28.4s, v16.4s\n" - "blt 19f\n" - "mov x20, %x[dst]\n" - "cmp x13, #0x1\n" - "str q31, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 22f\n" - "cmp x13, #0x2\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 22f\n" - "cmp x13, #0x3\n" - "str q29, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 22f\n" - "str q28, [x20, #0x0]\n" - "b 22f\n" - "19:" // Row tail: Partial output - "mov x23, %x[dst]\n" - "cmp x13, #0x1\n" - "add x22, x23, %x[dst_stride_row]\n" - "csel x22, x22, x23, GT\n" - "cmp x13, #0x2\n" - "add x21, x23, %x[dst_stride_row], LSL #1\n" - "csel x21, x21, x22, GT\n" - "cmp x13, #0x3\n" - "add x20, x21, %x[dst_stride_row]\n" - "csel x20, x20, x21, GT\n" - "tbz x25, #1, 20f\n" - "st1 { v28.d }[0], [x20], #0x8\n" - "st1 { v29.d }[0], [x21], #0x8\n" - "st1 { v30.d }[0], [x22], #0x8\n" - "st1 { v31.d }[0], [x23], #0x8\n" - "tbz x25, #0, 21f\n" - "st1 { v28.s }[2], [x20]\n" - "st1 { v29.s }[2], [x21]\n" - "st1 { v30.s }[2], [x22]\n" - "st1 { v31.s }[2], [x23]\n" - "b 21f\n" - "20:" // Row tail: Output block 0: partial_1_0 - "st1 { v28.s }[0], [x20]\n" - "st1 { v29.s }[0], [x21]\n" - "st1 { v30.s }[0], [x22]\n" - "st1 { v31.s }[0], [x23]\n" - "21:" // Row tail: Output block 0: Done - "22:" // Row tail: Output stage exit - "subs x25, x25, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 16b\n" - "subs x13, x13, #0x4\n" - "add %x[lhs_packed], %x[lhs_packed], x12\n" - "mov %x[dst], x24\n" - "bgt 15b\n" - "23:" // Row tail: Row loop skip - : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) - : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), - [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) - : "cc", "memory", "v0", "v1", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v2", "v20", - "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v31", "v4", "v5", "v6", "v7", - "v8", "v9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9"); -} - -inline static void kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { KAI_ASSERT((bl % kai_kr) == 0); @@ -736,6 +116,7 @@ inline static void kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( if (m == 0) { return; } + size_t num_subblocks = bl / kai_bl_multiple_of; size_t num_blocks = kai_num_blocks_per_row(k, bl); @@ -1164,23 +545,4 @@ inline static void kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); } - -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( - size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, - size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { - KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(dst_stride_col == sizeof(float)); - - if (m == 0) { - return; - } - if (m >= 16 && bl == 32) { - kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm( - m, n, k, bl, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, scalar_min, scalar_max); - } else { - kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( - m, n, k, bl, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, scalar_min, scalar_max); - } -} #endif // Architectural feature check -- GitLab From 373fd3b9b9defdddafb706e6655967663c7da0c9 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Mon, 30 Sep 2024 14:48:01 +0100 Subject: [PATCH 3/5] Add unit tests and examples for the new 16x4x32 ukernels Signed-off-by: Anitha Raj --- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 13 ++++++++++++ ...matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp | 20 ++++++++++--------- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index 87bdad8f..8568d77f 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -19,6 +19,7 @@ #include "kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" @@ -76,6 +77,18 @@ kai_matmul_ukernel_f32_qa8dxp_qs4c32p ukernel_variants[] = { kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm}, "matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm"}, + {{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm}, + "matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm"}, {{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp index 9ca4aee6..39b62bb8 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -16,6 +16,7 @@ #include "kai/kai_common.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" @@ -40,18 +41,23 @@ namespace kai::test { static auto cpu_has_dotprod_and_bf16 = []() { return cpu_has_dotprod() && cpu_has_bf16(); }; static auto cpu_has_i8mm_and_bf16 = []() { return cpu_has_i8mm() && cpu_has_bf16(); }; -static const std::array, 4> +static const std::array, 5> variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p = {{ UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, cpu_has_dotprod_and_bf16), UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, cpu_has_dotprod_and_bf16), UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, cpu_has_i8mm_and_bf16), UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, cpu_has_i8mm_and_bf16), + UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, cpu_has_i8mm_and_bf16), }}; -class MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p : public UkernelVariantTest {}; +using MatMulTestParams_withBL = std::tuple; + +class UkernelVariantTest_withBL : public ::testing::TestWithParam {}; + +class MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p : public UkernelVariantTest_withBL {}; TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_Transposed) { - const auto& [variant_index, matmul_shape] = GetParam(); + const auto& [variant_index, matmul_shape, bl] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_index); if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { @@ -69,8 +75,6 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_Transpose const auto kr = ukernel_variant.interface.get_kr(); const auto sr = ukernel_variant.interface.get_sr(); - constexpr size_t bl = 32; - // Generates input data. const auto ref_lhs = fill_random(M * K, seed + 0); const auto ref_rhs = fill_random(N * K, seed + 1); @@ -136,7 +140,7 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_Transpose } TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_NonTransposed) { - const auto& [variant_index, matmul_shape] = GetParam(); + const auto& [variant_index, matmul_shape, bl] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_index); if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { @@ -154,8 +158,6 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_NonTransp const auto kr = ukernel_variant.interface.get_kr(); const auto sr = ukernel_variant.interface.get_sr(); - constexpr size_t bl = 32; - // Generates input data. const auto ref_lhs = fill_random(M * K, seed + 0); const auto ref_rhs_transposed = fill_random(N * K, seed + 1); @@ -224,6 +226,6 @@ INSTANTIATE_TEST_SUITE_P( MatMul, MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, testing::Combine( testing::Range(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.size()), - testing::Values(MatMulShape{16, 32, 64}))); + testing::Values(MatMulShape{16, 32, 64}, MatMulShape{8, 32, 64}), testing::Values(32, 64))); } // namespace kai::test -- GitLab From f854e00664f634a2f59c3da3060a7ee7f436dbd6 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 2 Oct 2024 15:46:04 +0100 Subject: [PATCH 4/5] Add more matmul shapes in the example Signed-off-by: Anitha Raj --- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 378 +++++++++--------- 1 file changed, 194 insertions(+), 184 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index 8568d77f..8217562c 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -33,7 +33,13 @@ enum class rhs_format { nxk, kxn, }; - +struct mnk { + size_t m = 0; + size_t n = 0; + size_t k = 0; + size_t bl = 0; +}; +mnk matmul_shapes[] = {{13, 33, 32, 32}, {37, 75, 256, 64}, {16, 32, 64, 32}, {8, 32, 64, 64}}; // Micro-kernel interface struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel; @@ -533,203 +539,207 @@ static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, } int main() { - const size_t m = 37; - const size_t n = 75; - const size_t k = 256; - const size_t bl = 64; + const size_t num_shapes = sizeof(matmul_shapes) / sizeof(matmul_shapes[0]); const size_t seed_lhs = 4568; const size_t seed_rhs = seed_lhs + 4; std::cout << "------------" << std::endl; + for (size_t test_idx = 0; test_idx < num_shapes; ++test_idx) { + size_t m = matmul_shapes[test_idx].m; + size_t n = matmul_shapes[test_idx].n; + size_t k = matmul_shapes[test_idx].k; + size_t bl = matmul_shapes[test_idx].bl; + + printf("\nTEST[%ld, %ld, %ld], with Block Size %ld \n", m, n, k, bl); + // Iterate over the RHS format (NxK or KxN) + for (const rhs_format& format : {rhs_format::nxk, rhs_format::kxn}) { + std::cout << "Testing RHS format = " << (format == rhs_format::nxk ? "N x K" : "K x N") << std::endl; + + const size_t lhs_native_size_f32 = m * k * sizeof(float); + const size_t rhs_native_size_f32 = n * k * sizeof(float); + const size_t rhs_native_size_qs4c32 = + format == rhs_format::nxk ? n * get_rhs_native_stride(k) : k * get_rhs_native_stride(n); + const size_t rhs_scales_size_bf16 = n * get_rhs_scale_stride(k, bl); + + // Allocate the memory + uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; + uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; + uint8_t* rhs_native_mtx_qs4c32 = new uint8_t[rhs_native_size_qs4c32]; + uint8_t* rhs_scales_mtx_bf16 = new uint8_t[rhs_scales_size_bf16]; + + fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); + fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); + + quant_qs4cx_f32( + n, k, bl, // Dimensions + format, // Format (NxK or KxN) + (const float*)rhs_native_mtx_f32, // RHS (F32) + rhs_native_mtx_qs4c32, // RHS (QS4C32) + (uint16_t*)rhs_scales_mtx_bf16); // Scales (Bf16) + + delete[] rhs_native_mtx_f32; + + //----------- REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + // Memory sizes for the reference implementation + // After dynamically quantized the LHS matrix, we have the scale and offset for each + // row. The scale (f32) and offset (int32) are stored at the beginning of each row + const size_t lhs_ref_size_qa8dx = m * (k + sizeof(int32_t) + sizeof(float)); + const size_t dst_ref_size_f32 = m * n * sizeof(float); + + uint8_t* lhs_ref_mtx_qa8dx = new uint8_t[lhs_ref_size_qa8dx]; + uint8_t* dst_ref_mtx_f32 = new uint8_t[dst_ref_size_f32]; + + ref_quant_qa8dx_f32(m, k, (const float*)lhs_native_mtx_f32, (int8_t*)lhs_ref_mtx_qa8dx); + + ref_matmul_f32_qa8dx_qs4c32( + m, n, k, // Dimensions + bl, // Block length + format, // Format (NxK or KxN) + (const int8_t*)lhs_ref_mtx_qa8dx, // LHS + (const uint8_t*)rhs_native_mtx_qs4c32, // RHS + (const uint16_t*)rhs_scales_mtx_bf16, // Scale + (float*)dst_ref_mtx_f32, // DST + -FLT_MAX, FLT_MAX); // Min and max for the clamp operation + + // Remove the unnecessary buffer + delete[] lhs_ref_mtx_qa8dx; + + //----------- END REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + + //----------- MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + for (size_t idx_variant = 0; idx_variant < num_ukernel_variants; ++idx_variant) { + // Get the packing parameters + const size_t mr = ukernel_variants[idx_variant].ukernel.get_mr(); + const size_t nr = ukernel_variants[idx_variant].ukernel.get_nr(); + const size_t kr = ukernel_variants[idx_variant].ukernel.get_kr(); + const size_t sr = ukernel_variants[idx_variant].ukernel.get_sr(); + + // Get the size in bytes for the packed matrices + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); + size_t rhs_packed_size = 0; + + if (format == rhs_format::nxk) { + rhs_packed_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, kai_dt_bf16); - // Iterate over the RHS format (NxK or KxN) - for (const rhs_format& format : {rhs_format::nxk, rhs_format::kxn}) { - std::cout << "Testing RHS format = " << (format == rhs_format::nxk ? "N x K" : "K x N") << std::endl; - - const size_t lhs_native_size_f32 = m * k * sizeof(float); - const size_t rhs_native_size_f32 = n * k * sizeof(float); - const size_t rhs_native_size_qs4c32 = - format == rhs_format::nxk ? n * get_rhs_native_stride(k) : k * get_rhs_native_stride(n); - const size_t rhs_scales_size_bf16 = n * get_rhs_scale_stride(k, bl); - - // Allocate the memory - uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; - uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; - uint8_t* rhs_native_mtx_qs4c32 = new uint8_t[rhs_native_size_qs4c32]; - uint8_t* rhs_scales_mtx_bf16 = new uint8_t[rhs_scales_size_bf16]; - - fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); - fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); - - quant_qs4cx_f32( - n, k, bl, // Dimensions - format, // Format (NxK or KxN) - (const float*)rhs_native_mtx_f32, // RHS (F32) - rhs_native_mtx_qs4c32, // RHS (QS4C32) - (uint16_t*)rhs_scales_mtx_bf16); // Scales (Bf16) - - delete[] rhs_native_mtx_f32; - - //----------- REFERENCE IMPLEMENTATION - //------------------------------------ - //------------------------------------ - // Memory sizes for the reference implementation - // After dynamically quantized the LHS matrix, we have the scale and offset for each - // row. The scale (f32) and offset (int32) are stored at the beginning of each row - const size_t lhs_ref_size_qa8dx = m * (k + sizeof(int32_t) + sizeof(float)); - const size_t dst_ref_size_f32 = m * n * sizeof(float); - - uint8_t* lhs_ref_mtx_qa8dx = new uint8_t[lhs_ref_size_qa8dx]; - uint8_t* dst_ref_mtx_f32 = new uint8_t[dst_ref_size_f32]; - - ref_quant_qa8dx_f32(m, k, (const float*)lhs_native_mtx_f32, (int8_t*)lhs_ref_mtx_qa8dx); - - ref_matmul_f32_qa8dx_qs4c32( - m, n, k, // Dimensions - bl, // Block length - format, // Format (NxK or KxN) - (const int8_t*)lhs_ref_mtx_qa8dx, // LHS - (const uint8_t*)rhs_native_mtx_qs4c32, // RHS - (const uint16_t*)rhs_scales_mtx_bf16, // Scale - (float*)dst_ref_mtx_f32, // DST - -FLT_MAX, FLT_MAX); // Min and max for the clamp operation - - // Remove the unnecessary buffer - delete[] lhs_ref_mtx_qa8dx; - - //----------- END REFERENCE IMPLEMENTATION - //------------------------------------ - //------------------------------------ - - //----------- MICRO-KERNELS TESTS - //------------------------------------ - //------------------------------------ - for (size_t idx_variant = 0; idx_variant < num_ukernel_variants; ++idx_variant) { - // Get the packing parameters - const size_t mr = ukernel_variants[idx_variant].ukernel.get_mr(); - const size_t nr = ukernel_variants[idx_variant].ukernel.get_nr(); - const size_t kr = ukernel_variants[idx_variant].ukernel.get_kr(); - const size_t sr = ukernel_variants[idx_variant].ukernel.get_sr(); - - // Get the size in bytes for the packed matrices - const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); - size_t rhs_packed_size = 0; - - if (format == rhs_format::nxk) { - rhs_packed_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, kai_dt_bf16); - - } else { - rhs_packed_size = - kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, kai_dt_bf16); - } + } else { + rhs_packed_size = + kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, kai_dt_bf16); + } - const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); - - // Allocate the matrices - uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; - uint8_t* rhs_packed_mtx_qs4c32 = new uint8_t[rhs_packed_size]; - uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; - - memset(dst_act_mtx_f32, 0, dst_size); - - // If the RHS matrix contains constant values, the packing can be performed - // only once - if (format == rhs_format::nxk) { - struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - params.scale_dt = kai_dt_bf16; - - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - 1, n, k, // Dimensions - nr, kr, sr, // Packing arguments - bl, // Block length - (const uint8_t*)(rhs_native_mtx_qs4c32), // RHS - get_rhs_native_stride(k), // RHS stride - NULL, // Bias - rhs_scales_mtx_bf16, // Scale - get_rhs_scale_stride(k, bl), // Scale stride - rhs_packed_mtx_qs4c32, // RHS packed - 0, ¶ms); - - } else { - struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - params.scale_dt = kai_dt_bf16; - - kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( - 1, n, k, // Dimensions - nr, kr, sr, // Packing arguments - bl, // Block length - (const uint8_t*)(rhs_native_mtx_qs4c32), // RHS - get_rhs_native_stride(n), // RHS stride - NULL, // Bias - rhs_scales_mtx_bf16, // Scale - get_rhs_scale_stride(k, bl), // Scale stride - rhs_packed_mtx_qs4c32, // RHS packed - 0, ¶ms); - } + const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); + + // Allocate the matrices + uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; + uint8_t* rhs_packed_mtx_qs4c32 = new uint8_t[rhs_packed_size]; + uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; + + memset(dst_act_mtx_f32, 0, dst_size); + + // If the RHS matrix contains constant values, the packing can be performed + // only once + if (format == rhs_format::nxk) { + struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = kai_dt_bf16; + + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + 1, n, k, // Dimensions + nr, kr, sr, // Packing arguments + bl, // Block length + (const uint8_t*)(rhs_native_mtx_qs4c32), // RHS + get_rhs_native_stride(k), // RHS stride + NULL, // Bias + rhs_scales_mtx_bf16, // Scale + get_rhs_scale_stride(k, bl), // Scale stride + rhs_packed_mtx_qs4c32, // RHS packed + 0, ¶ms); - const auto time_s = std::chrono::high_resolution_clock::now(); - - // LHS packing - kai_run_lhs_quant_pack_qai8dxp_f32( - m, k, // Dimensions - mr, kr, sr, 0, // Packing arguments - (const float*)lhs_native_mtx_f32, // LHS - k * sizeof(float), // LHS stride - lhs_packed_mtx_qa8dx); // LHS packed - - // Matmul - { - const size_t dst_stride = n * sizeof(float); - const size_t lhs_offset = ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k); - const size_t rhs_offset = ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(0, k, bl); - const size_t dst_offset = ukernel_variants[idx_variant].ukernel.get_dst_offset(0, 0, dst_stride); - - const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); - const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4c32 + rhs_offset); - float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); - - ukernel_variants[idx_variant].ukernel.run_matmul( - m, n, k, // Dimensions - bl, // Block length - lhs_ptr, // LHS packed - rhs_ptr, // RHS packed - dst_ptr, // DST - dst_stride, // DST stride (row) - sizeof(float), // DST stride (col) - -FLT_MAX, FLT_MAX // Min and max for the clamp operation - ); - } + } else { + struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = kai_dt_bf16; + + kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + 1, n, k, // Dimensions + nr, kr, sr, // Packing arguments + bl, // Block length + (const uint8_t*)(rhs_native_mtx_qs4c32), // RHS + get_rhs_native_stride(n), // RHS stride + NULL, // Bias + rhs_scales_mtx_bf16, // Scale + get_rhs_scale_stride(k, bl), // Scale stride + rhs_packed_mtx_qs4c32, // RHS packed + 0, ¶ms); + } + + const auto time_s = std::chrono::high_resolution_clock::now(); + + // LHS packing + kai_run_lhs_quant_pack_qai8dxp_f32( + m, k, // Dimensions + mr, kr, sr, 0, // Packing arguments + (const float*)lhs_native_mtx_f32, // LHS + k * sizeof(float), // LHS stride + lhs_packed_mtx_qa8dx); // LHS packed + + // Matmul + { + const size_t dst_stride = n * sizeof(float); + const size_t lhs_offset = ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k); + const size_t rhs_offset = ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(0, k, bl); + const size_t dst_offset = ukernel_variants[idx_variant].ukernel.get_dst_offset(0, 0, dst_stride); + + const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); + const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4c32 + rhs_offset); + float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); + + ukernel_variants[idx_variant].ukernel.run_matmul( + m, n, k, // Dimensions + bl, // Block length + lhs_ptr, // LHS packed + rhs_ptr, // RHS packed + dst_ptr, // DST + dst_stride, // DST stride (row) + sizeof(float), // DST stride (col) + -FLT_MAX, FLT_MAX // Min and max for the clamp operation + ); + } - const auto time_e = std::chrono::high_resolution_clock::now(); + const auto time_e = std::chrono::high_resolution_clock::now(); - const auto elap = std::chrono::duration_cast(time_e - time_s); + const auto elap = std::chrono::duration_cast(time_e - time_s); - const bool is_valid = - is_output_correct(m, n, 0.0001f, (const float*)dst_ref_mtx_f32, (const float*)dst_act_mtx_f32); + const bool is_valid = + is_output_correct(m, n, 0.0001f, (const float*)dst_ref_mtx_f32, (const float*)dst_act_mtx_f32); - std::cout << "TEST[" << idx_variant << "]: Dynamic quantization + matmul" << std::endl; - std::cout << "- ukernel: " << ukernel_variants[idx_variant].name << std::endl; - if (is_valid) { - std::cout << "- Status: PASSED" << std::endl; - std::cout << "- Performance: " << elap.count() << " us" << std::endl; - } else { - std::cout << "Status: FAILED" << std::endl; + std::cout << "TEST[" << idx_variant << "]: Dynamic quantization + matmul" << std::endl; + std::cout << "- ukernel: " << ukernel_variants[idx_variant].name << std::endl; + if (is_valid) { + std::cout << "- Status: PASSED" << std::endl; + std::cout << "- Performance: " << elap.count() << " us" << std::endl; + } else { + std::cout << "Status: FAILED" << std::endl; + } + std::cout << "------------" << std::endl; + delete[] lhs_packed_mtx_qa8dx; + delete[] rhs_packed_mtx_qs4c32; + delete[] dst_act_mtx_f32; } - std::cout << "------------" << std::endl; - delete[] lhs_packed_mtx_qa8dx; - delete[] rhs_packed_mtx_qs4c32; - delete[] dst_act_mtx_f32; + delete[] lhs_native_mtx_f32; + delete[] rhs_native_mtx_qs4c32; + delete[] rhs_scales_mtx_bf16; + delete[] dst_ref_mtx_f32; } - delete[] lhs_native_mtx_f32; - delete[] rhs_native_mtx_qs4c32; - delete[] rhs_scales_mtx_bf16; - delete[] dst_ref_mtx_f32; } } -- GitLab From 3e3e04d435a52d98a9c1da265d208d7f6b2ed6a5 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 3 Oct 2024 11:57:12 +0100 Subject: [PATCH 5/5] Address review comments Signed-off-by: Anitha Raj --- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index 8217562c..03f8b332 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -539,7 +539,7 @@ static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, } int main() { - const size_t num_shapes = sizeof(matmul_shapes) / sizeof(matmul_shapes[0]); + const size_t num_shapes = std::size(matmul_shapes); const size_t seed_lhs = 4568; const size_t seed_rhs = seed_lhs + 4; @@ -551,7 +551,7 @@ int main() { size_t k = matmul_shapes[test_idx].k; size_t bl = matmul_shapes[test_idx].bl; - printf("\nTEST[%ld, %ld, %ld], with Block Size %ld \n", m, n, k, bl); + std::cout << "\nTEST[" << m << ", " << n << "," << k << "] with Block Size " << bl << "\n"; // Iterate over the RHS format (NxK or KxN) for (const rhs_format& format : {rhs_format::nxk, rhs_format::kxn}) { std::cout << "Testing RHS format = " << (format == rhs_format::nxk ? "N x K" : "K x N") << std::endl; @@ -646,7 +646,7 @@ int main() { // If the RHS matrix contains constant values, the packing can be performed // only once if (format == rhs_format::nxk) { - struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; + kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; params.scale_dt = kai_dt_bf16; @@ -664,7 +664,7 @@ int main() { 0, ¶ms); } else { - struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params; + kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; params.scale_dt = kai_dt_bf16; -- GitLab