From 0a74f65ce87ca231b2dfa725f49e8a0aadbc60a5 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 25 Oct 2024 17:18:06 +0100 Subject: [PATCH] Integrate ASM matmul micro-kernels for F32 <- QSI8D32 x QSI4C32 - Integrate ASM matmul micro-kernel for the GeMV and GeMM variants - Refactor the LHS and RHS packing function to load the scale from the beginning of the block - Add timer in the example for profiling the ukernels Signed-off-by: Gian Marco Iodice --- .../matmul_clamp_f32_qsi8d32p_qsi4c32p.cpp | 10 +- ...8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c | 169 ++--- ...qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c | 602 +++++++++++------- .../pack/kai_lhs_quant_pack_qsi8d32p_f32.c | 11 +- ...s_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c | 16 +- 5 files changed, 488 insertions(+), 320 deletions(-) diff --git a/examples/matmul_clamp_f32_qsi8d32p_qsi4c32p/matmul_clamp_f32_qsi8d32p_qsi4c32p.cpp b/examples/matmul_clamp_f32_qsi8d32p_qsi4c32p/matmul_clamp_f32_qsi8d32p_qsi4c32p.cpp index fb455183..39ba7faa 100644 --- a/examples/matmul_clamp_f32_qsi8d32p_qsi4c32p/matmul_clamp_f32_qsi8d32p_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qsi8d32p_qsi4c32p/matmul_clamp_f32_qsi8d32p_qsi4c32p.cpp @@ -8,6 +8,7 @@ #else #include #include +#include #include #include #include @@ -268,7 +269,7 @@ static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, int main(int argc, char** argv) { const size_t bl = 32; // Block length. It must be 32 - const size_t m = 13; + const size_t m = 71; const size_t n = 64; const size_t k = 128; const size_t seed_lhs = 4568; @@ -361,6 +362,8 @@ int main(int argc, char** argv) { rhs_packed_mtx_qs4c32, // RHS packed 0, ¶ms); + const auto time_s = std::chrono::high_resolution_clock::now(); + // LHS packing kai_run_lhs_quant_pack_qsi8d32p_f32( m, k, bl, // Dimensions @@ -391,11 +394,16 @@ int main(int argc, char** argv) { ); } + const auto time_e = std::chrono::high_resolution_clock::now(); + + const auto elap = std::chrono::duration_cast(time_e - time_s); + const bool is_valid = is_output_correct(m, n, 0.0001f, (const float*)dst_ref_mtx_f32, (const float*)dst_act_mtx_f32); if (is_valid) { printf("TEST[%ld] = PASSED\n", idx_variant); + std::cout << "- Performance: " << elap.count() << " us" << std::endl; } else { printf("TEST[%ld] = FAILED\n", idx_variant); } diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c index e41de7f6..ba7c2136 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c @@ -121,88 +121,93 @@ void kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod( } const size_t num_blocks = k / kai_block_size; - const size_t num_cols = n; - const size_t num_rows = m; - const size_t lhs_packed_stride = kai_lhs_packed_stride(k, bl); - - const int8x16_t nibble_mask = vdupq_n_s8(0xF0); - - const uint8_t* lhs_ptr_start = lhs_packed; - - for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_m_step) { - const uint8_t* rhs_ptr = rhs_packed; - for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_n_step) { - // Main f32 accumulator - float32x4_t main_acc = vdupq_n_f32(0.0F); - - const uint8_t* lhs_ptr = lhs_ptr_start; - - for (size_t b = 0; b < num_blocks; b++) { - // Set up RHS - const int8x16_t rhs_raw_vec_0 = vld1q_s8((const int8_t*)rhs_ptr + 0); - const int8x16_t rhs_raw_vec_1 = vld1q_s8((const int8_t*)rhs_ptr + 16); - const int8x16_t rhs_raw_vec_2 = vld1q_s8((const int8_t*)rhs_ptr + 32); - const int8x16_t rhs_raw_vec_3 = vld1q_s8((const int8_t*)rhs_ptr + 48); - - // Low nibble - const int8x16_t rhs_vec_0_0 = vshlq_n_s8(rhs_raw_vec_0, 4); - const int8x16_t rhs_vec_1_0 = vshlq_n_s8(rhs_raw_vec_1, 4); - const int8x16_t rhs_vec_2_0 = vshlq_n_s8(rhs_raw_vec_2, 4); - const int8x16_t rhs_vec_3_0 = vshlq_n_s8(rhs_raw_vec_3, 4); - - // High nibble - const int8x16_t rhs_vec_0_1 = vandq_s8(rhs_raw_vec_0, nibble_mask); - const int8x16_t rhs_vec_1_1 = vandq_s8(rhs_raw_vec_1, nibble_mask); - const int8x16_t rhs_vec_2_1 = vandq_s8(rhs_raw_vec_2, nibble_mask); - const int8x16_t rhs_vec_3_1 = vandq_s8(rhs_raw_vec_3, nibble_mask); - - const int8x16_t lhs_vec_0 = vld1q_s8((const int8_t*)(lhs_ptr + 0)); - const int8x16_t lhs_vec_1 = vld1q_s8((const int8_t*)(lhs_ptr + 16)); - - int32x4_t iacc0011 = vdupq_n_s32(0); - int32x4_t iacc2233 = vdupq_n_s32(0); - - int8x16_t t; - - t = vcombine_s8(vget_low_s8(lhs_vec_0), vget_low_s8(lhs_vec_0)); - iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_0, t); - iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_0, t); - t = vcombine_s8(vget_high_s8(lhs_vec_0), vget_high_s8(lhs_vec_0)); - iacc0011 = vdotq_s32(iacc0011, rhs_vec_2_0, t); - iacc2233 = vdotq_s32(iacc2233, rhs_vec_3_0, t); - t = vcombine_s8(vget_low_s8(lhs_vec_1), vget_low_s8(lhs_vec_1)); - iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_1, t); - iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_1, t); - t = vcombine_s8(vget_high_s8(lhs_vec_1), vget_high_s8(lhs_vec_1)); - iacc0011 = vdotq_s32(iacc0011, rhs_vec_2_1, t); - iacc2233 = vdotq_s32(iacc2233, rhs_vec_3_1, t); - - int32x4_t iacc = vpaddq_s32(iacc0011, iacc2233); - - // RHS scale values - const float16x4_t col_scale_f16 = vld1_f16((const float16_t*)(rhs_ptr + 64)); - const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16); - - // LHS scale values - const float16x4_t row_scale_f16 = vld1_dup_f16((const float16_t*)(lhs_ptr + 32)); - const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16); - - lhs_ptr += 34; - rhs_ptr += 72; - - main_acc = vfmaq_f32(main_acc, vcvtq_f32_s32(iacc), vmulq_f32(col_scale_f32, row_scale_f32)); - } - - const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); - const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); - - main_acc = vmaxq_f32(main_acc, vmin_f32); - main_acc = vminq_f32(main_acc, vmax_f32); - - vst1q_f32((float*)((uint8_t*)dst + col_idx * sizeof(float) + row_idx * dst_stride_row), main_acc); - } - lhs_ptr_start += lhs_packed_stride; - } + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x26, #0x22\n" + "movi v1.16b, #0xf0\n" + "mov x25, %x[m]\n" + "mul x26, %x[num_blocks], x26\n" + "1:" // Row loop + "mov x24, %x[rhs_packed]\n" + "mov x23, %x[n]\n" + "add x22, %x[dst], %x[dst_stride_row]\n" + "2:" // Column loop + "mov x21, %x[lhs_packed]\n" + "movi v0.16b, #0x0\n" + "mov x20, %x[num_blocks]\n" + "3:" // Block loop + "ldr d16, [x24, #0x0]\n" + "ld1r { v31.8h }, [x21]\n" + "add x24, x24, #0x8\n" + "add x21, x21, #0x2\n" + "ldr q30, [x24, #0x0]\n" + "ldr q29, [x24, #0x10]\n" + "movi v28.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "ld1r { v26.2d }, [x21], #0x8\n" + "ldr q25, [x24, #0x20]\n" + "sub x20, x20, #0x1\n" + "ldr q24, [x24, #0x30]\n" + "fcvtl v31.4s, v31.4h\n" + "fcvtl v23.4s, v16.4h\n" + "add x24, x24, #0x40\n" + "ld1r { v22.2d }, [x21], #0x8\n" + "shl v21.16b, v30.16b, #0x4\n" + "shl v20.16b, v29.16b, #0x4\n" + "ld1r { v19.2d }, [x21], #0x8\n" + "ld1r { v18.2d }, [x21], #0x8\n" + "shl v17.16b, v25.16b, #0x4\n" + "and v30.16b, v30.16b, v1.16b\n" + "shl v16.16b, v24.16b, #0x4\n" + "and v29.16b, v29.16b, v1.16b\n" + ".inst 0x4e9a96bc // sdot v28.4s, v21.16b, v26.16b\n" + ".inst 0x4e9a969b // sdot v27.4s, v20.16b, v26.16b\n" + "and v25.16b, v25.16b, v1.16b\n" + "and v24.16b, v24.16b, v1.16b\n" + "fmul v23.4s, v23.4s, v31.4s\n" + ".inst 0x4e96963c // sdot v28.4s, v17.16b, v22.16b\n" + ".inst 0x4e96961b // sdot v27.4s, v16.16b, v22.16b\n" + ".inst 0x4e9397dc // sdot v28.4s, v30.16b, v19.16b\n" + ".inst 0x4e9397bb // sdot v27.4s, v29.16b, v19.16b\n" + ".inst 0x4e92973c // sdot v28.4s, v25.16b, v18.16b\n" + ".inst 0x4e92971b // sdot v27.4s, v24.16b, v18.16b\n" + "addp v28.4s, v28.4s, v27.4s\n" + "scvtf v28.4s, v28.4s, #0x4\n" + "fmla v0.4s, v28.4s, v23.4s\n" + "cbnz x20, 3b\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x23, #0x4\n" + "ld1r { v16.4s }, [x20]\n" + "fmax v0.4s, v0.4s, v17.4s\n" + "fmin v0.4s, v0.4s, v16.4s\n" + "blt 4f\n" + "str q0, [%x[dst], #0x0]\n" + "b 7f\n" + "4:" // Partial output + "mov x20, %x[dst]\n" + "tbz x23, #1, 5f\n" + "st1 { v0.d }[0], [x20], #0x8\n" + "tbz x23, #0, 6f\n" + "st1 { v0.s }[2], [x20]\n" + "b 6f\n" + "5:" // Output block 0: partial_1_0 + "st1 { v0.s }[0], [x20]\n" + "6:" // Output block 0: Done + "7:" // Stores done + "subs x23, x23, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "subs x25, x25, #0x1\n" + "add %x[lhs_packed], %x[lhs_packed], x26\n" + "mov %x[dst], x22\n" + "bgt 1b\n" + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26"); } #endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c index 78c9ec0b..d10a6eda 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -120,232 +120,384 @@ void kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm( return; } - const size_t lhs_packed_stride = kai_lhs_packed_stride(k, bl); const size_t num_blocks = k / kai_block_size; - const size_t num_cols = n; - const size_t num_rows = m; - - const int8x16_t nibble_mask = vdupq_n_s8(0xF0); - - const uint8_t* lhs_ptr_start = lhs_packed; - - for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_m_step) { - const size_t step_packed_row = (int32_t)num_rows - (int32_t)row_idx <= 4 ? 0 : 1; - - const uint8_t* rhs_ptr = rhs_packed; - - for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_n_step) { - const uint8_t* lhs_ptr = lhs_ptr_start; - - // Main f32 accumulator - float32x4_t main_acc0 = vdupq_n_f32(0.0F); - float32x4_t main_acc1 = vdupq_n_f32(0.0F); - float32x4_t main_acc2 = vdupq_n_f32(0.0F); - float32x4_t main_acc3 = vdupq_n_f32(0.0F); - float32x4_t main_acc4 = vdupq_n_f32(0.0F); - float32x4_t main_acc5 = vdupq_n_f32(0.0F); - float32x4_t main_acc6 = vdupq_n_f32(0.0F); - float32x4_t main_acc7 = vdupq_n_f32(0.0F); - - for (size_t b = 0; b < num_blocks; b++) { - // Set up RHS - const int8x16_t rhs_raw_mat_01_0 = vld1q_s8((const int8_t*)rhs_ptr + 0); - const int8x16_t rhs_raw_mat_23_0 = vld1q_s8((const int8_t*)rhs_ptr + 16); - const int8x16_t rhs_raw_mat_01_1 = vld1q_s8((const int8_t*)rhs_ptr + 32); - const int8x16_t rhs_raw_mat_23_1 = vld1q_s8((const int8_t*)rhs_ptr + 48); - - const float16x4_t col_scale_f16 = vld1_f16((const float16_t*)((const uint8_t*)rhs_ptr + 64)); - const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16); - - // Low nibble - const int8x16_t rhs_mat_01_0 = vshlq_n_s8(rhs_raw_mat_01_0, 4); - const int8x16_t rhs_mat_23_0 = vshlq_n_s8(rhs_raw_mat_23_0, 4); - const int8x16_t rhs_mat_01_1 = vshlq_n_s8(rhs_raw_mat_01_1, 4); - const int8x16_t rhs_mat_23_1 = vshlq_n_s8(rhs_raw_mat_23_1, 4); - - // High nibble - const int8x16_t rhs_mat_01_2 = vandq_s8(rhs_raw_mat_01_0, nibble_mask); - const int8x16_t rhs_mat_23_2 = vandq_s8(rhs_raw_mat_23_0, nibble_mask); - const int8x16_t rhs_mat_01_3 = vandq_s8(rhs_raw_mat_01_1, nibble_mask); - const int8x16_t rhs_mat_23_3 = vandq_s8(rhs_raw_mat_23_1, nibble_mask); - - // Process LHS in pairs of rows - { - const int8x16_t lhs_mat_01_0 = vld1q_s8((const int8_t*)lhs_ptr + 0); - const int8x16_t lhs_mat_23_0 = vld1q_s8((const int8_t*)lhs_ptr + 16); - const int8x16_t lhs_mat_01_1 = vld1q_s8((const int8_t*)lhs_ptr + 32); - const int8x16_t lhs_mat_23_1 = vld1q_s8((const int8_t*)lhs_ptr + 48); - const int8x16_t lhs_mat_01_2 = vld1q_s8((const int8_t*)lhs_ptr + 64); - const int8x16_t lhs_mat_23_2 = vld1q_s8((const int8_t*)lhs_ptr + 80); - const int8x16_t lhs_mat_01_3 = vld1q_s8((const int8_t*)lhs_ptr + 96); - const int8x16_t lhs_mat_23_3 = vld1q_s8((const int8_t*)lhs_ptr + 112); - - // Do the MMLAs into 2x2 matrices - const int32x4_t iacc_mat_00 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), - lhs_mat_01_2, rhs_mat_01_2), - lhs_mat_01_3, rhs_mat_01_3); - const int32x4_t iacc_mat_01 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), - lhs_mat_01_2, rhs_mat_23_2), - lhs_mat_01_3, rhs_mat_23_3); - const int32x4_t iacc_mat_10 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), - lhs_mat_23_2, rhs_mat_01_2), - lhs_mat_23_3, rhs_mat_01_3); - const int32x4_t iacc_mat_11 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), - lhs_mat_23_2, rhs_mat_23_2), - lhs_mat_23_3, rhs_mat_23_3); - - // Straighten out to make 4 row vectors - const int32x4_t iacc_row_0 = vreinterpretq_s32_u64( - vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); - const int32x4_t iacc_row_1 = vreinterpretq_s32_u64( - vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); - const int32x4_t iacc_row_2 = vreinterpretq_s32_u64( - vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); - const int32x4_t iacc_row_3 = vreinterpretq_s32_u64( - vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); - - const float16x4_t row_scale_f16 = vld1_f16((const float16_t*)((const uint8_t*)lhs_ptr + 128)); - const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16); - - main_acc0 = vfmaq_f32( - main_acc0, vcvtq_f32_s32(iacc_row_0), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 0)); - main_acc1 = vfmaq_f32( - main_acc1, vcvtq_f32_s32(iacc_row_1), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 1)); - main_acc2 = vfmaq_f32( - main_acc2, vcvtq_f32_s32(iacc_row_2), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 2)); - main_acc3 = vfmaq_f32( - main_acc3, vcvtq_f32_s32(iacc_row_3), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 3)); - } - - lhs_ptr += step_packed_row * lhs_packed_stride; - - { - const int8x16_t lhs_mat_01_0 = vld1q_s8((const int8_t*)lhs_ptr + 0); - const int8x16_t lhs_mat_23_0 = vld1q_s8((const int8_t*)lhs_ptr + 16); - const int8x16_t lhs_mat_01_1 = vld1q_s8((const int8_t*)lhs_ptr + 32); - const int8x16_t lhs_mat_23_1 = vld1q_s8((const int8_t*)lhs_ptr + 48); - const int8x16_t lhs_mat_01_2 = vld1q_s8((const int8_t*)lhs_ptr + 64); - const int8x16_t lhs_mat_23_2 = vld1q_s8((const int8_t*)lhs_ptr + 80); - const int8x16_t lhs_mat_01_3 = vld1q_s8((const int8_t*)lhs_ptr + 96); - const int8x16_t lhs_mat_23_3 = vld1q_s8((const int8_t*)lhs_ptr + 112); - - // Do the MMLAs into 2x2 matrices - const int32x4_t iacc_mat_00 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), - lhs_mat_01_2, rhs_mat_01_2), - lhs_mat_01_3, rhs_mat_01_3); - const int32x4_t iacc_mat_01 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), - lhs_mat_01_2, rhs_mat_23_2), - lhs_mat_01_3, rhs_mat_23_3); - const int32x4_t iacc_mat_10 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), - lhs_mat_23_2, rhs_mat_01_2), - lhs_mat_23_3, rhs_mat_01_3); - const int32x4_t iacc_mat_11 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), - lhs_mat_23_2, rhs_mat_23_2), - lhs_mat_23_3, rhs_mat_23_3); - - // Straighten out to make 4 row vectors - const int32x4_t iacc_row_0 = vreinterpretq_s32_u64( - vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); - const int32x4_t iacc_row_1 = vreinterpretq_s32_u64( - vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); - const int32x4_t iacc_row_2 = vreinterpretq_s32_u64( - vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); - const int32x4_t iacc_row_3 = vreinterpretq_s32_u64( - vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); - - const float16x4_t row_scale_f16 = vld1_f16((const float16_t*)((const uint8_t*)lhs_ptr + 128)); - const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16); - - main_acc4 = vfmaq_f32( - main_acc4, vcvtq_f32_s32(iacc_row_0), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 0)); - main_acc5 = vfmaq_f32( - main_acc5, vcvtq_f32_s32(iacc_row_1), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 1)); - main_acc6 = vfmaq_f32( - main_acc6, vcvtq_f32_s32(iacc_row_2), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 2)); - main_acc7 = vfmaq_f32( - main_acc7, vcvtq_f32_s32(iacc_row_3), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 3)); - } - - lhs_ptr -= step_packed_row * lhs_packed_stride; - - lhs_ptr += 136; - rhs_ptr += 72; - } - - const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); - const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); - - main_acc0 = vmaxq_f32(main_acc0, vmin_f32); - main_acc0 = vminq_f32(main_acc0, vmax_f32); - main_acc1 = vmaxq_f32(main_acc1, vmin_f32); - main_acc1 = vminq_f32(main_acc1, vmax_f32); - main_acc2 = vmaxq_f32(main_acc2, vmin_f32); - main_acc2 = vminq_f32(main_acc2, vmax_f32); - main_acc3 = vmaxq_f32(main_acc3, vmin_f32); - main_acc3 = vminq_f32(main_acc3, vmax_f32); - main_acc4 = vmaxq_f32(main_acc4, vmin_f32); - main_acc4 = vminq_f32(main_acc4, vmax_f32); - main_acc5 = vmaxq_f32(main_acc5, vmin_f32); - main_acc5 = vminq_f32(main_acc5, vmax_f32); - main_acc6 = vmaxq_f32(main_acc6, vmin_f32); - main_acc6 = vminq_f32(main_acc6, vmax_f32); - main_acc7 = vmaxq_f32(main_acc7, vmin_f32); - main_acc7 = vminq_f32(main_acc7, vmax_f32); - - // Stores the rows in reverse order to avoid out-of-bound writes. - // Override out-of-bound values with in-bound values - vst1q_f32( - (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 7, m - 1) * dst_stride_row), - main_acc7); - vst1q_f32( - (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 6, m - 1) * dst_stride_row), - main_acc6); - vst1q_f32( - (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 5, m - 1) * dst_stride_row), - main_acc5); - vst1q_f32( - (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 4, m - 1) * dst_stride_row), - main_acc4); - vst1q_f32( - (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 3, m - 1) * dst_stride_row), - main_acc3); - vst1q_f32( - (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 2, m - 1) * dst_stride_row), - main_acc2); - vst1q_f32( - (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 1, m - 1) * dst_stride_row), - main_acc1); - vst1q_f32( - (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 0, m - 1) * dst_stride_row), - main_acc0); - } - - lhs_ptr_start += 2 * lhs_packed_stride; - } + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x12, %x[m]\n" + "mov x11, #0x88\n" + "movi v13.16b, #0xf0\n" + "cmp x12, #0x8\n" + "mul x11, %x[num_blocks], x11\n" + "blt 8f\n" + "1:" // Row loop + "mov x10, %x[rhs_packed]\n" + "mov x9, %x[n]\n" + "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" + "2:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v1.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "mov x21, %x[num_blocks]\n" + "movi v14.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v3.16b, #0x0\n" + "movi v2.16b, #0x0\n" + "add x20, x22, x11\n" + "3:" // Block loop + "ldr d11, [x10, #0x0]\n" + "ldr d10, [x22, #0x0]\n" + "add x10, x10, #0x8\n" + "add x22, x22, #0x8\n" + "ldr q25, [x10, #0x0]\n" + "ldr q30, [x10, #0x10]\n" + "movi v6.4s, #0x0\n" + "movi v21.4s, #0x0\n" + "ldr d24, [x20, #0x0]\n" + "ldr q28, [x22, #0x0]\n" + "add x20, x20, #0x8\n" + "movi v9.4s, #0x0\n" + "ldr q4, [x22, #0x10]\n" + "ldr q23, [x20, #0x0]\n" + "movi v0.4s, #0x0\n" + "movi v31.4s, #0x0\n" + "ldr q17, [x20, #0x10]\n" + "ldr q18, [x10, #0x20]\n" + "shl v20.16b, v25.16b, #0x4\n" + "shl v29.16b, v30.16b, #0x4\n" + "ldr q16, [x10, #0x30]\n" + "ldr q26, [x22, #0x20]\n" + "movi v7.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "ldr q8, [x22, #0x30]\n" + "ldr q5, [x20, #0x20]\n" + "and v25.16b, v25.16b, v13.16b\n" + "and v30.16b, v30.16b, v13.16b\n" + ".inst 0x4e94a786 // smmla v6.4s, v28.16b, v20.16b\n" + ".inst 0x4e9da795 // smmla v21.4s, v28.16b, v29.16b\n" + "ldr q28, [x20, #0x30]\n" + "fcvtl v11.4s, v11.4h\n" + ".inst 0x4e94a489 // smmla v9.4s, v4.16b, v20.16b\n" + ".inst 0x4e9da480 // smmla v0.4s, v4.16b, v29.16b\n" + "ldr q4, [x22, #0x40]\n" + "fcvtl v10.4s, v10.4h\n" + ".inst 0x4e94a6ff // smmla v31.4s, v23.16b, v20.16b\n" + ".inst 0x4e9da6e7 // smmla v7.4s, v23.16b, v29.16b\n" + "ldr q23, [x22, #0x50]\n" + "fcvtl v24.4s, v24.4h\n" + ".inst 0x4e94a63b // smmla v27.4s, v17.16b, v20.16b\n" + "movi v20.4s, #0x0\n" + "subs x21, x21, #0x1\n" + "add x10, x10, #0x40\n" + ".inst 0x4e9da634 // smmla v20.4s, v17.16b, v29.16b\n" + "ldr q17, [x20, #0x40]\n" + "shl v29.16b, v18.16b, #0x4\n" + "and v18.16b, v18.16b, v13.16b\n" + ".inst 0x4e9da746 // smmla v6.4s, v26.16b, v29.16b\n" + ".inst 0x4e9da509 // smmla v9.4s, v8.16b, v29.16b\n" + ".inst 0x4e9da4bf // smmla v31.4s, v5.16b, v29.16b\n" + ".inst 0x4e9da79b // smmla v27.4s, v28.16b, v29.16b\n" + "ldr q29, [x20, #0x50]\n" + ".inst 0x4e99a486 // smmla v6.4s, v4.16b, v25.16b\n" + ".inst 0x4e99a6e9 // smmla v9.4s, v23.16b, v25.16b\n" + ".inst 0x4e99a63f // smmla v31.4s, v17.16b, v25.16b\n" + ".inst 0x4e99a7bb // smmla v27.4s, v29.16b, v25.16b\n" + "shl v25.16b, v16.16b, #0x4\n" + "and v16.16b, v16.16b, v13.16b\n" + ".inst 0x4e99a755 // smmla v21.4s, v26.16b, v25.16b\n" + "ldr q26, [x22, #0x60]\n" + ".inst 0x4e99a500 // smmla v0.4s, v8.16b, v25.16b\n" + "ldr q8, [x22, #0x70]\n" + "add x22, x22, #0x80\n" + ".inst 0x4e99a4a7 // smmla v7.4s, v5.16b, v25.16b\n" + "ldr q5, [x20, #0x60]\n" + ".inst 0x4e99a794 // smmla v20.4s, v28.16b, v25.16b\n" + "ldr q25, [x20, #0x70]\n" + "fmul v28.4s, v11.4s, v10.s[0]\n" + "add x20, x20, #0x80\n" + ".inst 0x4e92a746 // smmla v6.4s, v26.16b, v18.16b\n" + ".inst 0x4e9ea495 // smmla v21.4s, v4.16b, v30.16b\n" + "fmul v4.4s, v11.4s, v10.s[1]\n" + ".inst 0x4e9ea6e0 // smmla v0.4s, v23.16b, v30.16b\n" + ".inst 0x4e92a509 // smmla v9.4s, v8.16b, v18.16b\n" + "fmul v23.4s, v11.4s, v10.s[2]\n" + ".inst 0x4e9ea627 // smmla v7.4s, v17.16b, v30.16b\n" + ".inst 0x4e92a4bf // smmla v31.4s, v5.16b, v18.16b\n" + "fmul v17.4s, v11.4s, v10.s[3]\n" + ".inst 0x4e9ea7b4 // smmla v20.4s, v29.16b, v30.16b\n" + ".inst 0x4e92a73b // smmla v27.4s, v25.16b, v18.16b\n" + "fmul v30.4s, v11.4s, v24.s[0]\n" + ".inst 0x4e90a755 // smmla v21.4s, v26.16b, v16.16b\n" + "fmul v29.4s, v11.4s, v24.s[1]\n" + ".inst 0x4e90a500 // smmla v0.4s, v8.16b, v16.16b\n" + "fmul v18.4s, v11.4s, v24.s[2]\n" + "fmul v10.4s, v11.4s, v24.s[3]\n" + ".inst 0x4e90a4a7 // smmla v7.4s, v5.16b, v16.16b\n" + ".inst 0x4e90a734 // smmla v20.4s, v25.16b, v16.16b\n" + "uzp1 v26.2d, v6.2d, v21.2d\n" + "uzp2 v6.2d, v6.2d, v21.2d\n" + "uzp1 v24.2d, v9.2d, v0.2d\n" + "uzp2 v16.2d, v9.2d, v0.2d\n" + "uzp1 v8.2d, v31.2d, v7.2d\n" + "uzp2 v11.2d, v31.2d, v7.2d\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "uzp1 v31.2d, v27.2d, v20.2d\n" + "uzp2 v7.2d, v27.2d, v20.2d\n" + "scvtf v6.4s, v6.4s, #0x4\n" + "scvtf v24.4s, v24.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "scvtf v8.4s, v8.4s, #0x4\n" + "fmla v1.4s, v26.4s, v28.4s\n" + "scvtf v11.4s, v11.4s, #0x4\n" + "scvtf v31.4s, v31.4s, #0x4\n" + "scvtf v7.4s, v7.4s, #0x4\n" + "fmla v22.4s, v6.4s, v4.4s\n" + "fmla v14.4s, v24.4s, v23.4s\n" + "fmla v12.4s, v16.4s, v17.4s\n" + "fmla v15.4s, v8.4s, v30.4s\n" + "fmla v19.4s, v11.4s, v29.4s\n" + "fmla v3.4s, v31.4s, v18.4s\n" + "fmla v2.4s, v7.4s, v10.4s\n" + "bgt 3b\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x9, #0x4\n" + "ld1r { v10.4s }, [x20]\n" + "fmax v1.4s, v1.4s, v17.4s\n" + "fmax v22.4s, v22.4s, v17.4s\n" + "fmax v14.4s, v14.4s, v17.4s\n" + "fmax v12.4s, v12.4s, v17.4s\n" + "fmax v15.4s, v15.4s, v17.4s\n" + "fmax v19.4s, v19.4s, v17.4s\n" + "fmax v3.4s, v3.4s, v17.4s\n" + "fmax v2.4s, v2.4s, v17.4s\n" + "fmin v1.4s, v1.4s, v10.4s\n" + "fmin v22.4s, v22.4s, v10.4s\n" + "fmin v14.4s, v14.4s, v10.4s\n" + "fmin v12.4s, v12.4s, v10.4s\n" + "fmin v15.4s, v15.4s, v10.4s\n" + "fmin v19.4s, v19.4s, v10.4s\n" + "fmin v3.4s, v3.4s, v10.4s\n" + "fmin v2.4s, v2.4s, v10.4s\n" + "blt 4f\n" + "mov x20, %x[dst]\n" + "str q1, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q14, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q3, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q2, [x20, #0x0]\n" + "b 7f\n" + "4:" // Partial output + "mov x27, %x[dst]\n" + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x27, %x[dst_stride_row], LSL #1\n" + "add x21, x27, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "tbz x9, #1, 5f\n" + "st1 { v2.d }[0], [x23], #0x8\n" + "st1 { v3.d }[0], [x25], #0x8\n" + "st1 { v19.d }[0], [x24], #0x8\n" + "st1 { v15.d }[0], [x26], #0x8\n" + "st1 { v12.d }[0], [x20], #0x8\n" + "st1 { v14.d }[0], [x22], #0x8\n" + "st1 { v22.d }[0], [x21], #0x8\n" + "st1 { v1.d }[0], [x27], #0x8\n" + "tbz x9, #0, 6f\n" + "st1 { v2.s }[2], [x23]\n" + "st1 { v3.s }[2], [x25]\n" + "st1 { v19.s }[2], [x24]\n" + "st1 { v15.s }[2], [x26]\n" + "st1 { v12.s }[2], [x20]\n" + "st1 { v14.s }[2], [x22]\n" + "st1 { v22.s }[2], [x21]\n" + "st1 { v1.s }[2], [x27]\n" + "b 6f\n" + "5:" // Output block 0: partial_1_0 + "st1 { v2.s }[0], [x23]\n" + "st1 { v3.s }[0], [x25]\n" + "st1 { v19.s }[0], [x24]\n" + "st1 { v15.s }[0], [x26]\n" + "st1 { v12.s }[0], [x20]\n" + "st1 { v14.s }[0], [x22]\n" + "st1 { v22.s }[0], [x21]\n" + "st1 { v1.s }[0], [x27]\n" + "6:" // Output block 0: Done + "7:" // Output stage exit + "subs x9, x9, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "mov x20, #0x2\n" + "sub x12, x12, #0x8\n" + "cmp x12, #0x8\n" + "mov %x[dst], x28\n" + "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" + "bge 1b\n" + "8:" // Row loop skip + "cbz x12, 16f\n" + "9:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "10:" // Row tail: Column loop + "movi v1.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "mov x22, %x[lhs_packed]\n" + "mov x20, %x[num_blocks]\n" + "movi v14.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "11:" // Row tail: Block loop + "ldr d16, [x26, #0x0]\n" + "ldr d6, [x22, #0x0]\n" + "add x26, x26, #0x8\n" + "add x22, x22, #0x8\n" + "ldr q5, [x26, #0x0]\n" + "ldr q4, [x26, #0x10]\n" + "movi v7.4s, #0x0\n" + "movi v2.4s, #0x0\n" + "ldr q23, [x22, #0x0]\n" + "ldr q27, [x22, #0x10]\n" + "movi v0.4s, #0x0\n" + "movi v31.4s, #0x0\n" + "ldr q30, [x26, #0x20]\n" + "ldr q29, [x26, #0x30]\n" + "fcvtl v28.4s, v16.4h\n" + "fcvtl v6.4s, v6.4h\n" + "ldr q8, [x22, #0x20]\n" + "ldr q26, [x22, #0x30]\n" + "shl v21.16b, v5.16b, #0x4\n" + "shl v20.16b, v4.16b, #0x4\n" + "ldr q25, [x22, #0x40]\n" + "ldr q24, [x22, #0x50]\n" + "and v5.16b, v5.16b, v13.16b\n" + "and v4.16b, v4.16b, v13.16b\n" + "ldr q19, [x22, #0x60]\n" + "ldr q18, [x22, #0x70]\n" + "shl v17.16b, v30.16b, #0x4\n" + "shl v16.16b, v29.16b, #0x4\n" + ".inst 0x4e95a6e7 // smmla v7.4s, v23.16b, v21.16b\n" + ".inst 0x4e94a6e2 // smmla v2.4s, v23.16b, v20.16b\n" + "and v30.16b, v30.16b, v13.16b\n" + "subs x20, x20, #0x1\n" + ".inst 0x4e95a760 // smmla v0.4s, v27.16b, v21.16b\n" + ".inst 0x4e94a77f // smmla v31.4s, v27.16b, v20.16b\n" + "and v29.16b, v29.16b, v13.16b\n" + "add x26, x26, #0x40\n" + "fmul v23.4s, v28.4s, v6.s[0]\n" + "fmul v10.4s, v28.4s, v6.s[1]\n" + "add x22, x22, #0x80\n" + "fmul v21.4s, v28.4s, v6.s[2]\n" + "fmul v20.4s, v28.4s, v6.s[3]\n" + ".inst 0x4e91a507 // smmla v7.4s, v8.16b, v17.16b\n" + ".inst 0x4e90a502 // smmla v2.4s, v8.16b, v16.16b\n" + ".inst 0x4e91a740 // smmla v0.4s, v26.16b, v17.16b\n" + ".inst 0x4e90a75f // smmla v31.4s, v26.16b, v16.16b\n" + ".inst 0x4e85a727 // smmla v7.4s, v25.16b, v5.16b\n" + ".inst 0x4e84a722 // smmla v2.4s, v25.16b, v4.16b\n" + ".inst 0x4e85a700 // smmla v0.4s, v24.16b, v5.16b\n" + ".inst 0x4e84a71f // smmla v31.4s, v24.16b, v4.16b\n" + ".inst 0x4e9ea667 // smmla v7.4s, v19.16b, v30.16b\n" + ".inst 0x4e9da662 // smmla v2.4s, v19.16b, v29.16b\n" + ".inst 0x4e9ea640 // smmla v0.4s, v18.16b, v30.16b\n" + ".inst 0x4e9da65f // smmla v31.4s, v18.16b, v29.16b\n" + "uzp1 v19.2d, v7.2d, v2.2d\n" + "uzp2 v18.2d, v7.2d, v2.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp1 v17.2d, v0.2d, v31.2d\n" + "uzp2 v16.2d, v0.2d, v31.2d\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v1.4s, v19.4s, v23.4s\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v22.4s, v18.4s, v10.4s\n" + "fmla v14.4s, v17.4s, v21.4s\n" + "fmla v12.4s, v16.4s, v20.4s\n" + "bgt 11b\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x4\n" + "ld1r { v16.4s }, [x20]\n" + "fmax v1.4s, v1.4s, v17.4s\n" + "fmax v22.4s, v22.4s, v17.4s\n" + "fmax v14.4s, v14.4s, v17.4s\n" + "fmax v12.4s, v12.4s, v17.4s\n" + "fmin v1.4s, v1.4s, v16.4s\n" + "fmin v22.4s, v22.4s, v16.4s\n" + "fmin v14.4s, v14.4s, v16.4s\n" + "fmin v12.4s, v12.4s, v16.4s\n" + "blt 12f\n" + "mov x20, %x[dst]\n" + "cmp x12, #0x1\n" + "str q1, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 15f\n" + "cmp x12, #0x2\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 15f\n" + "cmp x12, #0x3\n" + "str q14, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 15f\n" + "str q12, [x20, #0x0]\n" + "b 15f\n" + "12:" // Row tail: Partial output + "mov x23, %x[dst]\n" + "cmp x12, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GT\n" + "cmp x12, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GT\n" + "cmp x12, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GT\n" + "tbz x25, #1, 13f\n" + "st1 { v12.d }[0], [x20], #0x8\n" + "st1 { v14.d }[0], [x21], #0x8\n" + "st1 { v22.d }[0], [x22], #0x8\n" + "st1 { v1.d }[0], [x23], #0x8\n" + "tbz x25, #0, 14f\n" + "st1 { v12.s }[2], [x20]\n" + "st1 { v14.s }[2], [x21]\n" + "st1 { v22.s }[2], [x22]\n" + "st1 { v1.s }[2], [x23]\n" + "b 14f\n" + "13:" // Row tail: Output block 0: partial_1_0 + "st1 { v12.s }[0], [x20]\n" + "st1 { v14.s }[0], [x21]\n" + "st1 { v22.s }[0], [x22]\n" + "st1 { v1.s }[0], [x23]\n" + "14:" // Row tail: Output block 0: Done + "15:" // Row tail: Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 10b\n" + "subs x12, x12, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x11\n" + "mov %x[dst], x24\n" + "bgt 9b\n" + "16:" // Row tail: Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); } #endif // Architectural feature check diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c index a391335d..865e641f 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c @@ -81,8 +81,7 @@ void kai_run_lhs_quant_pack_qsi8d32p_f32( float abs_max = 0.0F; const size_t dst_x = ((row_idx + m_idx_start) % mr); - int8_t* dst_ptr = - (int8_t*)lhs_packed + dst_x * k_block_len * sizeof(int8_t) + (b * mr) * num_bytes_per_block; + int8_t* dst_ptr = (int8_t*)lhs_packed + (b * mr) * num_bytes_per_block; for (size_t idx_v = 0; idx_v < bl; ++idx_v) { const float val = src_ptr[idx_v]; @@ -93,6 +92,11 @@ void kai_run_lhs_quant_pack_qsi8d32p_f32( const float scale = abs_max / ((1 << 7) - 1); const float rep_scale = scale ? 1.0f / scale : 0.0f; + *((uint16_t*)(dst_ptr + dst_x * kai_num_bytes_multiplier)) = kai_cast_f16_f32(scale); + dst_ptr += mr * kai_num_bytes_multiplier; + + dst_ptr += dst_x * k_block_len * sizeof(int8_t); + // Quantize and pack the block for (size_t k_idx = 0; k_idx < bl; k_idx += k_block_len) { for (size_t k_block_idx = 0; k_block_idx < (size_t)k_block_len; ++k_block_idx) { @@ -111,9 +115,8 @@ void kai_run_lhs_quant_pack_qsi8d32p_f32( } dst_ptr = (int8_t*)lhs_packed + mr * (bl * sizeof(int8_t)); - dst_ptr += dst_x * kai_num_bytes_multiplier + b * mr * num_bytes_per_block; + dst_ptr += b * mr * num_bytes_per_block; - *((uint16_t*)(dst_ptr)) = kai_cast_f16_f32(scale); src_ptr += bl; } // Move to the next row if we have interleaved all Mr rows diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c index 475d781d..a0038446 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c @@ -111,13 +111,20 @@ void kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( for (size_t x = 0; x < num_blocks_per_row; ++x) { // Store the scales at the end of the block - uint8_t* scales = (dst_row + (bl / 2) * nr); + uint8_t* scales = (dst_row); for (size_t i = 0; i < nr; ++i) { memcpy(scales + i * kai_num_bytes_multiplier, src_row + i * rhs_stride, kai_num_bytes_multiplier); } src_row += kai_num_bytes_multiplier; + for (size_t i = 0; i < nr; ++i) { + const float d = kai_cast_f32_f16(((uint16_t*)scales)[i]); + ((uint16_t*)scales)[i] = kai_cast_f16_f32(d); + } + + dst_row += (kai_num_bytes_multiplier * nr); + // Store the segments for (size_t s = 0; s < num_segments_per_block; ++s) { for (size_t i = 0; i < nr; ++i) { @@ -134,13 +141,6 @@ void kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( src_row += num_bytes_per_segment; dst_row += num_bytes_per_segment * nr; } - - for (size_t i = 0; i < nr; ++i) { - const float d = kai_cast_f32_f16(((uint16_t*)scales)[i]); - ((uint16_t*)scales)[i] = kai_cast_f16_f32(d * 0.0625F); - } - - dst_row += (kai_num_bytes_multiplier * nr); } } } -- GitLab