From a6c93dcce879e5cadd7d92264931f82c70ce5d26 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 20 Mar 2025 14:15:23 +0000 Subject: [PATCH 01/15] Micro-kernels to compute the matrix multiplication of dynamically quantized symmetric signed 8-bit integer with per-channel quantization (QSI8DX) LHS matrix and quantized asymmetric 4-bit signed integer with per-channel quantization (QAI4CX) RHS matrix and the accumulation of the result into a single-precision (F32): Matrix multiplication (MxN) Micro-kernels of QSI8DX LHS and QAI4CX RHS with F32 output, optimized for FEAT_I8MM. Matrix multiplication (1xN) Micro-kernels of QSI8DX LHS and QAI4CX RHS with F32 output, optimized for FEAT_DotProd. Signed-off-by: Anitha Raj --- ...2_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.c | 248 ++++++++++++++++++ ...2_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.h | 139 ++++++++++ ...i8dxp1x8_qai4cxp4x8_1x4_neon_dotprod_asm.S | 153 +++++++++++ ...tmul_clamp_f32_qsi8dxp_qai4cxp_interface.h | 52 ++++ 4 files changed, 592 insertions(+) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod_asm.S create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp_qai4cxp_interface.h diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.c new file mode 100644 index 00000000..db13bb65 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.c @@ -0,0 +1,248 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_DOTPROD) && !defined(_M_ARM64) +#error "Dotprod extension required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; + size_t num_subblocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_sum_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_offset_rhs = 4; + +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t rhs_packed_stride = kai_nr * (k_internal / kai_num_bytes_recip_qvalue_rhs); + + rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t kai_bl = 32; + const size_t k_internal = kai_get_k_roundedup(k); + const size_t num_blocks = k_internal / kai_bl; + const float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x26, #0x20\n" + "mov x20, #0x8\n" + "movi v30.16b, #0xf0\n" + "mov x25, %x[m]\n" + "madd x26, %x[num_blocks], x26, x20\n" + "1:" // Row loop + "mov x24, %x[rhs_packed]\n" + "mov x23, %x[n]\n" + "add x22, %x[dst], %x[dst_stride_row]\n" + "2:" // Column loop + "mov x21, %x[lhs_packed]\n" + "movi v29.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "mov x20, %x[num_blocks]\n" + "3:" // Sub block loop + "ldr q27, [x24, #0x0]\n" + "ldr q26, [x24, #0x10]\n" + "subs x20, x20, #0x1\n" + "ld1r { v25.2d }, [x21], #0x8\n" + "ldr q24, [x24, #0x20]\n" + "ldr q23, [x24, #0x30]\n" + "add x24, x24, #0x40\n" + "ld1r { v22.2d }, [x21], #0x8\n" + "ld1r { v21.2d }, [x21], #0x8\n" + "shl v20.16b, v27.16b, #0x4\n" + "shl v19.16b, v26.16b, #0x4\n" + "ld1r { v18.2d }, [x21], #0x8\n" + "shl v17.16b, v24.16b, #0x4\n" + "and v27.16b, v27.16b, v30.16b\n" + "shl v16.16b, v23.16b, #0x4\n" + "and v26.16b, v26.16b, v30.16b\n" + ".inst 0x4e99969d // sdot v29.4s, v20.16b, v25.16b\n" + ".inst 0x4e99967c // sdot v28.4s, v19.16b, v25.16b\n" + "and v24.16b, v24.16b, v30.16b\n" + "and v23.16b, v23.16b, v30.16b\n" + ".inst 0x4e96963d // sdot v29.4s, v17.16b, v22.16b\n" + ".inst 0x4e96961c // sdot v28.4s, v16.16b, v22.16b\n" + ".inst 0x4e95977d // sdot v29.4s, v27.16b, v21.16b\n" + ".inst 0x4e95975c // sdot v28.4s, v26.16b, v21.16b\n" + ".inst 0x4e92971d // sdot v29.4s, v24.16b, v18.16b\n" + ".inst 0x4e9296fc // sdot v28.4s, v23.16b, v18.16b\n" + "bgt 3b\n" + "ldr q20, [x24, #0x10]\n" + "ldr q19, [x24, #0x20]\n" + "add x21, x21, #0x4\n" + "addp v29.4s, v29.4s, v28.4s\n" + "ld1r { v16.4s }, [x21]\n" + "ld1r { v18.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x23, #0x4\n" + "ld1r { v17.4s }, [x20]\n" + "add x24, x24, #0x30\n" + "scvtf v29.4s, v29.4s\n" + "fmul v20.4s, v20.4s, v16.4s\n" + "fmul v16.4s, v29.4s, v20.4s\n" + "fadd v16.4s, v16.4s, v19.4s\n" + "fmax v16.4s, v16.4s, v18.4s\n" + "fmin v16.4s, v16.4s, v17.4s\n" + "blt 4f\n" + "str q16, [%x[dst], #0x0]\n" + "b 7f\n" + "4:" // Partial output + "mov x20, %x[dst]\n" + "tbz x23, #1, 5f\n" + "st1 { v16.d }[0], [x20], #0x8\n" + "tbz x23, #0, 6f\n" + "st1 { v16.s }[2], [x20]\n" + "b 6f\n" + "5:" // Output block 0: partial_1_0 + "st1 { v16.s }[0], [x20]\n" + "6:" // Output block 0: Done + "7:" // Stores done + "subs x23, x23, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "subs x25, x25, #0x1\n" + "add %x[lhs_packed], %x[lhs_packed], x26\n" + "mov %x[dst], x22\n" + "bgt 1b\n" + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "x20", "x21", "x22", "x23", "x24", "x25", "x26"); + /* + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + + kai_kernel_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(&args); + */ +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.h new file mode 100644 index 00000000..ebf3cd6b --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.h @@ -0,0 +1,139 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qsi8dxp_f32 to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qai4cxp_qau4c32s1s0 to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qai4cxp_qau4c32s1s0 to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-row quantization (qsi8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be 1. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Asymmetric Signed 4-bit with per-channel quantization (qai4cx) values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod( + size_t n_idx, // + size_t k); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Symmetric Signed 8-bit with per-row quantization (qsi8dx) and packed. +/// RHS matrix: Quantized Asymmetric Signed 4-bit with per-channel quantization (qai4cx) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: dotprod +/// +/// @param[in] m The number of output rows written. It must be 1. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod_asm.S new file mode 100644 index 00000000..4861c74d --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod_asm.S @@ -0,0 +1,153 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod) + stp x20, x21, [sp, -80]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + mov x15, #0x20 + movi v30.16b, #0xf0 + mov x21, #0x8 + ldr x14, [x0, #0x40] + ldr x13, [x0, #0x38] + ldr x20, [x0, #0x28] + ldr x12, [x0, #0x8] + ldr x11, [x0, #0x10] + ldr x10, [x0, #0x30] + mul x15, x14, x15 + ldr x9, [x0, #0x0] + ldr x28, [x0, #0x20] + ldr x27, [x0, #0x18] + mov x26, x20 + madd x15, x13, x15, x21 +KAI_ASM_LABEL(label_1) // Row loop + mov x25, x11 + mov x24, x10 + add x23, x9, x28 +KAI_ASM_LABEL(label_2) // Column loop + mov x22, x12 + mov x21, x13 +KAI_ASM_LABEL(label_3) // Block loop + movi v29.4s, #0x0 + movi v28.4s, #0x0 + mov x20, x14 +KAI_ASM_LABEL(label_4) // Sub block loop + ldr q27, [x25, #0x0] + ldr q26, [x25, #0x10] + subs x20, x20, #0x1 + ld1r { v25.2d }, [x22], #0x8 + ldr q24, [x25, #0x20] + ldr q23, [x25, #0x30] + add x25, x25, #0x40 + ld1r { v22.2d }, [x22], #0x8 + ld1r { v21.2d }, [x22], #0x8 + shl v20.16b, v27.16b, #0x4 + shl v19.16b, v26.16b, #0x4 + ld1r { v18.2d }, [x22], #0x8 + shl v17.16b, v24.16b, #0x4 + and v27.16b, v27.16b, v30.16b + shl v16.16b, v23.16b, #0x4 + and v26.16b, v26.16b, v30.16b + KAI_ASM_INST(0x4e99969d) // sdot v29.4s, v20.16b, v25.16b + KAI_ASM_INST(0x4e99967c) // sdot v28.4s, v19.16b, v25.16b + and v24.16b, v24.16b, v30.16b + and v23.16b, v23.16b, v30.16b + KAI_ASM_INST(0x4e96963d) // sdot v29.4s, v17.16b, v22.16b + KAI_ASM_INST(0x4e96961c) // sdot v28.4s, v16.16b, v22.16b + KAI_ASM_INST(0x4e95977d) // sdot v29.4s, v27.16b, v21.16b + KAI_ASM_INST(0x4e95975c) // sdot v28.4s, v26.16b, v21.16b + KAI_ASM_INST(0x4e92971d) // sdot v29.4s, v24.16b, v18.16b + KAI_ASM_INST(0x4e9296fc) // sdot v28.4s, v23.16b, v18.16b + bgt label_4 + ldr q17, [x25, #0x10] + add x22, x22, #0x4 + addp v29.4s, v29.4s, v28.4s + sub x21, x21, #0x1 + ld1r { v16.4s }, [x22] + add x22, x22, #0x4 + add x25, x25, #0x20 + scvtf v29.4s, v29.4s + fmul v17.4s, v17.4s, v16.4s + fmul v19.4s, v29.4s, v17.4s + cbnz x21, label_3 + ldr q18, [x25, #0x0] + ld1r { v17.4s }, [x27] + add x20, x27, #0x4 + cmp x24, #0x4 + ld1r { v16.4s }, [x20] + add x25, x25, #0x10 + fadd v19.4s, v19.4s, v18.4s + fmax v19.4s, v19.4s, v17.4s + fmin v19.4s, v19.4s, v16.4s + blt label_5 + str q19, [x9, #0x0] + b label_8 +KAI_ASM_LABEL(label_5) // Partial output + mov x20, x9 + tbz x24, #1, label_6 + st1 { v19.d }[0], [x20], #0x8 + tbz x24, #0, label_7 + st1 { v19.s }[2], [x20] + b label_7 +KAI_ASM_LABEL(label_6) // Output block 0: partial_1_0 + st1 { v19.s }[0], [x20] +KAI_ASM_LABEL(label_7) // Output block 0: Done +KAI_ASM_LABEL(label_8) // Stores done + subs x24, x24, #0x4 + add x9, x9, #0x10 + bgt label_2 + subs x26, x26, #0x1 + add x12, x12, x15 + mov x9, x23 + bgt label_1 + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp x20, x21, [sp], 80 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod) + + KAI_ASM_END diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp_qai4cxp_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp_qai4cxp_interface.h new file mode 100644 index 00000000..1fd19e26 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp_qai4cxp_interface.h @@ -0,0 +1,52 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f32_qsi8dxp_qai4cxp + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_f32_qsi8dxp_qai4cxp_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_f32_qsi8dxp_qai4cxp_ukernel { + kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_m_step_func_t get_m_step; + kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_n_step_func_t get_n_step; + kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_mr_func_t get_mr; + kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_nr_func_t get_nr; + kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_kr_func_t get_kr; + kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_sr_func_t get_sr; + kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f32_qsi8dxp_qai4cxp_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f32_qsi8dxp_qai4cxp_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif -- GitLab From 56f75ff3b30fcc1ae1025580ff58dc924606fc7d Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 27 Mar 2025 15:09:36 +0000 Subject: [PATCH 02/15] Add QSI8DX LHS packing kernel Signed-off-by: Anitha Raj --- ..._lhs_quant_pack_qsi8dxpscalef32_f32_neon.c | 152 ++++++++++++++++++ ..._lhs_quant_pack_qsi8dxpscalef32_f32_neon.h | 78 +++++++++ 2 files changed, 230 insertions(+) create mode 100644 kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c create mode 100644 kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.h diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c new file mode 100644 index 00000000..d38b27d2 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c @@ -0,0 +1,152 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) +#error This file must be compiled for AArch64. +#else // Architectural features check. + +#include "kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.h" + +#include +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_per_multiplier = sizeof(float); +static const size_t kai_num_bytes_per_offset = sizeof(int32_t); + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for memory alignment. + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) { + const size_t k_internal = kai_k_roundedup(k, kr, sr); + + KAI_ASSERT((k_internal % 2) == 0); + + return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); +} + +size_t kai_get_m_step_lhs_quant_pack_qsi8dxpscalef32_f32_neon(size_t mr) { + KAI_UNUSED(mr); + return 1; +} + +size_t kai_get_lhs_offset_lhs_quant_pack_qsi8dxpscalef32_f32_neon(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8dxpscalef32_f32_neon( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { + // It always points to the beginning of the row + return (m_idx / mr) * kai_get_lhs_packed_stride(k, mr, kr, sr); +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8dxpscalef32_f32_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + const size_t num_rows = kai_roundup(m, mr) / mr; + + return (num_rows * kai_get_lhs_packed_stride(k, mr, kr, sr)); +} + +void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* restrict lhs, + size_t lhs_stride, void* restrict lhs_packed) { + KAI_ASSERT((kr % sr) == 0); + + if (m == 0) { + return; + } + + const size_t num_rows = m; + + const float* src_ptr = lhs; + + const size_t dst_stride = kai_get_lhs_packed_stride(k, mr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const int32_t k_block_len = (int32_t)(kr / sr); + for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { + float absmax = -FLT_MAX; + // Find min/max for each channel + int32_t k_idx = 0; +#if defined(__aarch64__) + float32x4_t vabsmax = vdupq_n_f32(-FLT_MAX); + for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { + const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx); + const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx); + // Calculate the max + vabsmax = vmaxq_f32(vabsq_f32(src0_0), vmaxq_f32(vabsmax, vabsq_f32(src0_1))); + } + // Get the max/min + absmax = vmaxvq_f32(vabsmax); +#endif + for (; k_idx < (int32_t)k; ++k_idx) { + const float src0_0 = *(src_ptr + (size_t)k_idx); + absmax = KAI_MAX(src0_0, absmax); + } + // Maximum int8 values + const float qmax = (float)INT8_MAX; + const float scale0 = absmax == 0.0F ? 0.0F : qmax / absmax; + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; + + const size_t dst_x = ((row_idx + m_idx_start) % mr); + uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); + int32_t qsum = 0; + + // Quantize the channels + k_idx = 0; + for (; k_idx < (int32_t)k_internal; k_idx += k_block_len) { + for (size_t k_block_idx = 0; k_block_idx < (size_t)k_block_len; ++k_block_idx) { + // Clamp at the last valid k-index + const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k - 1); + const float src0_0 = *(src_ptr + k_idx_start); + + // Scale the values + int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); + + v0_s32 = KAI_MAX(v0_s32, INT8_MIN); + v0_s32 = KAI_MIN(v0_s32, INT8_MAX); + + *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; + if ((size_t)k_idx + k_block_idx <= k - 1) { + qsum += v0_s32; + } + dst_ptr += sizeof(int8_t); + } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + + dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); + + dst_ptr += dst_x * kai_num_bytes_per_offset; + + // LHS offset at the beginning of the row + *((float*)(dst_ptr)) = (float)(((float)qsum) * recip_scale0); + + // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier + KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); + + dst_ptr += mr * kai_num_bytes_per_offset; + + // Store the scale quantization params + *((float*)(dst_ptr)) = recip_scale0; + + src_ptr += (lhs_stride / sizeof(float)); + + // Move to the next row if we have interleaved all Mr rows + if ((((row_idx + 1) + m_idx_start) % mr) == 0) { + lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); + } + } +} + +#endif diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.h b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.h new file mode 100644 index 00000000..9bafa553 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.h @@ -0,0 +1,78 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @param[in] mr The number of M rows to interleave on the same output row. +/// +/// @return the m step value +size_t kai_get_m_step_lhs_quant_pack_qsi8dxpscalef32_f32_neon(size_t mr); + +/// Gets the offset in bytes for the LHS matrix (not packed) +/// +/// This function should be called before passing the pointer to the LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) +/// +/// @return the offset in bytes to the LHS matrix +size_t kai_get_lhs_offset_lhs_quant_pack_qsi8dxpscalef32_f32_neon(size_t m_idx, size_t lhs_stride); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized symmetric per-row (qsi8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of mr. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qsi8dxpscalef32_f32_neon( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); + +/// Gets the size in bytes for the quantized and packed LHS matrix +/// +/// @param[in] m Total number of rows in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the packed LHS matrix size in bytes +size_t kai_get_lhs_packed_size_lhs_quant_pack_qsi8dxpscalef32_f32_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr); + +/// Run the micro-kernel to quantize and pack the LHS matrix. +/// +/// @param[in] m The number of output rows written. +/// @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] m_idx_start The starting M index. +/// @param[in] lhs LHS of the vector-by-matrix. +/// @param[in] lhs_stride Stride in bytes between two rows of LHS. +/// @param[out] lhs_packed The quantized and packed LHS matrix. +void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, size_t lhs_stride, + void* lhs_packed); + +#ifdef __cplusplus +} +#endif -- GitLab From f1fcfff2e6ffbfba11bcebbc28b2fe85148fcfdc Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 27 Mar 2025 15:11:45 +0000 Subject: [PATCH 03/15] Add QAI4CX RHS packing kernel Signed-off-by: Anitha Raj --- ..._nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.c | 140 ++++++++++++++++++ ..._nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h | 106 +++++++++++++ 2 files changed, 246 insertions(+) create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.c create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.c new file mode 100644 index 00000000..36f88223 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.c @@ -0,0 +1,140 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) +#error This file must be compiled for AArch64. +#else // Architectural features check. +#include "kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h" + +#include +#include + +#include "kai/kai_common.h" +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon(size_t n_idx, size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( + size_t k, size_t nr, size_t kr, size_t sr) { + const size_t k_internal = kai_k_roundedup(k, kr, sr); + KAI_ASSERT((k_internal % 2) == 0); + + return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { + KAI_ASSERT((n_idx % nr) == 0); + + return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon(k, nr, kr, sr); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( + size_t n, size_t k, size_t nr, size_t kr, size_t sr) { + const size_t num_rows = kai_roundup(n, nr) / nr; + + return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon(k, nr, kr, sr); +} + +void kai_run_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const void* zero, + const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qai4cxp_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + const int32_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_stride = kai_roundup(k, 2) / 2; + const size_t rhs_packed_stride = + kai_get_rhs_packed_stride_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon(k, nr, kr, sr); + + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (k_internal / 2); + const size_t block_length_in_bytes = kr / sr; + const size_t k_interleaved_v = 16U; + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + const size_t shift_right_x0 = (k0_idx % 2) * 4; + const size_t shift_right_x1 = (k1_idx % 2) * 4; + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } + // Adjust the zero point + if (zero == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + dst_row += nr * sizeof(float); + } else { + for (size_t i = 0; i < nr; ++i) { + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = *((const float*)zero + src_row_idx); + dst_row += sizeof(float); + } + } + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = *((const float*)scale + src_row_idx) * 0.0625F; + dst_row += sizeof(float); + } + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = *((const float*)bias + src_row_idx); + } + } + } +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h new file mode 100644 index 00000000..141eea79 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h @@ -0,0 +1,106 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef kai_rhs_pack_nxk_qai4cxp_params +#define kai_rhs_pack_nxk_qai4cxp_params kai_rhs_pack_qs4cxs1s0_param +#endif + +/// Gets the offset in bytes for the RHS matrix (not packed). +/// +/// @note The int4 values are stored in a N x K matrix. Two int4 values are stored in one byte. +/// The lower order part of the byte (low) holds the first nibble (K-index + 0). +/// The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon(size_t n_idx, size_t rhs_stride); + +/// Get the row stride in bytes to the packed RHS matrix +/// +/// @param[in] k In the RHS matrix (not packed), K is the number of columns. +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the stride in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( + size_t k, size_t nr, size_t kr, size_t sr); + +/// Gets the offset in bytes for the packed RHS matrix, which contains the packed 4-bit quantized symmetric per-channel +/// (qsu4cx) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k In the RHS matrix (not packed), K is the number of columns. +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr); + +/// @brief Gets the size in bytes for the packed RHS matrix +/// +/// @param[in] n The number of rows in the RHS matrix (not packed) +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( + size_t n, size_t k, size_t nr, size_t kr, size_t sr); + +/// Run the micro-kernel to pack the RHS matrix. +/// +/// @note The int4 values are stored in a N x K matrix. Two int4 values are stored in one byte. +/// The lower order part of the byte (low) holds the first nibble (K-index + 0). +/// The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of rows. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). It must be an even value. +/// @param[in] nr The number of N rows to interleave on the same output output row. +/// @param[in] kr The number of K values loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). +/// @param[in] zero The zero point. +/// @param[in] bias The biases. +/// @param[in] scale The scale for each output channel. +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + const uint8_t* rhs, // + const void* zero, // + const void* bias, // + const void* scale, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qai4cxp_params* params); + +#ifdef __cplusplus +} +#endif -- GitLab From c5b7493d42accbd32f55c8de72c8df010395b821 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 27 Mar 2025 15:45:08 +0000 Subject: [PATCH 04/15] Add kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm kernel Signed-off-by: Anitha Raj --- ..._f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c | 529 ++++++++++++++++++ ..._f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.h | 139 +++++ ..._qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm_asm.S | 438 +++++++++++++++ 3 files changed, 1106 insertions(+) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm_asm.S diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c new file mode 100644 index 00000000..1932f71e --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c @@ -0,0 +1,529 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_MATMUL_INT8) && !defined(_M_ARM64) +#error "I8mm extension required to compile this micro-kernel" +#else // Architectural features check. + +#include "kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.h" + +#include +#include + +#include "kai/kai_common.h" + +typedef struct { + float* dst; + const void* lhs_packed; + const void* rhs_packed; + const float* clamp_vals; + size_t dst_stride_row; + size_t m; + size_t n; + size_t num_blocks; +} KernelArgs; + +void kai_kernel_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(KernelArgs* args_ptr); + +// Compute args +static const size_t kai_m_step = 8; +static const size_t kai_n_step = 4; +// Packing args +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +// LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) +static const size_t kai_num_bytes_qvalue_lhs = 1; +static const size_t kai_num_bytes_multiplier_lhs = 4; +// RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is +// asymmetric)) +static const size_t kai_num_bytes_recip_qvalue_rhs = 2; +static const size_t kai_num_bytes_multiplier_rhs = 4; +// DST format args +static const size_t kai_num_bytes_dst_value = 4; +// Extra args +static const size_t kai_num_bytes_bias = 4; +static const size_t kai_k_multiple_of = 32; + +inline static size_t kai_get_k_roundedup(size_t k) { + return kai_roundup(k, kai_k_multiple_of); +} + +inline static size_t kai_get_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + + return lhs_packed_stride; +} + +inline static size_t kai_get_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_get_k_roundedup(k); + size_t rhs_packed_stride = kai_nr * (k_internal / kai_num_bytes_recip_qvalue_rhs); + + rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; + // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias + rhs_packed_stride += kai_nr * kai_num_bytes_bias; + + return rhs_packed_stride; +} + +size_t kai_get_m_step_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + + return (m_idx / kai_mr) * kai_get_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx / kai_nr) * kai_get_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSUME((m_idx % kai_m_step) == 0); + KAI_ASSUME((n_idx % kai_n_step) == 0); + + return (n_idx * kai_num_bytes_dst_value) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(size_t m, size_t n) { + return m * n * kai_num_bytes_dst_value; +} + +void kai_run_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + const void* restrict lhs_packed, // + const void* restrict rhs_packed, // + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max) { + KAI_ASSUME(dst_stride_col == sizeof(float)); + if (m == 0) { + return; + } + const size_t kai_bl = 32; + const size_t k_internal = kai_get_k_roundedup(k); + size_t num_blocks = k_internal / kai_bl; + const float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x12, %x[m]\n" + "mov x11, #0x80\n" + "movi v11.16b, #0xf0\n" + "mov x20, #0x20\n" + "cmp x12, #0x8\n" + "madd x11, %x[num_blocks], x11, x20\n" + "blt 10f\n" + "1:" // Row loop + "mov x10, %x[rhs_packed]\n" + "mov x9, %x[n]\n" + "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" + "2:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v10.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "mov x21, %x[num_blocks]\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "add x20, x22, x11\n" + "movi v4.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "3:" // Sub block loop + "ldr q2, [x10, #0x0]\n" + "ldr q1, [x10, #0x10]\n" + "subs x21, x21, #0x1\n" + "ldr q20, [x22, #0x0]\n" + "ldr q19, [x22, #0x10]\n" + "ldr q18, [x20, #0x0]\n" + "ldr q0, [x20, #0x10]\n" + "ldr q31, [x10, #0x20]\n" + "ldr q30, [x10, #0x30]\n" + "shl v17.16b, v2.16b, #0x4\n" + "shl v16.16b, v1.16b, #0x4\n" + "ldr q29, [x22, #0x20]\n" + "ldr q28, [x22, #0x30]\n" + "and v2.16b, v2.16b, v11.16b\n" + "and v1.16b, v1.16b, v11.16b\n" + "ldr q27, [x20, #0x20]\n" + "ldr q26, [x20, #0x30]\n" + "add x10, x10, #0x40\n" + "ldr q25, [x22, #0x40]\n" + "ldr q24, [x22, #0x50]\n" + ".inst 0x4e91a68a // smmla v10.4s, v20.16b, v17.16b\n" + ".inst 0x4e90a689 // smmla v9.4s, v20.16b, v16.16b\n" + "ldr q23, [x20, #0x40]\n" + "ldr q22, [x20, #0x50]\n" + ".inst 0x4e91a668 // smmla v8.4s, v19.16b, v17.16b\n" + ".inst 0x4e90a667 // smmla v7.4s, v19.16b, v16.16b\n" + "ldr q21, [x22, #0x60]\n" + "ldr q20, [x22, #0x70]\n" + ".inst 0x4e91a646 // smmla v6.4s, v18.16b, v17.16b\n" + ".inst 0x4e90a645 // smmla v5.4s, v18.16b, v16.16b\n" + "ldr q19, [x20, #0x60]\n" + "ldr q18, [x20, #0x70]\n" + ".inst 0x4e91a404 // smmla v4.4s, v0.16b, v17.16b\n" + ".inst 0x4e90a403 // smmla v3.4s, v0.16b, v16.16b\n" + "shl v17.16b, v31.16b, #0x4\n" + "shl v16.16b, v30.16b, #0x4\n" + "add x22, x22, #0x80\n" + "add x20, x20, #0x80\n" + "and v31.16b, v31.16b, v11.16b\n" + "and v30.16b, v30.16b, v11.16b\n" + ".inst 0x4e91a7aa // smmla v10.4s, v29.16b, v17.16b\n" + ".inst 0x4e90a7a9 // smmla v9.4s, v29.16b, v16.16b\n" + ".inst 0x4e91a788 // smmla v8.4s, v28.16b, v17.16b\n" + ".inst 0x4e90a787 // smmla v7.4s, v28.16b, v16.16b\n" + ".inst 0x4e91a766 // smmla v6.4s, v27.16b, v17.16b\n" + ".inst 0x4e90a765 // smmla v5.4s, v27.16b, v16.16b\n" + ".inst 0x4e91a744 // smmla v4.4s, v26.16b, v17.16b\n" + ".inst 0x4e90a743 // smmla v3.4s, v26.16b, v16.16b\n" + ".inst 0x4e82a72a // smmla v10.4s, v25.16b, v2.16b\n" + ".inst 0x4e81a729 // smmla v9.4s, v25.16b, v1.16b\n" + ".inst 0x4e82a708 // smmla v8.4s, v24.16b, v2.16b\n" + ".inst 0x4e81a707 // smmla v7.4s, v24.16b, v1.16b\n" + ".inst 0x4e82a6e6 // smmla v6.4s, v23.16b, v2.16b\n" + ".inst 0x4e81a6e5 // smmla v5.4s, v23.16b, v1.16b\n" + ".inst 0x4e82a6c4 // smmla v4.4s, v22.16b, v2.16b\n" + ".inst 0x4e81a6c3 // smmla v3.4s, v22.16b, v1.16b\n" + ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" + ".inst 0x4e9fa666 // smmla v6.4s, v19.16b, v31.16b\n" + ".inst 0x4e9ea665 // smmla v5.4s, v19.16b, v30.16b\n" + ".inst 0x4e9fa644 // smmla v4.4s, v18.16b, v31.16b\n" + ".inst 0x4e9ea643 // smmla v3.4s, v18.16b, v30.16b\n" + "bgt 3b\n" + "ldr q24, [x10, #0x10]\n" + "add x22, x22, #0x10\n" + "uzp1 v23.2d, v10.2d, v9.2d\n" + "uzp2 v22.2d, v10.2d, v9.2d\n" + "ldr q16, [x22, #0x0]\n" + "uzp1 v21.2d, v8.2d, v7.2d\n" + "uzp2 v20.2d, v8.2d, v7.2d\n" + "add x10, x10, #0x20\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v22.4s, v22.4s\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "scvtf v21.4s, v21.4s\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v20.4s, v20.4s\n" + "fmul v10.4s, v23.4s, v19.4s\n" + "fmul v9.4s, v22.4s, v18.4s\n" + "fmul v8.4s, v21.4s, v17.4s\n" + "fmul v7.4s, v20.4s, v16.4s\n" + "add x20, x20, #0x10\n" + "uzp1 v23.2d, v6.2d, v5.2d\n" + "uzp2 v22.2d, v6.2d, v5.2d\n" + "ldr q16, [x20, #0x0]\n" + "uzp1 v21.2d, v4.2d, v3.2d\n" + "uzp2 v20.2d, v4.2d, v3.2d\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "scvtf v23.4s, v23.4s\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "scvtf v22.4s, v22.4s\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "scvtf v21.4s, v21.4s\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v20.4s, v20.4s\n" + "fmul v6.4s, v23.4s, v19.4s\n" + "fmul v5.4s, v22.4s, v18.4s\n" + "fmul v4.4s, v21.4s, v17.4s\n" + "fmul v3.4s, v20.4s, v16.4s\n" + "ldr q18, [x10, #0x0]\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x9, #0x4\n" + "ld1r { v16.4s }, [x20]\n" + "add x10, x10, #0x10\n" + "fadd v10.4s, v10.4s, v18.4s\n" + "fadd v9.4s, v9.4s, v18.4s\n" + "fadd v8.4s, v8.4s, v18.4s\n" + "fadd v7.4s, v7.4s, v18.4s\n" + "fadd v6.4s, v6.4s, v18.4s\n" + "fadd v5.4s, v5.4s, v18.4s\n" + "fadd v4.4s, v4.4s, v18.4s\n" + "fadd v3.4s, v3.4s, v18.4s\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v7.4s, v7.4s, v17.4s\n" + "fmax v6.4s, v6.4s, v17.4s\n" + "fmax v5.4s, v5.4s, v17.4s\n" + "fmax v4.4s, v4.4s, v17.4s\n" + "fmax v3.4s, v3.4s, v17.4s\n" + "fmin v10.4s, v10.4s, v16.4s\n" + "fmin v9.4s, v9.4s, v16.4s\n" + "fmin v8.4s, v8.4s, v16.4s\n" + "fmin v7.4s, v7.4s, v16.4s\n" + "fmin v6.4s, v6.4s, v16.4s\n" + "fmin v5.4s, v5.4s, v16.4s\n" + "fmin v4.4s, v4.4s, v16.4s\n" + "fmin v3.4s, v3.4s, v16.4s\n" + "blt 6f\n" + "mov x20, %x[dst]\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q9, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q6, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q3, [x20, #0x0]\n" + "b 9f\n" + "6:" // Partial output + "mov x27, %x[dst]\n" + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x27, %x[dst_stride_row], LSL #1\n" + "add x21, x27, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "tbz x9, #1, 7f\n" + "st1 { v3.d }[0], [x23], #0x8\n" + "st1 { v4.d }[0], [x25], #0x8\n" + "st1 { v5.d }[0], [x24], #0x8\n" + "st1 { v6.d }[0], [x26], #0x8\n" + "st1 { v7.d }[0], [x20], #0x8\n" + "st1 { v8.d }[0], [x22], #0x8\n" + "st1 { v9.d }[0], [x21], #0x8\n" + "st1 { v10.d }[0], [x27], #0x8\n" + "tbz x9, #0, 8f\n" + "st1 { v3.s }[2], [x23]\n" + "st1 { v4.s }[2], [x25]\n" + "st1 { v5.s }[2], [x24]\n" + "st1 { v6.s }[2], [x26]\n" + "st1 { v7.s }[2], [x20]\n" + "st1 { v8.s }[2], [x22]\n" + "st1 { v9.s }[2], [x21]\n" + "st1 { v10.s }[2], [x27]\n" + "b 8f\n" + "7:" // Output block 0: partial_1_0 + "st1 { v3.s }[0], [x23]\n" + "st1 { v4.s }[0], [x25]\n" + "st1 { v5.s }[0], [x24]\n" + "st1 { v6.s }[0], [x26]\n" + "st1 { v7.s }[0], [x20]\n" + "st1 { v8.s }[0], [x22]\n" + "st1 { v9.s }[0], [x21]\n" + "st1 { v10.s }[0], [x27]\n" + "8:" // Output block 0: Done + "9:" // Output stage exit + "subs x9, x9, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "mov x20, #0x2\n" + "sub x12, x12, #0x8\n" + "cmp x12, #0x8\n" + "mov %x[dst], x28\n" + "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" + "bge 1b\n" + "10:" // Row loop skip + "cbz x12, 19f\n" + "11:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "12:" // Row tail: Column loop + "mov x22, %x[lhs_packed]\n" + "movi v10.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "mov x20, %x[num_blocks]\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "13:" // Row tail: Sub block loop + "ldr q31, [x26, #0x0]\n" + "ldr q30, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q29, [x22, #0x0]\n" + "ldr q28, [x22, #0x10]\n" + "ldr q27, [x26, #0x20]\n" + "ldr q26, [x26, #0x30]\n" + "add x26, x26, #0x40\n" + "ldr q25, [x22, #0x20]\n" + "ldr q24, [x22, #0x30]\n" + "shl v23.16b, v31.16b, #0x4\n" + "shl v22.16b, v30.16b, #0x4\n" + "ldr q21, [x22, #0x40]\n" + "ldr q20, [x22, #0x50]\n" + "and v31.16b, v31.16b, v11.16b\n" + "and v30.16b, v30.16b, v11.16b\n" + "ldr q19, [x22, #0x60]\n" + "ldr q18, [x22, #0x70]\n" + "shl v17.16b, v27.16b, #0x4\n" + "shl v16.16b, v26.16b, #0x4\n" + ".inst 0x4e97a7aa // smmla v10.4s, v29.16b, v23.16b\n" + ".inst 0x4e96a7a9 // smmla v9.4s, v29.16b, v22.16b\n" + "and v27.16b, v27.16b, v11.16b\n" + "add x22, x22, #0x80\n" + ".inst 0x4e97a788 // smmla v8.4s, v28.16b, v23.16b\n" + ".inst 0x4e96a787 // smmla v7.4s, v28.16b, v22.16b\n" + "and v26.16b, v26.16b, v11.16b\n" + ".inst 0x4e91a72a // smmla v10.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a729 // smmla v9.4s, v25.16b, v16.16b\n" + ".inst 0x4e91a708 // smmla v8.4s, v24.16b, v17.16b\n" + ".inst 0x4e90a707 // smmla v7.4s, v24.16b, v16.16b\n" + ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" + ".inst 0x4e9ba66a // smmla v10.4s, v19.16b, v27.16b\n" + ".inst 0x4e9aa669 // smmla v9.4s, v19.16b, v26.16b\n" + ".inst 0x4e9ba648 // smmla v8.4s, v18.16b, v27.16b\n" + ".inst 0x4e9aa647 // smmla v7.4s, v18.16b, v26.16b\n" + "bgt 13b\n" + "ldr q24, [x26, #0x10]\n" + "add x22, x22, #0x10\n" + "uzp1 v23.2d, v10.2d, v9.2d\n" + "uzp2 v22.2d, v10.2d, v9.2d\n" + "ldr q16, [x22, #0x0]\n" + "uzp1 v21.2d, v8.2d, v7.2d\n" + "uzp2 v20.2d, v8.2d, v7.2d\n" + "add x26, x26, #0x20\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v22.4s, v22.4s\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "scvtf v21.4s, v21.4s\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v20.4s, v20.4s\n" + "fmul v10.4s, v23.4s, v19.4s\n" + "fmul v9.4s, v22.4s, v18.4s\n" + "fmul v8.4s, v21.4s, v17.4s\n" + "fmul v7.4s, v20.4s, v16.4s\n" + "ldr q18, [x26, #0x0]\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x4\n" + "ld1r { v16.4s }, [x20]\n" + "add x26, x26, #0x10\n" + "fadd v10.4s, v10.4s, v18.4s\n" + "fadd v9.4s, v9.4s, v18.4s\n" + "fadd v8.4s, v8.4s, v18.4s\n" + "fadd v7.4s, v7.4s, v18.4s\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v7.4s, v7.4s, v17.4s\n" + "fmin v10.4s, v10.4s, v16.4s\n" + "fmin v9.4s, v9.4s, v16.4s\n" + "fmin v8.4s, v8.4s, v16.4s\n" + "fmin v7.4s, v7.4s, v16.4s\n" + "blt 15f\n" + "mov x20, %x[dst]\n" + "cmp x12, #0x1\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 18f\n" + "cmp x12, #0x2\n" + "str q9, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 18f\n" + "cmp x12, #0x3\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 18f\n" + "str q7, [x20, #0x0]\n" + "b 18f\n" + "15:" // Row tail: Partial output + "mov x23, %x[dst]\n" + "cmp x12, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GT\n" + "cmp x12, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GT\n" + "cmp x12, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GT\n" + "tbz x25, #1, 16f\n" + "st1 { v7.d }[0], [x20], #0x8\n" + "st1 { v8.d }[0], [x21], #0x8\n" + "st1 { v9.d }[0], [x22], #0x8\n" + "st1 { v10.d }[0], [x23], #0x8\n" + "tbz x25, #0, 17f\n" + "st1 { v7.s }[2], [x20]\n" + "st1 { v8.s }[2], [x21]\n" + "st1 { v9.s }[2], [x22]\n" + "st1 { v10.s }[2], [x23]\n" + "b 17f\n" + "16:" // Row tail: Output block 0: partial_1_0 + "st1 { v7.s }[0], [x20]\n" + "st1 { v8.s }[0], [x21]\n" + "st1 { v9.s }[0], [x22]\n" + "st1 { v10.s }[0], [x23]\n" + "17:" // Row tail: Output block 0: Done + "18:" // Row tail: Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 12b\n" + "subs x12, x12, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x11\n" + "mov %x[dst], x24\n" + "bgt 11b\n" + "19:" // Row tail: Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v10", "v11", "v16", "v17", "v18", "v19", "v2", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v31", "v4", "v5", "v6", "v7", "v8", "v9", "x10", "x11", + "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9"); + // KernelArgs args; + + // args.dst = dst; + // args.lhs_packed = lhs_packed; + // args.rhs_packed = rhs_packed; + // args.clamp_vals = clamp_vals; + // args.dst_stride_row = dst_stride_row; + // args.m = m; + // args.n = n; + + // kai_kernel_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(&args); +} + +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.h new file mode 100644 index 00000000..b59c69e2 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.h @@ -0,0 +1,139 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# @ref kai_lhs_quant_pack_qsi8dxp_f32 to dynamically quantize and pack the LHS matrix in a single step. +/// -# @ref kai_rhs_pack_nxk_qai4cxp_ to pack the RHS NxK matrix. +/// -# @ref kai_rhs_pack_kxn_qai4cxp_ to pack the RHS KxN matrix. + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(void); + +/// Gets the nr value, which must be used to pack the RHS matrix. +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Quantized Symmetric Signed 8-bit with per-row quantization (qsi8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of m_step. +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Quantized Asymmetric Signed 4-bit with per-channel quantization (qai4cx) values. +/// +/// @param[in] n_idx Col index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm( + size_t n_idx, // + size_t k); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of m_step. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of n_step. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Quantized Symmetric Signed 8-bit with per-row quantization (qsi8dx) and packed. +/// RHS matrix: Quantized Asymmetric Signed 4-bit with per-channel quantization (qai4cx) and packed. +/// Output tile: (rows x cols) = m_step x n_step. +/// +/// Note: Please refer to the get functions for m_step and n_step for the exact values. +/// +/// Features used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. The micro-kernel to pack the native LHS matrix is reported at the +/// top of this file. +/// @param[in] rhs_packed The RHS packed matrix. The micro-kernel to pack the native RHS matrix is reported at the +/// top of this file. +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float) bytes. +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif // __cplusplus diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm_asm.S new file mode 100644 index 00000000..9d867705 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm_asm.S @@ -0,0 +1,438 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(_MSC_VER) + #define KAI_ASM_GLOBAL(name) GLOBAL name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) name PROC + #define KAI_ASM_FUNCTION_END(name) ENDP + + #define KAI_ASM_CODE(name) AREA name, CODE, READONLY + #define KAI_ASM_ALIGN + #define KAI_ASM_LABEL(name) name + #define KAI_ASM_INST(hex) DCD hex + #define KAI_ASM_END END +#else + #if defined(__APPLE__) + #define KAI_ASM_GLOBAL(name) .globl _##name + #define KAI_ASM_FUNCTION_TYPE(name) + #define KAI_ASM_FUNCTION_LABEL(name) _##name: + #define KAI_ASM_FUNCTION_END(name) + #else + #define KAI_ASM_GLOBAL(name) .global name + #define KAI_ASM_FUNCTION_TYPE(name) .type name, %function + #define KAI_ASM_FUNCTION_LABEL(name) name: + #define KAI_ASM_FUNCTION_END(name) .size name, .-name + #endif + + #define KAI_ASM_CODE(name) .text + #define KAI_ASM_ALIGN .p2align 4,,11 + #define KAI_ASM_LABEL(name) name: + #define KAI_ASM_INST(hex) .inst hex + #define KAI_ASM_END +#endif + + KAI_ASM_CODE(matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm) + KAI_ASM_ALIGN + + KAI_ASM_GLOBAL(kai_kernel_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm) + +KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm) +KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm) + stp x20, x21, [sp, -112]! + stp x22, x23, [sp, 16] + stp x24, x25, [sp, 32] + stp x26, x27, [sp, 48] + str x28, [sp, 64] + stp d10, d11, [sp, 72] + stp d8, d9, [sp, 88] + mov x7, #0x80 + movi v11.16b, #0xf0 + mov x21, #0x20 + ldr x20, [x0, #0x28] + ldr x8, [x0, #0x38] + ldr x17, [x0, #0x8] + ldr x16, [x0, #0x10] + ldr x15, [x0, #0x30] + ldr x14, [x0, #0x0] + mov x13, x20 + ldr x12, [x0, #0x20] + madd x7, x8, x7, x21 + ldr x11, [x0, #0x18] + cmp x13, #0x8 + blt label_10 +KAI_ASM_LABEL(label_1) // Row loop + mov x10, x16 + mov x9, x15 + add x28, x14, x12, LSL #3 +KAI_ASM_LABEL(label_2) // Column loop + mov x22, x17 + movi v10.4s, #0x0 + movi v9.4s, #0x0 + mov x21, x8 + movi v8.4s, #0x0 + movi v7.4s, #0x0 + movi v6.4s, #0x0 + movi v5.4s, #0x0 + add x20, x22, x7 + movi v4.4s, #0x0 + movi v3.4s, #0x0 +KAI_ASM_LABEL(label_3) // Sub block loop + ldr q2, [x10, #0x0] + ldr q1, [x10, #0x10] + subs x21, x21, #0x1 + ldr q20, [x22, #0x0] + ldr q19, [x22, #0x10] + ldr q18, [x20, #0x0] + ldr q0, [x20, #0x10] + ldr q31, [x10, #0x20] + ldr q30, [x10, #0x30] + shl v17.16b, v2.16b, #0x4 + shl v16.16b, v1.16b, #0x4 + ldr q29, [x22, #0x20] + ldr q28, [x22, #0x30] + and v2.16b, v2.16b, v11.16b + and v1.16b, v1.16b, v11.16b + ldr q27, [x20, #0x20] + ldr q26, [x20, #0x30] + add x10, x10, #0x40 + ldr q25, [x22, #0x40] + ldr q24, [x22, #0x50] + KAI_ASM_INST(0x4e91a68a) // smmla v10.4s, v20.16b, v17.16b + KAI_ASM_INST(0x4e90a689) // smmla v9.4s, v20.16b, v16.16b + ldr q23, [x20, #0x40] + ldr q22, [x20, #0x50] + KAI_ASM_INST(0x4e91a668) // smmla v8.4s, v19.16b, v17.16b + KAI_ASM_INST(0x4e90a667) // smmla v7.4s, v19.16b, v16.16b + ldr q21, [x22, #0x60] + ldr q20, [x22, #0x70] + KAI_ASM_INST(0x4e91a646) // smmla v6.4s, v18.16b, v17.16b + KAI_ASM_INST(0x4e90a645) // smmla v5.4s, v18.16b, v16.16b + ldr q19, [x20, #0x60] + ldr q18, [x20, #0x70] + KAI_ASM_INST(0x4e91a404) // smmla v4.4s, v0.16b, v17.16b + KAI_ASM_INST(0x4e90a403) // smmla v3.4s, v0.16b, v16.16b + shl v17.16b, v31.16b, #0x4 + shl v16.16b, v30.16b, #0x4 + add x22, x22, #0x80 + add x20, x20, #0x80 + and v31.16b, v31.16b, v11.16b + and v30.16b, v30.16b, v11.16b + KAI_ASM_INST(0x4e91a7aa) // smmla v10.4s, v29.16b, v17.16b + KAI_ASM_INST(0x4e90a7a9) // smmla v9.4s, v29.16b, v16.16b + KAI_ASM_INST(0x4e91a788) // smmla v8.4s, v28.16b, v17.16b + KAI_ASM_INST(0x4e90a787) // smmla v7.4s, v28.16b, v16.16b + KAI_ASM_INST(0x4e91a766) // smmla v6.4s, v27.16b, v17.16b + KAI_ASM_INST(0x4e90a765) // smmla v5.4s, v27.16b, v16.16b + KAI_ASM_INST(0x4e91a744) // smmla v4.4s, v26.16b, v17.16b + KAI_ASM_INST(0x4e90a743) // smmla v3.4s, v26.16b, v16.16b + KAI_ASM_INST(0x4e82a72a) // smmla v10.4s, v25.16b, v2.16b + KAI_ASM_INST(0x4e81a729) // smmla v9.4s, v25.16b, v1.16b + KAI_ASM_INST(0x4e82a708) // smmla v8.4s, v24.16b, v2.16b + KAI_ASM_INST(0x4e81a707) // smmla v7.4s, v24.16b, v1.16b + KAI_ASM_INST(0x4e82a6e6) // smmla v6.4s, v23.16b, v2.16b + KAI_ASM_INST(0x4e81a6e5) // smmla v5.4s, v23.16b, v1.16b + KAI_ASM_INST(0x4e82a6c4) // smmla v4.4s, v22.16b, v2.16b + KAI_ASM_INST(0x4e81a6c3) // smmla v3.4s, v22.16b, v1.16b + KAI_ASM_INST(0x4e9fa6aa) // smmla v10.4s, v21.16b, v31.16b + KAI_ASM_INST(0x4e9ea6a9) // smmla v9.4s, v21.16b, v30.16b + KAI_ASM_INST(0x4e9fa688) // smmla v8.4s, v20.16b, v31.16b + KAI_ASM_INST(0x4e9ea687) // smmla v7.4s, v20.16b, v30.16b + KAI_ASM_INST(0x4e9fa666) // smmla v6.4s, v19.16b, v31.16b + KAI_ASM_INST(0x4e9ea665) // smmla v5.4s, v19.16b, v30.16b + KAI_ASM_INST(0x4e9fa644) // smmla v4.4s, v18.16b, v31.16b + KAI_ASM_INST(0x4e9ea643) // smmla v3.4s, v18.16b, v30.16b + bgt label_3 + ldr q24, [x10, #0x10] + add x22, x22, #0x10 + uzp1 v23.2d, v10.2d, v9.2d + uzp2 v22.2d, v10.2d, v9.2d + ldr q16, [x22, #0x0] + uzp1 v21.2d, v8.2d, v7.2d + uzp2 v20.2d, v8.2d, v7.2d + add x10, x10, #0x20 + scvtf v23.4s, v23.4s + scvtf v22.4s, v22.4s + fmul v19.4s, v24.4s, v16.s[0] + fmul v18.4s, v24.4s, v16.s[1] + fmul v17.4s, v24.4s, v16.s[2] + scvtf v21.4s, v21.4s + fmul v16.4s, v24.4s, v16.s[3] + scvtf v20.4s, v20.4s + fmul v10.4s, v23.4s, v19.4s + fmul v9.4s, v22.4s, v18.4s + fmul v8.4s, v21.4s, v17.4s + fmul v7.4s, v20.4s, v16.4s + add x20, x20, #0x10 + uzp1 v23.2d, v6.2d, v5.2d + uzp2 v22.2d, v6.2d, v5.2d + ldr q16, [x20, #0x0] + uzp1 v21.2d, v4.2d, v3.2d + uzp2 v20.2d, v4.2d, v3.2d + fmul v19.4s, v24.4s, v16.s[0] + scvtf v23.4s, v23.4s + fmul v18.4s, v24.4s, v16.s[1] + scvtf v22.4s, v22.4s + fmul v17.4s, v24.4s, v16.s[2] + scvtf v21.4s, v21.4s + fmul v16.4s, v24.4s, v16.s[3] + scvtf v20.4s, v20.4s + fmul v6.4s, v23.4s, v19.4s + fmul v5.4s, v22.4s, v18.4s + fmul v4.4s, v21.4s, v17.4s + fmul v3.4s, v20.4s, v16.4s + ldr q18, [x10, #0x0] + ld1r { v17.4s }, [x11] + add x20, x11, #0x4 + cmp x9, #0x4 + ld1r { v16.4s }, [x20] + add x10, x10, #0x10 + fadd v10.4s, v10.4s, v18.4s + fadd v9.4s, v9.4s, v18.4s + fadd v8.4s, v8.4s, v18.4s + fadd v7.4s, v7.4s, v18.4s + fadd v6.4s, v6.4s, v18.4s + fadd v5.4s, v5.4s, v18.4s + fadd v4.4s, v4.4s, v18.4s + fadd v3.4s, v3.4s, v18.4s + fmax v10.4s, v10.4s, v17.4s + fmax v9.4s, v9.4s, v17.4s + fmax v8.4s, v8.4s, v17.4s + fmax v7.4s, v7.4s, v17.4s + fmax v6.4s, v6.4s, v17.4s + fmax v5.4s, v5.4s, v17.4s + fmax v4.4s, v4.4s, v17.4s + fmax v3.4s, v3.4s, v17.4s + fmin v10.4s, v10.4s, v16.4s + fmin v9.4s, v9.4s, v16.4s + fmin v8.4s, v8.4s, v16.4s + fmin v7.4s, v7.4s, v16.4s + fmin v6.4s, v6.4s, v16.4s + fmin v5.4s, v5.4s, v16.4s + fmin v4.4s, v4.4s, v16.4s + fmin v3.4s, v3.4s, v16.4s + blt label_6 + mov x20, x14 + str q10, [x20, #0x0] + add x20, x20, x12 + str q9, [x20, #0x0] + add x20, x20, x12 + str q8, [x20, #0x0] + add x20, x20, x12 + str q7, [x20, #0x0] + add x20, x20, x12 + str q6, [x20, #0x0] + add x20, x20, x12 + str q5, [x20, #0x0] + add x20, x20, x12 + str q4, [x20, #0x0] + add x20, x20, x12 + str q3, [x20, #0x0] + b label_9 +KAI_ASM_LABEL(label_6) // Partial output + mov x27, x14 + add x26, x27, x12, LSL #2 + add x25, x26, x12, LSL #1 + add x24, x26, x12 + add x23, x25, x12 + add x22, x27, x12, LSL #1 + add x21, x27, x12 + add x20, x22, x12 + tbz x9, #1, label_7 + st1 { v3.d }[0], [x23], #0x8 + st1 { v4.d }[0], [x25], #0x8 + st1 { v5.d }[0], [x24], #0x8 + st1 { v6.d }[0], [x26], #0x8 + st1 { v7.d }[0], [x20], #0x8 + st1 { v8.d }[0], [x22], #0x8 + st1 { v9.d }[0], [x21], #0x8 + st1 { v10.d }[0], [x27], #0x8 + tbz x9, #0, label_8 + st1 { v3.s }[2], [x23] + st1 { v4.s }[2], [x25] + st1 { v5.s }[2], [x24] + st1 { v6.s }[2], [x26] + st1 { v7.s }[2], [x20] + st1 { v8.s }[2], [x22] + st1 { v9.s }[2], [x21] + st1 { v10.s }[2], [x27] + b label_8 +KAI_ASM_LABEL(label_7) // Output block 0: partial_1_0 + st1 { v3.s }[0], [x23] + st1 { v4.s }[0], [x25] + st1 { v5.s }[0], [x24] + st1 { v6.s }[0], [x26] + st1 { v7.s }[0], [x20] + st1 { v8.s }[0], [x22] + st1 { v9.s }[0], [x21] + st1 { v10.s }[0], [x27] +KAI_ASM_LABEL(label_8) // Output block 0: Done +KAI_ASM_LABEL(label_9) // Output stage exit + subs x9, x9, #0x4 + add x14, x14, #0x10 + bgt label_2 + mov x20, #0x2 + sub x13, x13, #0x8 + cmp x13, #0x8 + mov x14, x28 + madd x17, x20, x7, x17 + bge label_1 +KAI_ASM_LABEL(label_10) // Row loop skip + cbz x13, label_19 +KAI_ASM_LABEL(label_11) // Row tail: Row loop + mov x26, x16 + mov x25, x15 + add x24, x14, x12, LSL #2 +KAI_ASM_LABEL(label_12) // Row tail: Column loop + mov x22, x17 + movi v10.4s, #0x0 + movi v9.4s, #0x0 + mov x20, x8 + movi v8.4s, #0x0 + movi v7.4s, #0x0 +KAI_ASM_LABEL(label_13) // Row tail: Sub block loop + ldr q31, [x26, #0x0] + ldr q30, [x26, #0x10] + subs x20, x20, #0x1 + ldr q29, [x22, #0x0] + ldr q28, [x22, #0x10] + ldr q27, [x26, #0x20] + ldr q26, [x26, #0x30] + add x26, x26, #0x40 + ldr q25, [x22, #0x20] + ldr q24, [x22, #0x30] + shl v23.16b, v31.16b, #0x4 + shl v22.16b, v30.16b, #0x4 + ldr q21, [x22, #0x40] + ldr q20, [x22, #0x50] + and v31.16b, v31.16b, v11.16b + and v30.16b, v30.16b, v11.16b + ldr q19, [x22, #0x60] + ldr q18, [x22, #0x70] + shl v17.16b, v27.16b, #0x4 + shl v16.16b, v26.16b, #0x4 + KAI_ASM_INST(0x4e97a7aa) // smmla v10.4s, v29.16b, v23.16b + KAI_ASM_INST(0x4e96a7a9) // smmla v9.4s, v29.16b, v22.16b + and v27.16b, v27.16b, v11.16b + add x22, x22, #0x80 + KAI_ASM_INST(0x4e97a788) // smmla v8.4s, v28.16b, v23.16b + KAI_ASM_INST(0x4e96a787) // smmla v7.4s, v28.16b, v22.16b + and v26.16b, v26.16b, v11.16b + KAI_ASM_INST(0x4e91a72a) // smmla v10.4s, v25.16b, v17.16b + KAI_ASM_INST(0x4e90a729) // smmla v9.4s, v25.16b, v16.16b + KAI_ASM_INST(0x4e91a708) // smmla v8.4s, v24.16b, v17.16b + KAI_ASM_INST(0x4e90a707) // smmla v7.4s, v24.16b, v16.16b + KAI_ASM_INST(0x4e9fa6aa) // smmla v10.4s, v21.16b, v31.16b + KAI_ASM_INST(0x4e9ea6a9) // smmla v9.4s, v21.16b, v30.16b + KAI_ASM_INST(0x4e9fa688) // smmla v8.4s, v20.16b, v31.16b + KAI_ASM_INST(0x4e9ea687) // smmla v7.4s, v20.16b, v30.16b + KAI_ASM_INST(0x4e9ba66a) // smmla v10.4s, v19.16b, v27.16b + KAI_ASM_INST(0x4e9aa669) // smmla v9.4s, v19.16b, v26.16b + KAI_ASM_INST(0x4e9ba648) // smmla v8.4s, v18.16b, v27.16b + KAI_ASM_INST(0x4e9aa647) // smmla v7.4s, v18.16b, v26.16b + bgt label_13 + ldr q24, [x26, #0x10] + add x22, x22, #0x10 + uzp1 v23.2d, v10.2d, v9.2d + uzp2 v22.2d, v10.2d, v9.2d + ldr q16, [x22, #0x0] + uzp1 v21.2d, v8.2d, v7.2d + uzp2 v20.2d, v8.2d, v7.2d + add x26, x26, #0x20 + scvtf v23.4s, v23.4s + scvtf v22.4s, v22.4s + fmul v19.4s, v24.4s, v16.s[0] + fmul v18.4s, v24.4s, v16.s[1] + fmul v17.4s, v24.4s, v16.s[2] + scvtf v21.4s, v21.4s + fmul v16.4s, v24.4s, v16.s[3] + scvtf v20.4s, v20.4s + fmul v10.4s, v23.4s, v19.4s + fmul v9.4s, v22.4s, v18.4s + fmul v8.4s, v21.4s, v17.4s + fmul v7.4s, v20.4s, v16.4s + ldr q18, [x26, #0x0] + ld1r { v17.4s }, [x11] + add x20, x11, #0x4 + cmp x25, #0x4 + ld1r { v16.4s }, [x20] + add x26, x26, #0x10 + fadd v10.4s, v10.4s, v18.4s + fadd v9.4s, v9.4s, v18.4s + fadd v8.4s, v8.4s, v18.4s + fadd v7.4s, v7.4s, v18.4s + fmax v10.4s, v10.4s, v17.4s + fmax v9.4s, v9.4s, v17.4s + fmax v8.4s, v8.4s, v17.4s + fmax v7.4s, v7.4s, v17.4s + fmin v10.4s, v10.4s, v16.4s + fmin v9.4s, v9.4s, v16.4s + fmin v8.4s, v8.4s, v16.4s + fmin v7.4s, v7.4s, v16.4s + blt label_15 + mov x20, x14 + cmp x13, #0x1 + str q10, [x20, #0x0] + add x20, x20, x12 + ble label_18 + cmp x13, #0x2 + str q9, [x20, #0x0] + add x20, x20, x12 + ble label_18 + cmp x13, #0x3 + str q8, [x20, #0x0] + add x20, x20, x12 + ble label_18 + str q7, [x20, #0x0] + b label_18 +KAI_ASM_LABEL(label_15) // Row tail: Partial output + mov x23, x14 + cmp x13, #0x1 + add x22, x23, x12 + csel x22, x22, x23, GT + cmp x13, #0x2 + add x21, x23, x12, LSL #1 + csel x21, x21, x22, GT + cmp x13, #0x3 + add x20, x21, x12 + csel x20, x20, x21, GT + tbz x25, #1, label_16 + st1 { v7.d }[0], [x20], #0x8 + st1 { v8.d }[0], [x21], #0x8 + st1 { v9.d }[0], [x22], #0x8 + st1 { v10.d }[0], [x23], #0x8 + tbz x25, #0, label_17 + st1 { v7.s }[2], [x20] + st1 { v8.s }[2], [x21] + st1 { v9.s }[2], [x22] + st1 { v10.s }[2], [x23] + b label_17 +KAI_ASM_LABEL(label_16) // Row tail: Output block 0: partial_1_0 + st1 { v7.s }[0], [x20] + st1 { v8.s }[0], [x21] + st1 { v9.s }[0], [x22] + st1 { v10.s }[0], [x23] +KAI_ASM_LABEL(label_17) // Row tail: Output block 0: Done +KAI_ASM_LABEL(label_18) // Row tail: Output stage exit + subs x25, x25, #0x4 + add x14, x14, #0x10 + bgt label_12 + subs x13, x13, #0x4 + add x17, x17, x7 + mov x14, x24 + bgt label_11 +KAI_ASM_LABEL(label_19) // Row tail: Row loop skip + ldp x22, x23, [sp, 16] + ldp x24, x25, [sp, 32] + ldp x26, x27, [sp, 48] + ldr x28, [sp, 64] + ldp d10, d11, [sp, 72] + ldp d8, d9, [sp, 88] + ldp x20, x21, [sp], 112 + ret + KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm) + + KAI_ASM_END -- GitLab From e4f7be6c1f2f22cbfc4971dbb70ee923bdd3de84 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 27 Mar 2025 15:46:14 +0000 Subject: [PATCH 05/15] Add unit tests for matmul_clamp_f32_qsi8dxp_qai4cxp Signed-off-by: Anitha Raj --- .../matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp | 210 ++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp diff --git a/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp b/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp new file mode 100644 index 00000000..bef78512 --- /dev/null +++ b/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp @@ -0,0 +1,210 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp_qai4cxp_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h" +#include "test/common/cpu_info.hpp" +#include "test/common/float16.hpp" +#include "test/common/int4.hpp" +#include "test/common/matrix_portion.hpp" +#include "test/common/memory.hpp" +#include "test/common/round.hpp" +#include "test/common/test_suite.hpp" +#include "test/reference/cast.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/pack.hpp" +#include "test/reference/quantize.hpp" + +namespace kai::test { + +static const std::array, 2> + variants_kai_matmul_clamp_f32_qsi8dxp_qai4cxp = { + {{UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod), + "kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod", cpu_has_dotprod}, + {UKERNEL_MATMUL_VARIANT(clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm), + "kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm", cpu_has_i8mm}}}; + +class MatMulTest_f32_qsi8dxp_qai4cxp : public ::testing::TestWithParam {}; + +TEST_P(MatMulTest_f32_qsi8dxp_qai4cxp, EndToEnd_NullBias) { + const auto& [variant_index, matmul_shape, portion] = GetParam(); + const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8dxp_qai4cxp.at(variant_index); + + if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { + GTEST_SKIP() << "Kernel not supported"; + } + + const std::uint64_t seed = 0; + + const size_t M = matmul_shape.m; + const size_t N = matmul_shape.n; + const size_t K = matmul_shape.k; + + const auto mr = ukernel_variant.interface.get_mr(); + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); + + if (mr == 1 && M > 1) { + GTEST_SKIP() << "Kernel does not support M != 1"; + } + + auto m_step = ukernel_variant.interface.get_m_step(); + ASSERT_TRUE(m_step % mr == 0); + + auto n_step = ukernel_variant.interface.get_n_step(); + ASSERT_TRUE(n_step % nr == 0); + + const auto rect = portion.compute_portion(M, N, m_step, n_step); + if (rect.height() == 0 || rect.width() == 0) { + GTEST_SKIP() << "Test Portion size is 0!"; + } + + // Generates input data + const auto ref_lhs = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit symmetric quantization. + // * Quantizes the RHS matrix using 8-bit asymmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales] = + quantize_symmetric_per_block_dynamic(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qai4, ref_rhs_scales, ref_rhs_zero_points] = + quantize_asymmetric_per_block_dynamic(ref_rhs.data(), N, K, K); + + const auto ref_dst = matmul_clamp_nt_t( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), nullptr, K, ref_rhs_qai4.data(), ref_rhs_scales.data(), + ref_rhs_zero_points.data(), K, nullptr, std::numeric_limits::lowest(), + std::numeric_limits::max()); + + // Runs the LHS packing micro-kernel. + const auto lhs_start_row = rect.start_row(); + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qsi8dxpscalef32_f32_neon(M, K, mr, kr, sr); + std::vector imp_packed_lhs(imp_packed_lhs_size); + + auto lhs_stride = K * sizeof(float); + auto lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qsi8dxpscalef32_f32_neon(lhs_start_row, lhs_stride); + auto lhs_packed_offset = + kai_get_lhs_packed_offset_lhs_quant_pack_qsi8dxpscalef32_f32_neon(lhs_start_row, K, mr, kr, sr); + auto lhs_matmul_offset = ukernel_variant.interface.get_lhs_packed_offset(lhs_start_row, K); + + ASSERT_EQ(lhs_packed_offset, lhs_matmul_offset); + + kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( + rect.height() /* m */, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data() + lhs_offset), + lhs_stride, imp_packed_lhs.data() + lhs_packed_offset); + + // Prepare the offsets as the RHS packing kernel expects the scaled zero-points in float. + const size_t ref_zp_size = N; + const size_t ref_zp_size_in_bytes = ref_zp_size * sizeof(float); + std::vector ref_rhs_zp_f32(ref_zp_size_in_bytes); + for (size_t i = 0; i < ref_zp_size; ++i) { + reinterpret_cast(ref_rhs_zp_f32.data())[i] = + -reinterpret_cast(ref_rhs_zero_points.data())[i] * + reinterpret_cast(ref_rhs_scales.data())[i]; + } + // Runs the RHS packing micro-kernel. + const auto ref_rhs_qau4 = cast_qsu4_qsi4(ref_rhs_qai4.data(), N * K); + + const auto imp_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon(N, K, nr, kr, sr); + std::vector imp_packed_rhs(imp_packed_rhs_size); + const auto rhs_start_row = rect.start_col(); + auto rhs_packed_offset = + kai_get_rhs_packed_offset_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon(rhs_start_row, K, nr, kr, sr); + auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); + ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); + + const kai_rhs_pack_nxk_qai4cxp_params params{.lhs_zero_point = 1, .rhs_zero_point = 8}; + kai_run_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( + 1, N, K, nr, kr, sr, ref_rhs_qau4.data(), ref_rhs_zp_f32.data(), nullptr, ref_rhs_scales.data(), + imp_packed_rhs.data(), 0, ¶ms); + + const auto dst_stride_row = N * sizeof(float); + const auto dst_stride_col = sizeof(float); + const auto dst_offset = + ukernel_variant.interface.get_dst_offset(rect.start_row(), rect.start_col(), dst_stride_row); + const auto ref_dst_offset = rect.start_row() * dst_stride_row + rect.start_col() * dst_stride_col; + ASSERT_EQ(dst_offset, ref_dst_offset); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + std::vector imp_dst(imp_dst_size); + ukernel_variant.interface.run_matmul( + rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset, + imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast(imp_dst.data() + dst_offset), + dst_stride_row, dst_stride_col, std::numeric_limits::lowest(), std::numeric_limits::max()); + + // Compares the output of the micro-kernels against the output of the reference implementation for the portion + // tested. + for (size_t y = 0; y < rect.height(); ++y) { + for (size_t x = 0; x < rect.width(); ++x) { + const auto imp_value = + read_array(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto ref_value = + read_array(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; + if (rel_error > 0.0001F) { + ASSERT_EQ(imp_value, ref_value); + } + } + } +} + +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTest_f32_qsi8dxp_qai4cxp, + testing::Combine( + testing::Range(0, variants_kai_matmul_clamp_f32_qsi8dxp_qai4cxp.size()), + testing::Values( + MatMulShape{1, 2, 32}, // + MatMulShape{32, 64, 64}, // + MatMulShape{16, 32, 64}, // + MatMulShape{4, 4, 32}, // + MatMulShape{15, 32, 32}, // + MatMulShape{77, 99, 64}), + testing::Values( + MatrixPortion(0, 0, 1, 1), // Full matrix. + MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. + MatrixPortion(0, 0.75, 1, 1), // Rightmost portion. + MatrixPortion(0, 0.5, 1, 0.8), // Somewhere Middle + MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. + MatrixPortion(0.75, 0, 1, 1), // Partial rows + MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle + )), + [](const auto& info) { + const auto variant_idx = std::get<0>(info.param); + const std::string name{variants_kai_matmul_clamp_f32_qsi8dxp_qai4cxp.at(variant_idx).name}; + const auto shape = std::get(info.param); + const auto portion = std::get<2>(info.param); + + std::stringstream sstream; + sstream << name << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // + << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // + << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // + << "__PortionHeight_" << static_cast(portion.height() * 1000) // + << "__PortionWidth_" << static_cast(portion.width() * 1000); + return sstream.str(); + }); + +} // namespace kai::test -- GitLab From 702ea8040d69166e01d3be59dc41de313c0eb823 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Tue, 1 Apr 2025 11:35:38 +0100 Subject: [PATCH 06/15] Update tests and Cmake files Signed-off-by: Anitha Raj --- CMakeLists.txt | 7 +++++++ test/reference/quantize.cpp | 1 + 2 files changed, 8 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8e0c7130..b14daa97 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -137,6 +137,7 @@ set(KLEIDIAI_FILES_NEON kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla_asm.S kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c + kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c @@ -144,6 +145,7 @@ set(KLEIDIAI_FILES_NEON kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.c ) set(KLEIDIAI_FILES_NEON_DOTPROD_ASM @@ -163,6 +165,8 @@ set(KLEIDIAI_FILES_NEON_DOTPROD_ASM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod_asm.S kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.c ) set(KLEIDIAI_FILES_NEON_DOTPROD @@ -191,6 +195,8 @@ set(KLEIDIAI_FILES_NEON_I8MM_ASM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm_asm.S kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm_asm.S kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm_asm.S + kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c ) set(KLEIDIAI_FILES_NEON_I8MM @@ -369,6 +375,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_clamp_f32_qsi8d32p_qai4c32p_test.cpp test/tests/matmul_clamp_f16_qsi8d32p_qai4c32p_test.cpp test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp + test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp test/tests/matmul_test.cpp ) diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp index 025f677b..498d26a8 100644 --- a/test/reference/quantize.cpp +++ b/test/reference/quantize.cpp @@ -294,4 +294,5 @@ quantize_asymmetric_per_block_dynamic( template std::tuple, std::vector, std::vector> quantize_asymmetric_per_block_dynamic( const void* src, size_t height, size_t width, size_t quant_width); + } // namespace kai::test -- GitLab From 7994ad2d6d5f2c995835428f0abd2857fdbf5a94 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Tue, 1 Apr 2025 12:51:41 +0100 Subject: [PATCH 07/15] Update matmul stride calculation Signed-off-by: Anitha Raj --- ..._matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c | 7 ++++++- .../kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.c | 6 +++--- .../kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c index 1932f71e..076ed06d 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c @@ -38,10 +38,13 @@ static const size_t kai_sr = 2; // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) static const size_t kai_num_bytes_qvalue_lhs = 1; static const size_t kai_num_bytes_multiplier_lhs = 4; +static const size_t kai_num_bytes_asymmetry_lhs = 4; + // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is // asymmetric)) static const size_t kai_num_bytes_recip_qvalue_rhs = 2; static const size_t kai_num_bytes_multiplier_rhs = 4; +static const size_t kai_num_bytes_asymmetry_rhs = 4; // DST format args static const size_t kai_num_bytes_dst_value = 4; // Extra args @@ -54,7 +57,8 @@ inline static size_t kai_get_k_roundedup(size_t k) { inline static size_t kai_get_lhs_packed_stride(size_t k) { const size_t k_internal = kai_get_k_roundedup(k); - size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + size_t lhs_packed_stride = + kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_asymmetry_lhs); return lhs_packed_stride; } @@ -66,6 +70,7 @@ inline static size_t kai_get_rhs_packed_stride(size_t k) { rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias rhs_packed_stride += kai_nr * kai_num_bytes_bias; + rhs_packed_stride += kai_nr * kai_num_bytes_asymmetry_rhs; return rhs_packed_stride; } diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.c index 36f88223..c255e6b5 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.c @@ -13,7 +13,7 @@ #include #include "kai/kai_common.h" -static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_offset_rhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); @@ -33,7 +33,7 @@ size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neo const size_t k_internal = kai_k_roundedup(k, kr, sr); KAI_ASSERT((k_internal % 2) == 0); - return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); + return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_offset_rhs + kai_num_bytes_bias); } size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( @@ -109,7 +109,7 @@ void kai_run_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( } // Adjust the zero point if (zero == NULL) { - memset(dst_row, 0, nr * kai_num_bytes_bias); + memset(dst_row, 0, nr * kai_num_bytes_offset_rhs); dst_row += nr * sizeof(float); } else { for (size_t i = 0; i < nr; ++i) { diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h index 141eea79..2a84598a 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h @@ -74,7 +74,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( /// @param[in] num_groups The number of groups. It must be 1. /// @param[in] n The number of rows. /// @param[in] k The common dimension between the LHS and RHS matrix (K). It must be an even value. -/// @param[in] nr The number of N rows to interleave on the same output output row. +/// @param[in] nr The number of N rows to interleave on the same output row. /// @param[in] kr The number of K values loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// However, kr must be multiple of sr. -- GitLab From dc5a255715853be365014a3a455187bee997bca2 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 10 Apr 2025 16:26:49 +0100 Subject: [PATCH 08/15] Update the matmul asm kernels Signed-off-by: Anitha Raj --- ...2_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.c | 100 +---- ...2_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.h | 3 +- ...i8dxp1x8_qai4cxp4x8_1x4_neon_dotprod_asm.S | 125 +++--- ..._f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c | 415 +----------------- ..._f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.h | 3 +- ..._qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm_asm.S | 65 ++- 6 files changed, 131 insertions(+), 580 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.c index db13bb65..a0fa31ef 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.c @@ -10,7 +10,6 @@ #include "kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.h" #include -#include #include "kai/kai_common.h" @@ -23,7 +22,6 @@ typedef struct { size_t m; size_t n; size_t num_blocks; - size_t num_subblocks; } KernelArgs; void kai_kernel_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(KernelArgs* args_ptr); @@ -52,13 +50,16 @@ static const size_t kai_num_bytes_dst_value = 4; static const size_t kai_num_bytes_bias = 4; static const size_t kai_k_multiple_of = 32; +static const size_t kai_bl = 32; + inline static size_t kai_get_k_roundedup(size_t k) { return kai_roundup(k, kai_k_multiple_of); } inline static size_t kai_get_lhs_packed_stride(size_t k) { const size_t k_internal = kai_get_k_roundedup(k); - size_t lhs_packed_stride = kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs); + size_t lhs_packed_stride = + kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs); return lhs_packed_stride; } @@ -70,6 +71,7 @@ inline static size_t kai_get_rhs_packed_stride(size_t k) { rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias rhs_packed_stride += kai_nr * kai_num_bytes_bias; + rhs_packed_stride += kai_nr * kai_num_bytes_offset_rhs; return rhs_packed_stride; } @@ -128,108 +130,21 @@ void kai_run_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod( size_t k, // const void* restrict lhs_packed, // const void* restrict rhs_packed, // - float* dst, // NOLINT(readability-non-const-parameter) + float* restrict dst, // NOLINT(readability-non-const-parameter) size_t dst_stride_row, // size_t dst_stride_col, // float scalar_min, // float scalar_max) { KAI_ASSUME(dst_stride_col == sizeof(float)); - + KAI_ASSUME(m == 1); if (m == 0) { return; } - const size_t kai_bl = 32; const size_t k_internal = kai_get_k_roundedup(k); const size_t num_blocks = k_internal / kai_bl; const float clamp_vals[2] = {scalar_min, scalar_max}; - __asm__ __volatile__( - "mov x26, #0x20\n" - "mov x20, #0x8\n" - "movi v30.16b, #0xf0\n" - "mov x25, %x[m]\n" - "madd x26, %x[num_blocks], x26, x20\n" - "1:" // Row loop - "mov x24, %x[rhs_packed]\n" - "mov x23, %x[n]\n" - "add x22, %x[dst], %x[dst_stride_row]\n" - "2:" // Column loop - "mov x21, %x[lhs_packed]\n" - "movi v29.4s, #0x0\n" - "movi v28.4s, #0x0\n" - "mov x20, %x[num_blocks]\n" - "3:" // Sub block loop - "ldr q27, [x24, #0x0]\n" - "ldr q26, [x24, #0x10]\n" - "subs x20, x20, #0x1\n" - "ld1r { v25.2d }, [x21], #0x8\n" - "ldr q24, [x24, #0x20]\n" - "ldr q23, [x24, #0x30]\n" - "add x24, x24, #0x40\n" - "ld1r { v22.2d }, [x21], #0x8\n" - "ld1r { v21.2d }, [x21], #0x8\n" - "shl v20.16b, v27.16b, #0x4\n" - "shl v19.16b, v26.16b, #0x4\n" - "ld1r { v18.2d }, [x21], #0x8\n" - "shl v17.16b, v24.16b, #0x4\n" - "and v27.16b, v27.16b, v30.16b\n" - "shl v16.16b, v23.16b, #0x4\n" - "and v26.16b, v26.16b, v30.16b\n" - ".inst 0x4e99969d // sdot v29.4s, v20.16b, v25.16b\n" - ".inst 0x4e99967c // sdot v28.4s, v19.16b, v25.16b\n" - "and v24.16b, v24.16b, v30.16b\n" - "and v23.16b, v23.16b, v30.16b\n" - ".inst 0x4e96963d // sdot v29.4s, v17.16b, v22.16b\n" - ".inst 0x4e96961c // sdot v28.4s, v16.16b, v22.16b\n" - ".inst 0x4e95977d // sdot v29.4s, v27.16b, v21.16b\n" - ".inst 0x4e95975c // sdot v28.4s, v26.16b, v21.16b\n" - ".inst 0x4e92971d // sdot v29.4s, v24.16b, v18.16b\n" - ".inst 0x4e9296fc // sdot v28.4s, v23.16b, v18.16b\n" - "bgt 3b\n" - "ldr q20, [x24, #0x10]\n" - "ldr q19, [x24, #0x20]\n" - "add x21, x21, #0x4\n" - "addp v29.4s, v29.4s, v28.4s\n" - "ld1r { v16.4s }, [x21]\n" - "ld1r { v18.4s }, [%x[clamp_vals]]\n" - "add x20, %x[clamp_vals], #0x4\n" - "cmp x23, #0x4\n" - "ld1r { v17.4s }, [x20]\n" - "add x24, x24, #0x30\n" - "scvtf v29.4s, v29.4s\n" - "fmul v20.4s, v20.4s, v16.4s\n" - "fmul v16.4s, v29.4s, v20.4s\n" - "fadd v16.4s, v16.4s, v19.4s\n" - "fmax v16.4s, v16.4s, v18.4s\n" - "fmin v16.4s, v16.4s, v17.4s\n" - "blt 4f\n" - "str q16, [%x[dst], #0x0]\n" - "b 7f\n" - "4:" // Partial output - "mov x20, %x[dst]\n" - "tbz x23, #1, 5f\n" - "st1 { v16.d }[0], [x20], #0x8\n" - "tbz x23, #0, 6f\n" - "st1 { v16.s }[2], [x20]\n" - "b 6f\n" - "5:" // Output block 0: partial_1_0 - "st1 { v16.s }[0], [x20]\n" - "6:" // Output block 0: Done - "7:" // Stores done - "subs x23, x23, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 2b\n" - "subs x25, x25, #0x1\n" - "add %x[lhs_packed], %x[lhs_packed], x26\n" - "mov %x[dst], x22\n" - "bgt 1b\n" - : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) - : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), - [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) - : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "x20", "x21", "x22", "x23", "x24", "x25", "x26"); - /* KernelArgs args; args.dst = dst; @@ -242,7 +157,6 @@ void kai_run_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod( args.num_blocks = num_blocks; kai_kernel_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod(&args); - */ } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.h index ebf3cd6b..e0b85813 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod.h @@ -15,8 +15,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# @ref kai_lhs_quant_pack_qsi8dxp_f32 to dynamically quantize and pack the LHS matrix in a single step. -/// -# @ref kai_rhs_pack_nxk_qai4cxp_qau4c32s1s0 to pack the RHS NxK matrix. -/// -# @ref kai_rhs_pack_kxn_qai4cxp_qau4c32s1s0 to pack the RHS KxN matrix. +/// -# @ref kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon to pack the RHS NxK matrix. /// -------------------------------------------------- diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod_asm.S index 4861c74d..2f192212 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod_asm.S @@ -47,45 +47,41 @@ KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neo stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] - mov x15, #0x20 + mov x13, #0x20 movi v30.16b, #0xf0 mov x21, #0x8 - ldr x14, [x0, #0x40] - ldr x13, [x0, #0x38] + ldr x12, [x0, #0x38] ldr x20, [x0, #0x28] - ldr x12, [x0, #0x8] - ldr x11, [x0, #0x10] - ldr x10, [x0, #0x30] - mul x15, x14, x15 - ldr x9, [x0, #0x0] - ldr x28, [x0, #0x20] - ldr x27, [x0, #0x18] - mov x26, x20 - madd x15, x13, x15, x21 + ldr x11, [x0, #0x8] + ldr x10, [x0, #0x10] + ldr x9, [x0, #0x30] + ldr x28, [x0, #0x0] + ldr x27, [x0, #0x20] + madd x13, x12, x13, x21 + ldr x26, [x0, #0x18] + mov x25, x20 KAI_ASM_LABEL(label_1) // Row loop - mov x25, x11 mov x24, x10 - add x23, x9, x28 + mov x23, x9 + add x22, x28, x27 KAI_ASM_LABEL(label_2) // Column loop - mov x22, x12 - mov x21, x13 -KAI_ASM_LABEL(label_3) // Block loop + mov x21, x11 movi v29.4s, #0x0 movi v28.4s, #0x0 - mov x20, x14 -KAI_ASM_LABEL(label_4) // Sub block loop - ldr q27, [x25, #0x0] - ldr q26, [x25, #0x10] + mov x20, x12 +KAI_ASM_LABEL(label_3) // Sub block loop + ldr q27, [x24, #0x0] + ldr q26, [x24, #0x10] subs x20, x20, #0x1 - ld1r { v25.2d }, [x22], #0x8 - ldr q24, [x25, #0x20] - ldr q23, [x25, #0x30] - add x25, x25, #0x40 - ld1r { v22.2d }, [x22], #0x8 - ld1r { v21.2d }, [x22], #0x8 + ld1r { v25.2d }, [x21], #0x8 + ldr q24, [x24, #0x20] + ldr q23, [x24, #0x30] + add x24, x24, #0x40 + ld1r { v22.2d }, [x21], #0x8 + ld1r { v21.2d }, [x21], #0x8 shl v20.16b, v27.16b, #0x4 shl v19.16b, v26.16b, #0x4 - ld1r { v18.2d }, [x22], #0x8 + ld1r { v18.2d }, [x21], #0x8 shl v17.16b, v24.16b, #0x4 and v27.16b, v27.16b, v30.16b shl v16.16b, v23.16b, #0x4 @@ -100,47 +96,46 @@ KAI_ASM_LABEL(label_4) // Sub block loop KAI_ASM_INST(0x4e95975c) // sdot v28.4s, v26.16b, v21.16b KAI_ASM_INST(0x4e92971d) // sdot v29.4s, v24.16b, v18.16b KAI_ASM_INST(0x4e9296fc) // sdot v28.4s, v23.16b, v18.16b - bgt label_4 - ldr q17, [x25, #0x10] - add x22, x22, #0x4 + bgt label_3 + ld1r { v22.4s }, [x21] + ldr q16, [x24, #0x0] + add x21, x21, #0x4 addp v29.4s, v29.4s, v28.4s - sub x21, x21, #0x1 - ld1r { v16.4s }, [x22] - add x22, x22, #0x4 - add x25, x25, #0x20 + ld1r { v21.4s }, [x21] + ldr q20, [x24, #0x10] + add x20, x26, #0x4 + cmp x23, #0x4 + ldr q19, [x24, #0x20] + ld1r { v18.4s }, [x26] + add x24, x24, #0x30 + ld1r { v17.4s }, [x20] + fmul v16.4s, v16.4s, v22.s[0] scvtf v29.4s, v29.4s - fmul v17.4s, v17.4s, v16.4s - fmul v19.4s, v29.4s, v17.4s - cbnz x21, label_3 - ldr q18, [x25, #0x0] - ld1r { v17.4s }, [x27] - add x20, x27, #0x4 - cmp x24, #0x4 - ld1r { v16.4s }, [x20] - add x25, x25, #0x10 - fadd v19.4s, v19.4s, v18.4s - fmax v19.4s, v19.4s, v17.4s - fmin v19.4s, v19.4s, v16.4s - blt label_5 - str q19, [x9, #0x0] - b label_8 -KAI_ASM_LABEL(label_5) // Partial output - mov x20, x9 - tbz x24, #1, label_6 - st1 { v19.d }[0], [x20], #0x8 - tbz x24, #0, label_7 - st1 { v19.s }[2], [x20] + fmul v20.4s, v20.4s, v21.4s + fmla v16.4s, v29.4s, v20.4s + fadd v16.4s, v16.4s, v19.4s + fmax v16.4s, v16.4s, v18.4s + fmin v16.4s, v16.4s, v17.4s + blt label_4 + str q16, [x28, #0x0] b label_7 -KAI_ASM_LABEL(label_6) // Output block 0: partial_1_0 - st1 { v19.s }[0], [x20] -KAI_ASM_LABEL(label_7) // Output block 0: Done -KAI_ASM_LABEL(label_8) // Stores done - subs x24, x24, #0x4 - add x9, x9, #0x10 +KAI_ASM_LABEL(label_4) // Partial output + mov x20, x28 + tbz x23, #1, label_5 + st1 { v16.d }[0], [x20], #0x8 + tbz x23, #0, label_6 + st1 { v16.s }[2], [x20] + b label_6 +KAI_ASM_LABEL(label_5) // Output block 0: partial_1_0 + st1 { v16.s }[0], [x20] +KAI_ASM_LABEL(label_6) // Output block 0: Done +KAI_ASM_LABEL(label_7) // Stores done + subs x23, x23, #0x4 + add x28, x28, #0x10 bgt label_2 - subs x26, x26, #0x1 - add x12, x12, x15 - mov x9, x23 + subs x25, x25, #0x1 + add x11, x11, x13 + mov x28, x22 bgt label_1 ldp x22, x23, [sp, 16] ldp x24, x25, [sp, 32] diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c index 076ed06d..f80f2cb7 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.c @@ -10,7 +10,6 @@ #include "kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.h" #include -#include #include "kai/kai_common.h" @@ -38,19 +37,21 @@ static const size_t kai_sr = 2; // LHS format args (num. bytes per value, multiplier, zero_point (if asymmetric)) static const size_t kai_num_bytes_qvalue_lhs = 1; static const size_t kai_num_bytes_multiplier_lhs = 4; -static const size_t kai_num_bytes_asymmetry_lhs = 4; - +static const size_t kai_num_bytes_sum_lhs = 4; // RHS format args (num. bytes per value, multiplier, zero_point (if asymmetric), and reduction sum (if LHS is // asymmetric)) static const size_t kai_num_bytes_recip_qvalue_rhs = 2; static const size_t kai_num_bytes_multiplier_rhs = 4; -static const size_t kai_num_bytes_asymmetry_rhs = 4; +static const size_t kai_num_bytes_offset_rhs = 4; + // DST format args static const size_t kai_num_bytes_dst_value = 4; // Extra args static const size_t kai_num_bytes_bias = 4; static const size_t kai_k_multiple_of = 32; +static const size_t kai_bl = 32; + inline static size_t kai_get_k_roundedup(size_t k) { return kai_roundup(k, kai_k_multiple_of); } @@ -58,7 +59,7 @@ inline static size_t kai_get_k_roundedup(size_t k) { inline static size_t kai_get_lhs_packed_stride(size_t k) { const size_t k_internal = kai_get_k_roundedup(k); size_t lhs_packed_stride = - kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_asymmetry_lhs); + kai_mr * ((k_internal * kai_num_bytes_qvalue_lhs) + kai_num_bytes_multiplier_lhs + kai_num_bytes_sum_lhs); return lhs_packed_stride; } @@ -70,7 +71,7 @@ inline static size_t kai_get_rhs_packed_stride(size_t k) { rhs_packed_stride += kai_nr * kai_num_bytes_multiplier_rhs; // Since the bias is packed with the RHS matrix, the stride is adjusted with the number of bytes of the bias rhs_packed_stride += kai_nr * kai_num_bytes_bias; - rhs_packed_stride += kai_nr * kai_num_bytes_asymmetry_rhs; + rhs_packed_stride += kai_nr * kai_num_bytes_offset_rhs; return rhs_packed_stride; } @@ -138,397 +139,23 @@ void kai_run_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm( if (m == 0) { return; } - const size_t kai_bl = 32; + const size_t k_internal = kai_get_k_roundedup(k); - size_t num_blocks = k_internal / kai_bl; + const size_t num_blocks = k_internal / kai_bl; const float clamp_vals[2] = {scalar_min, scalar_max}; - __asm__ __volatile__( - "mov x12, %x[m]\n" - "mov x11, #0x80\n" - "movi v11.16b, #0xf0\n" - "mov x20, #0x20\n" - "cmp x12, #0x8\n" - "madd x11, %x[num_blocks], x11, x20\n" - "blt 10f\n" - "1:" // Row loop - "mov x10, %x[rhs_packed]\n" - "mov x9, %x[n]\n" - "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" - "2:" // Column loop - "mov x22, %x[lhs_packed]\n" - "movi v10.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "mov x21, %x[num_blocks]\n" - "movi v8.4s, #0x0\n" - "movi v7.4s, #0x0\n" - "movi v6.4s, #0x0\n" - "movi v5.4s, #0x0\n" - "add x20, x22, x11\n" - "movi v4.4s, #0x0\n" - "movi v3.4s, #0x0\n" - "3:" // Sub block loop - "ldr q2, [x10, #0x0]\n" - "ldr q1, [x10, #0x10]\n" - "subs x21, x21, #0x1\n" - "ldr q20, [x22, #0x0]\n" - "ldr q19, [x22, #0x10]\n" - "ldr q18, [x20, #0x0]\n" - "ldr q0, [x20, #0x10]\n" - "ldr q31, [x10, #0x20]\n" - "ldr q30, [x10, #0x30]\n" - "shl v17.16b, v2.16b, #0x4\n" - "shl v16.16b, v1.16b, #0x4\n" - "ldr q29, [x22, #0x20]\n" - "ldr q28, [x22, #0x30]\n" - "and v2.16b, v2.16b, v11.16b\n" - "and v1.16b, v1.16b, v11.16b\n" - "ldr q27, [x20, #0x20]\n" - "ldr q26, [x20, #0x30]\n" - "add x10, x10, #0x40\n" - "ldr q25, [x22, #0x40]\n" - "ldr q24, [x22, #0x50]\n" - ".inst 0x4e91a68a // smmla v10.4s, v20.16b, v17.16b\n" - ".inst 0x4e90a689 // smmla v9.4s, v20.16b, v16.16b\n" - "ldr q23, [x20, #0x40]\n" - "ldr q22, [x20, #0x50]\n" - ".inst 0x4e91a668 // smmla v8.4s, v19.16b, v17.16b\n" - ".inst 0x4e90a667 // smmla v7.4s, v19.16b, v16.16b\n" - "ldr q21, [x22, #0x60]\n" - "ldr q20, [x22, #0x70]\n" - ".inst 0x4e91a646 // smmla v6.4s, v18.16b, v17.16b\n" - ".inst 0x4e90a645 // smmla v5.4s, v18.16b, v16.16b\n" - "ldr q19, [x20, #0x60]\n" - "ldr q18, [x20, #0x70]\n" - ".inst 0x4e91a404 // smmla v4.4s, v0.16b, v17.16b\n" - ".inst 0x4e90a403 // smmla v3.4s, v0.16b, v16.16b\n" - "shl v17.16b, v31.16b, #0x4\n" - "shl v16.16b, v30.16b, #0x4\n" - "add x22, x22, #0x80\n" - "add x20, x20, #0x80\n" - "and v31.16b, v31.16b, v11.16b\n" - "and v30.16b, v30.16b, v11.16b\n" - ".inst 0x4e91a7aa // smmla v10.4s, v29.16b, v17.16b\n" - ".inst 0x4e90a7a9 // smmla v9.4s, v29.16b, v16.16b\n" - ".inst 0x4e91a788 // smmla v8.4s, v28.16b, v17.16b\n" - ".inst 0x4e90a787 // smmla v7.4s, v28.16b, v16.16b\n" - ".inst 0x4e91a766 // smmla v6.4s, v27.16b, v17.16b\n" - ".inst 0x4e90a765 // smmla v5.4s, v27.16b, v16.16b\n" - ".inst 0x4e91a744 // smmla v4.4s, v26.16b, v17.16b\n" - ".inst 0x4e90a743 // smmla v3.4s, v26.16b, v16.16b\n" - ".inst 0x4e82a72a // smmla v10.4s, v25.16b, v2.16b\n" - ".inst 0x4e81a729 // smmla v9.4s, v25.16b, v1.16b\n" - ".inst 0x4e82a708 // smmla v8.4s, v24.16b, v2.16b\n" - ".inst 0x4e81a707 // smmla v7.4s, v24.16b, v1.16b\n" - ".inst 0x4e82a6e6 // smmla v6.4s, v23.16b, v2.16b\n" - ".inst 0x4e81a6e5 // smmla v5.4s, v23.16b, v1.16b\n" - ".inst 0x4e82a6c4 // smmla v4.4s, v22.16b, v2.16b\n" - ".inst 0x4e81a6c3 // smmla v3.4s, v22.16b, v1.16b\n" - ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" - ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" - ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" - ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" - ".inst 0x4e9fa666 // smmla v6.4s, v19.16b, v31.16b\n" - ".inst 0x4e9ea665 // smmla v5.4s, v19.16b, v30.16b\n" - ".inst 0x4e9fa644 // smmla v4.4s, v18.16b, v31.16b\n" - ".inst 0x4e9ea643 // smmla v3.4s, v18.16b, v30.16b\n" - "bgt 3b\n" - "ldr q24, [x10, #0x10]\n" - "add x22, x22, #0x10\n" - "uzp1 v23.2d, v10.2d, v9.2d\n" - "uzp2 v22.2d, v10.2d, v9.2d\n" - "ldr q16, [x22, #0x0]\n" - "uzp1 v21.2d, v8.2d, v7.2d\n" - "uzp2 v20.2d, v8.2d, v7.2d\n" - "add x10, x10, #0x20\n" - "scvtf v23.4s, v23.4s\n" - "scvtf v22.4s, v22.4s\n" - "fmul v19.4s, v24.4s, v16.s[0]\n" - "fmul v18.4s, v24.4s, v16.s[1]\n" - "fmul v17.4s, v24.4s, v16.s[2]\n" - "scvtf v21.4s, v21.4s\n" - "fmul v16.4s, v24.4s, v16.s[3]\n" - "scvtf v20.4s, v20.4s\n" - "fmul v10.4s, v23.4s, v19.4s\n" - "fmul v9.4s, v22.4s, v18.4s\n" - "fmul v8.4s, v21.4s, v17.4s\n" - "fmul v7.4s, v20.4s, v16.4s\n" - "add x20, x20, #0x10\n" - "uzp1 v23.2d, v6.2d, v5.2d\n" - "uzp2 v22.2d, v6.2d, v5.2d\n" - "ldr q16, [x20, #0x0]\n" - "uzp1 v21.2d, v4.2d, v3.2d\n" - "uzp2 v20.2d, v4.2d, v3.2d\n" - "fmul v19.4s, v24.4s, v16.s[0]\n" - "scvtf v23.4s, v23.4s\n" - "fmul v18.4s, v24.4s, v16.s[1]\n" - "scvtf v22.4s, v22.4s\n" - "fmul v17.4s, v24.4s, v16.s[2]\n" - "scvtf v21.4s, v21.4s\n" - "fmul v16.4s, v24.4s, v16.s[3]\n" - "scvtf v20.4s, v20.4s\n" - "fmul v6.4s, v23.4s, v19.4s\n" - "fmul v5.4s, v22.4s, v18.4s\n" - "fmul v4.4s, v21.4s, v17.4s\n" - "fmul v3.4s, v20.4s, v16.4s\n" - "ldr q18, [x10, #0x0]\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "add x20, %x[clamp_vals], #0x4\n" - "cmp x9, #0x4\n" - "ld1r { v16.4s }, [x20]\n" - "add x10, x10, #0x10\n" - "fadd v10.4s, v10.4s, v18.4s\n" - "fadd v9.4s, v9.4s, v18.4s\n" - "fadd v8.4s, v8.4s, v18.4s\n" - "fadd v7.4s, v7.4s, v18.4s\n" - "fadd v6.4s, v6.4s, v18.4s\n" - "fadd v5.4s, v5.4s, v18.4s\n" - "fadd v4.4s, v4.4s, v18.4s\n" - "fadd v3.4s, v3.4s, v18.4s\n" - "fmax v10.4s, v10.4s, v17.4s\n" - "fmax v9.4s, v9.4s, v17.4s\n" - "fmax v8.4s, v8.4s, v17.4s\n" - "fmax v7.4s, v7.4s, v17.4s\n" - "fmax v6.4s, v6.4s, v17.4s\n" - "fmax v5.4s, v5.4s, v17.4s\n" - "fmax v4.4s, v4.4s, v17.4s\n" - "fmax v3.4s, v3.4s, v17.4s\n" - "fmin v10.4s, v10.4s, v16.4s\n" - "fmin v9.4s, v9.4s, v16.4s\n" - "fmin v8.4s, v8.4s, v16.4s\n" - "fmin v7.4s, v7.4s, v16.4s\n" - "fmin v6.4s, v6.4s, v16.4s\n" - "fmin v5.4s, v5.4s, v16.4s\n" - "fmin v4.4s, v4.4s, v16.4s\n" - "fmin v3.4s, v3.4s, v16.4s\n" - "blt 6f\n" - "mov x20, %x[dst]\n" - "str q10, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q9, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q8, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q7, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q6, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q5, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q4, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q3, [x20, #0x0]\n" - "b 9f\n" - "6:" // Partial output - "mov x27, %x[dst]\n" - "add x26, x27, %x[dst_stride_row], LSL #2\n" - "add x25, x26, %x[dst_stride_row], LSL #1\n" - "add x24, x26, %x[dst_stride_row]\n" - "add x23, x25, %x[dst_stride_row]\n" - "add x22, x27, %x[dst_stride_row], LSL #1\n" - "add x21, x27, %x[dst_stride_row]\n" - "add x20, x22, %x[dst_stride_row]\n" - "tbz x9, #1, 7f\n" - "st1 { v3.d }[0], [x23], #0x8\n" - "st1 { v4.d }[0], [x25], #0x8\n" - "st1 { v5.d }[0], [x24], #0x8\n" - "st1 { v6.d }[0], [x26], #0x8\n" - "st1 { v7.d }[0], [x20], #0x8\n" - "st1 { v8.d }[0], [x22], #0x8\n" - "st1 { v9.d }[0], [x21], #0x8\n" - "st1 { v10.d }[0], [x27], #0x8\n" - "tbz x9, #0, 8f\n" - "st1 { v3.s }[2], [x23]\n" - "st1 { v4.s }[2], [x25]\n" - "st1 { v5.s }[2], [x24]\n" - "st1 { v6.s }[2], [x26]\n" - "st1 { v7.s }[2], [x20]\n" - "st1 { v8.s }[2], [x22]\n" - "st1 { v9.s }[2], [x21]\n" - "st1 { v10.s }[2], [x27]\n" - "b 8f\n" - "7:" // Output block 0: partial_1_0 - "st1 { v3.s }[0], [x23]\n" - "st1 { v4.s }[0], [x25]\n" - "st1 { v5.s }[0], [x24]\n" - "st1 { v6.s }[0], [x26]\n" - "st1 { v7.s }[0], [x20]\n" - "st1 { v8.s }[0], [x22]\n" - "st1 { v9.s }[0], [x21]\n" - "st1 { v10.s }[0], [x27]\n" - "8:" // Output block 0: Done - "9:" // Output stage exit - "subs x9, x9, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 2b\n" - "mov x20, #0x2\n" - "sub x12, x12, #0x8\n" - "cmp x12, #0x8\n" - "mov %x[dst], x28\n" - "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" - "bge 1b\n" - "10:" // Row loop skip - "cbz x12, 19f\n" - "11:" // Row tail: Row loop - "mov x26, %x[rhs_packed]\n" - "mov x25, %x[n]\n" - "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" - "12:" // Row tail: Column loop - "mov x22, %x[lhs_packed]\n" - "movi v10.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "mov x20, %x[num_blocks]\n" - "movi v8.4s, #0x0\n" - "movi v7.4s, #0x0\n" - "13:" // Row tail: Sub block loop - "ldr q31, [x26, #0x0]\n" - "ldr q30, [x26, #0x10]\n" - "subs x20, x20, #0x1\n" - "ldr q29, [x22, #0x0]\n" - "ldr q28, [x22, #0x10]\n" - "ldr q27, [x26, #0x20]\n" - "ldr q26, [x26, #0x30]\n" - "add x26, x26, #0x40\n" - "ldr q25, [x22, #0x20]\n" - "ldr q24, [x22, #0x30]\n" - "shl v23.16b, v31.16b, #0x4\n" - "shl v22.16b, v30.16b, #0x4\n" - "ldr q21, [x22, #0x40]\n" - "ldr q20, [x22, #0x50]\n" - "and v31.16b, v31.16b, v11.16b\n" - "and v30.16b, v30.16b, v11.16b\n" - "ldr q19, [x22, #0x60]\n" - "ldr q18, [x22, #0x70]\n" - "shl v17.16b, v27.16b, #0x4\n" - "shl v16.16b, v26.16b, #0x4\n" - ".inst 0x4e97a7aa // smmla v10.4s, v29.16b, v23.16b\n" - ".inst 0x4e96a7a9 // smmla v9.4s, v29.16b, v22.16b\n" - "and v27.16b, v27.16b, v11.16b\n" - "add x22, x22, #0x80\n" - ".inst 0x4e97a788 // smmla v8.4s, v28.16b, v23.16b\n" - ".inst 0x4e96a787 // smmla v7.4s, v28.16b, v22.16b\n" - "and v26.16b, v26.16b, v11.16b\n" - ".inst 0x4e91a72a // smmla v10.4s, v25.16b, v17.16b\n" - ".inst 0x4e90a729 // smmla v9.4s, v25.16b, v16.16b\n" - ".inst 0x4e91a708 // smmla v8.4s, v24.16b, v17.16b\n" - ".inst 0x4e90a707 // smmla v7.4s, v24.16b, v16.16b\n" - ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" - ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" - ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" - ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" - ".inst 0x4e9ba66a // smmla v10.4s, v19.16b, v27.16b\n" - ".inst 0x4e9aa669 // smmla v9.4s, v19.16b, v26.16b\n" - ".inst 0x4e9ba648 // smmla v8.4s, v18.16b, v27.16b\n" - ".inst 0x4e9aa647 // smmla v7.4s, v18.16b, v26.16b\n" - "bgt 13b\n" - "ldr q24, [x26, #0x10]\n" - "add x22, x22, #0x10\n" - "uzp1 v23.2d, v10.2d, v9.2d\n" - "uzp2 v22.2d, v10.2d, v9.2d\n" - "ldr q16, [x22, #0x0]\n" - "uzp1 v21.2d, v8.2d, v7.2d\n" - "uzp2 v20.2d, v8.2d, v7.2d\n" - "add x26, x26, #0x20\n" - "scvtf v23.4s, v23.4s\n" - "scvtf v22.4s, v22.4s\n" - "fmul v19.4s, v24.4s, v16.s[0]\n" - "fmul v18.4s, v24.4s, v16.s[1]\n" - "fmul v17.4s, v24.4s, v16.s[2]\n" - "scvtf v21.4s, v21.4s\n" - "fmul v16.4s, v24.4s, v16.s[3]\n" - "scvtf v20.4s, v20.4s\n" - "fmul v10.4s, v23.4s, v19.4s\n" - "fmul v9.4s, v22.4s, v18.4s\n" - "fmul v8.4s, v21.4s, v17.4s\n" - "fmul v7.4s, v20.4s, v16.4s\n" - "ldr q18, [x26, #0x0]\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "add x20, %x[clamp_vals], #0x4\n" - "cmp x25, #0x4\n" - "ld1r { v16.4s }, [x20]\n" - "add x26, x26, #0x10\n" - "fadd v10.4s, v10.4s, v18.4s\n" - "fadd v9.4s, v9.4s, v18.4s\n" - "fadd v8.4s, v8.4s, v18.4s\n" - "fadd v7.4s, v7.4s, v18.4s\n" - "fmax v10.4s, v10.4s, v17.4s\n" - "fmax v9.4s, v9.4s, v17.4s\n" - "fmax v8.4s, v8.4s, v17.4s\n" - "fmax v7.4s, v7.4s, v17.4s\n" - "fmin v10.4s, v10.4s, v16.4s\n" - "fmin v9.4s, v9.4s, v16.4s\n" - "fmin v8.4s, v8.4s, v16.4s\n" - "fmin v7.4s, v7.4s, v16.4s\n" - "blt 15f\n" - "mov x20, %x[dst]\n" - "cmp x12, #0x1\n" - "str q10, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 18f\n" - "cmp x12, #0x2\n" - "str q9, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 18f\n" - "cmp x12, #0x3\n" - "str q8, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 18f\n" - "str q7, [x20, #0x0]\n" - "b 18f\n" - "15:" // Row tail: Partial output - "mov x23, %x[dst]\n" - "cmp x12, #0x1\n" - "add x22, x23, %x[dst_stride_row]\n" - "csel x22, x22, x23, GT\n" - "cmp x12, #0x2\n" - "add x21, x23, %x[dst_stride_row], LSL #1\n" - "csel x21, x21, x22, GT\n" - "cmp x12, #0x3\n" - "add x20, x21, %x[dst_stride_row]\n" - "csel x20, x20, x21, GT\n" - "tbz x25, #1, 16f\n" - "st1 { v7.d }[0], [x20], #0x8\n" - "st1 { v8.d }[0], [x21], #0x8\n" - "st1 { v9.d }[0], [x22], #0x8\n" - "st1 { v10.d }[0], [x23], #0x8\n" - "tbz x25, #0, 17f\n" - "st1 { v7.s }[2], [x20]\n" - "st1 { v8.s }[2], [x21]\n" - "st1 { v9.s }[2], [x22]\n" - "st1 { v10.s }[2], [x23]\n" - "b 17f\n" - "16:" // Row tail: Output block 0: partial_1_0 - "st1 { v7.s }[0], [x20]\n" - "st1 { v8.s }[0], [x21]\n" - "st1 { v9.s }[0], [x22]\n" - "st1 { v10.s }[0], [x23]\n" - "17:" // Row tail: Output block 0: Done - "18:" // Row tail: Output stage exit - "subs x25, x25, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 12b\n" - "subs x12, x12, #0x4\n" - "add %x[lhs_packed], %x[lhs_packed], x11\n" - "mov %x[dst], x24\n" - "bgt 11b\n" - "19:" // Row tail: Row loop skip - : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) - : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), - [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) - : "cc", "memory", "v0", "v1", "v10", "v11", "v16", "v17", "v18", "v19", "v2", "v20", "v21", "v22", "v23", "v24", - "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v31", "v4", "v5", "v6", "v7", "v8", "v9", "x10", "x11", - "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x9"); - // KernelArgs args; - - // args.dst = dst; - // args.lhs_packed = lhs_packed; - // args.rhs_packed = rhs_packed; - // args.clamp_vals = clamp_vals; - // args.dst_stride_row = dst_stride_row; - // args.m = m; - // args.n = n; - - // kai_kernel_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(&args); + KernelArgs args; + + args.dst = dst; + args.lhs_packed = lhs_packed; + args.rhs_packed = rhs_packed; + args.clamp_vals = clamp_vals; + args.dst_stride_row = dst_stride_row; + args.m = m; + args.n = n; + args.num_blocks = num_blocks; + + kai_kernel_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm(&args); } #endif // Architectural features check. diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.h index b59c69e2..04e32914 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm.h @@ -15,8 +15,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# @ref kai_lhs_quant_pack_qsi8dxp_f32 to dynamically quantize and pack the LHS matrix in a single step. -/// -# @ref kai_rhs_pack_nxk_qai4cxp_ to pack the RHS NxK matrix. -/// -# @ref kai_rhs_pack_kxn_qai4cxp_ to pack the RHS KxN matrix. +/// -# @ref kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon to pack the RHS NxK matrix. /// -------------------------------------------------- diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm_asm.S index 9d867705..67d5d80a 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm_asm.S @@ -146,32 +146,43 @@ KAI_ASM_LABEL(label_3) // Sub block loop KAI_ASM_INST(0x4e9fa644) // smmla v4.4s, v18.16b, v31.16b KAI_ASM_INST(0x4e9ea643) // smmla v3.4s, v18.16b, v30.16b bgt label_3 - ldr q24, [x10, #0x10] + ld1 { v17.4s }, [x22] + ldr q25, [x10, #0x0] add x22, x22, #0x10 uzp1 v23.2d, v10.2d, v9.2d - uzp2 v22.2d, v10.2d, v9.2d + ldr q24, [x10, #0x10] ldr q16, [x22, #0x0] + uzp2 v22.2d, v10.2d, v9.2d uzp1 v21.2d, v8.2d, v7.2d uzp2 v20.2d, v8.2d, v7.2d add x10, x10, #0x20 - scvtf v23.4s, v23.4s - scvtf v22.4s, v22.4s + fmul v10.4s, v25.4s, v17.s[0] + fmul v9.4s, v25.4s, v17.s[1] + fmul v8.4s, v25.4s, v17.s[2] + fmul v7.4s, v25.4s, v17.s[3] fmul v19.4s, v24.4s, v16.s[0] + scvtf v23.4s, v23.4s fmul v18.4s, v24.4s, v16.s[1] + scvtf v22.4s, v22.4s fmul v17.4s, v24.4s, v16.s[2] scvtf v21.4s, v21.4s fmul v16.4s, v24.4s, v16.s[3] scvtf v20.4s, v20.4s - fmul v10.4s, v23.4s, v19.4s - fmul v9.4s, v22.4s, v18.4s - fmul v8.4s, v21.4s, v17.4s - fmul v7.4s, v20.4s, v16.4s + fmla v10.4s, v23.4s, v19.4s + fmla v9.4s, v22.4s, v18.4s + fmla v8.4s, v21.4s, v17.4s + fmla v7.4s, v20.4s, v16.4s + ld1 { v17.4s }, [x20] add x20, x20, #0x10 uzp1 v23.2d, v6.2d, v5.2d uzp2 v22.2d, v6.2d, v5.2d ldr q16, [x20, #0x0] uzp1 v21.2d, v4.2d, v3.2d uzp2 v20.2d, v4.2d, v3.2d + fmul v6.4s, v25.4s, v17.s[0] + fmul v5.4s, v25.4s, v17.s[1] + fmul v4.4s, v25.4s, v17.s[2] + fmul v3.4s, v25.4s, v17.s[3] fmul v19.4s, v24.4s, v16.s[0] scvtf v23.4s, v23.4s fmul v18.4s, v24.4s, v16.s[1] @@ -180,10 +191,10 @@ KAI_ASM_LABEL(label_3) // Sub block loop scvtf v21.4s, v21.4s fmul v16.4s, v24.4s, v16.s[3] scvtf v20.4s, v20.4s - fmul v6.4s, v23.4s, v19.4s - fmul v5.4s, v22.4s, v18.4s - fmul v4.4s, v21.4s, v17.4s - fmul v3.4s, v20.4s, v16.4s + fmla v6.4s, v23.4s, v19.4s + fmla v5.4s, v22.4s, v18.4s + fmla v4.4s, v21.4s, v17.4s + fmla v3.4s, v20.4s, v16.4s ldr q18, [x10, #0x0] ld1r { v17.4s }, [x11] add x20, x11, #0x4 @@ -334,26 +345,32 @@ KAI_ASM_LABEL(label_13) // Row tail: Sub block loop KAI_ASM_INST(0x4e9ba648) // smmla v8.4s, v18.16b, v27.16b KAI_ASM_INST(0x4e9aa647) // smmla v7.4s, v18.16b, v26.16b bgt label_13 - ldr q24, [x26, #0x10] + ld1 { v18.4s }, [x22] + ldr q17, [x26, #0x0] add x22, x22, #0x10 - uzp1 v23.2d, v10.2d, v9.2d - uzp2 v22.2d, v10.2d, v9.2d + uzp1 v24.2d, v10.2d, v9.2d + ldr q23, [x26, #0x10] ldr q16, [x22, #0x0] + uzp2 v22.2d, v10.2d, v9.2d uzp1 v21.2d, v8.2d, v7.2d uzp2 v20.2d, v8.2d, v7.2d add x26, x26, #0x20 - scvtf v23.4s, v23.4s + fmul v10.4s, v17.4s, v18.s[0] + fmul v9.4s, v17.4s, v18.s[1] + fmul v8.4s, v17.4s, v18.s[2] + fmul v7.4s, v17.4s, v18.s[3] + fmul v19.4s, v23.4s, v16.s[0] + scvtf v24.4s, v24.4s + fmul v18.4s, v23.4s, v16.s[1] scvtf v22.4s, v22.4s - fmul v19.4s, v24.4s, v16.s[0] - fmul v18.4s, v24.4s, v16.s[1] - fmul v17.4s, v24.4s, v16.s[2] + fmul v17.4s, v23.4s, v16.s[2] scvtf v21.4s, v21.4s - fmul v16.4s, v24.4s, v16.s[3] + fmul v16.4s, v23.4s, v16.s[3] scvtf v20.4s, v20.4s - fmul v10.4s, v23.4s, v19.4s - fmul v9.4s, v22.4s, v18.4s - fmul v8.4s, v21.4s, v17.4s - fmul v7.4s, v20.4s, v16.4s + fmla v10.4s, v24.4s, v19.4s + fmla v9.4s, v22.4s, v18.4s + fmla v8.4s, v21.4s, v17.4s + fmla v7.4s, v20.4s, v16.4s ldr q18, [x26, #0x0] ld1r { v17.4s }, [x11] add x20, x11, #0x4 -- GitLab From d709897b6a897090a94d97e5e0bb495e54818c8f Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 10 Apr 2025 17:17:41 +0100 Subject: [PATCH 09/15] Clean up unit tests and packing kernels Signed-off-by: Anitha Raj --- ..._lhs_quant_pack_qsi8dxpscalef32_f32_neon.c | 41 ++++++------ ..._nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h | 4 +- .../matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp | 63 +++++++++++++------ 3 files changed, 66 insertions(+), 42 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c index d38b27d2..16b28b39 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c @@ -17,8 +17,8 @@ #include "kai/kai_common.h" -static const size_t kai_num_bytes_per_multiplier = sizeof(float); -static const size_t kai_num_bytes_per_offset = sizeof(int32_t); +static const size_t kai_num_bytes_sum = sizeof(float); +static const size_t kai_num_bytes_multiplier = sizeof(float); inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { // Since we pack a float and int32 value at the end of the row, @@ -32,7 +32,7 @@ inline static size_t kai_get_lhs_packed_stride(size_t k, size_t mr, size_t kr, s KAI_ASSERT((k_internal % 2) == 0); - return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); + return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier + kai_num_bytes_sum); } size_t kai_get_m_step_lhs_quant_pack_qsi8dxpscalef32_f32_neon(size_t mr) { @@ -66,18 +66,18 @@ void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( return; } - const size_t num_rows = m; - const float* src_ptr = lhs; + const size_t lhs_packed_stride = kai_get_lhs_packed_stride(k, mr, kr, sr); + const size_t num_rows = m; - const size_t dst_stride = kai_get_lhs_packed_stride(k, mr, kr, sr); const size_t k_internal = kai_k_roundedup(k, kr, sr); const int32_t k_block_len = (int32_t)(kr / sr); for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { + const size_t dst_idx = ((row_idx + m_idx_start) % mr); + uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_idx * k_block_len * sizeof(int8_t)); float absmax = -FLT_MAX; - // Find min/max for each channel + // Find absmax for each channel int32_t k_idx = 0; -#if defined(__aarch64__) float32x4_t vabsmax = vdupq_n_f32(-FLT_MAX); for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx); @@ -85,21 +85,21 @@ void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( // Calculate the max vabsmax = vmaxq_f32(vabsq_f32(src0_0), vmaxq_f32(vabsmax, vabsq_f32(src0_1))); } - // Get the max/min + // Get the absmax absmax = vmaxvq_f32(vabsmax); -#endif + for (; k_idx < (int32_t)k; ++k_idx) { const float src0_0 = *(src_ptr + (size_t)k_idx); absmax = KAI_MAX(src0_0, absmax); } - // Maximum int8 values + + // Maximum/minimum int8 values const float qmax = (float)INT8_MAX; + + // Get the scale and reciprocal to quantize const float scale0 = absmax == 0.0F ? 0.0F : qmax / absmax; - // Reciprocal to quantize const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; - const size_t dst_x = ((row_idx + m_idx_start) % mr); - uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); int32_t qsum = 0; // Quantize the channels @@ -127,15 +127,15 @@ void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); - dst_ptr += dst_x * kai_num_bytes_per_offset; + dst_ptr += dst_idx * kai_num_bytes_sum; // LHS offset at the beginning of the row *((float*)(dst_ptr)) = (float)(((float)qsum) * recip_scale0); - // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier - KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); + // Assuming the same sizeof() for kai_num_bytes_sum and kai_num_bytes_multiplier + KAI_ASSERT(kai_num_bytes_sum == kai_num_bytes_multiplier); - dst_ptr += mr * kai_num_bytes_per_offset; + dst_ptr += mr * kai_num_bytes_sum; // Store the scale quantization params *((float*)(dst_ptr)) = recip_scale0; @@ -144,9 +144,8 @@ void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( // Move to the next row if we have interleaved all Mr rows if ((((row_idx + 1) + m_idx_start) % mr) == 0) { - lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); + lhs_packed = (void*)((uint8_t*)lhs_packed + lhs_packed_stride); } } } - -#endif +#endif // Architectural features check. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h index 2a84598a..92c77a72 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h @@ -40,8 +40,8 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon(size_ size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( size_t k, size_t nr, size_t kr, size_t sr); -/// Gets the offset in bytes for the packed RHS matrix, which contains the packed 4-bit quantized symmetric per-channel -/// (qsu4cx) values. +/// Gets the offset in bytes for the packed RHS matrix, which contains the packed 4-bit quantized asymmetric per-channel +/// (qai4cx) values. /// /// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. /// @param[in] k In the RHS matrix (not packed), K is the number of columns. diff --git a/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp b/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp index bef78512..56741449 100644 --- a/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp @@ -29,6 +29,7 @@ #include "test/common/round.hpp" #include "test/common/test_suite.hpp" #include "test/reference/cast.hpp" +#include "test/reference/clamp.hpp" #include "test/reference/fill.hpp" #include "test/reference/matmul.hpp" #include "test/reference/pack.hpp" @@ -43,17 +44,17 @@ static const std::array {}; +class MatMulTest_f32_qsi8dxp_qai4cxp : public ::testing::TestWithParam {}; -TEST_P(MatMulTest_f32_qsi8dxp_qai4cxp, EndToEnd_NullBias) { - const auto& [variant_index, matmul_shape, portion] = GetParam(); +TEST_P(MatMulTest_f32_qsi8dxp_qai4cxp, EndToEnd) { + const auto& [variant_index, matmul_shape, portion, has_bias] = GetParam(); const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qsi8dxp_qai4cxp.at(variant_index); if (ukernel_variant.fn_is_supported && !ukernel_variant.fn_is_supported()) { - GTEST_SKIP() << "Kernel not supported"; + GTEST_SKIP() << "CPU features are not supported by current CPU"; } - const std::uint64_t seed = 0; + const std::uint32_t seed = 0; const size_t M = matmul_shape.m; const size_t N = matmul_shape.n; @@ -76,13 +77,17 @@ TEST_P(MatMulTest_f32_qsi8dxp_qai4cxp, EndToEnd_NullBias) { const auto rect = portion.compute_portion(M, N, m_step, n_step); if (rect.height() == 0 || rect.width() == 0) { - GTEST_SKIP() << "Test Portion size is 0!"; + GTEST_SKIP() << "Empty dimension of matrix(" << rect.width() << "," << rect.height() << ")"; } - // Generates input data + // Generates input data. const auto ref_lhs = fill_random(M * K, seed + 0); const auto ref_rhs = fill_random(N * K, seed + 1); + std::vector ref_biases; + if (has_bias) { + ref_biases = fill_random(N, seed + 2); + } // Runs the reference implementation. // * Quantizes the LHS matrix using 8-bit symmetric quantization. // * Quantizes the RHS matrix using 8-bit asymmetric quantization. @@ -92,10 +97,16 @@ TEST_P(MatMulTest_f32_qsi8dxp_qai4cxp, EndToEnd_NullBias) { const auto [ref_rhs_qai4, ref_rhs_scales, ref_rhs_zero_points] = quantize_asymmetric_per_block_dynamic(ref_rhs.data(), N, K, K); - const auto ref_dst = matmul_clamp_nt_t( - M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), nullptr, K, ref_rhs_qai4.data(), ref_rhs_scales.data(), - ref_rhs_zero_points.data(), K, nullptr, std::numeric_limits::lowest(), - std::numeric_limits::max()); + const auto ref_dst_no_clamp = + matmul_nt_t_quantized( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), nullptr, 1, K, ref_rhs_qai4.data(), + ref_rhs_scales.data(), ref_rhs_zero_points.data(), 1, K, has_bias ? ref_biases.data() : nullptr, nullptr, + nullptr, 1); + + // Clamps the reference output. + const auto clamp_ratio = 0.8F; + const auto [clamp_min, clamp_max] = find_clamp_range(ref_dst_no_clamp.data(), M * N, clamp_ratio); + const auto ref_dst = clamp(ref_dst_no_clamp.data(), M * N, clamp_min, clamp_max); // Runs the LHS packing micro-kernel. const auto lhs_start_row = rect.start_row(); @@ -123,7 +134,8 @@ TEST_P(MatMulTest_f32_qsi8dxp_qai4cxp, EndToEnd_NullBias) { -reinterpret_cast(ref_rhs_zero_points.data())[i] * reinterpret_cast(ref_rhs_scales.data())[i]; } - // Runs the RHS packing micro-kernel. + + // Cast to unsigned int const auto ref_rhs_qau4 = cast_qsu4_qsi4(ref_rhs_qai4.data(), N * K); const auto imp_packed_rhs_size = @@ -135,10 +147,14 @@ TEST_P(MatMulTest_f32_qsi8dxp_qai4cxp, EndToEnd_NullBias) { auto rhs_matmul_offset = ukernel_variant.interface.get_rhs_packed_offset(rhs_start_row, K); ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); - const kai_rhs_pack_nxk_qai4cxp_params params{.lhs_zero_point = 1, .rhs_zero_point = 8}; + // Runs the RHS packing micro-kernel. + kai_rhs_pack_nxk_qai4cxp_params params{}; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + kai_run_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( - 1, N, K, nr, kr, sr, ref_rhs_qau4.data(), ref_rhs_zp_f32.data(), nullptr, ref_rhs_scales.data(), - imp_packed_rhs.data(), 0, ¶ms); + 1, N, K, nr, kr, sr, ref_rhs_qau4.data(), ref_rhs_zp_f32.data(), has_bias ? ref_biases.data() : nullptr, + ref_rhs_scales.data(), imp_packed_rhs.data(), 0, ¶ms); const auto dst_stride_row = N * sizeof(float); const auto dst_stride_col = sizeof(float); @@ -154,7 +170,7 @@ TEST_P(MatMulTest_f32_qsi8dxp_qai4cxp, EndToEnd_NullBias) { ukernel_variant.interface.run_matmul( rect.height(), rect.width(), K, imp_packed_lhs.data() + lhs_matmul_offset, imp_packed_rhs.data() + rhs_matmul_offset, reinterpret_cast(imp_dst.data() + dst_offset), - dst_stride_row, dst_stride_col, std::numeric_limits::lowest(), std::numeric_limits::max()); + dst_stride_row, dst_stride_col, clamp_min, clamp_max); // Compares the output of the micro-kernels against the output of the reference implementation for the portion // tested. @@ -178,9 +194,15 @@ INSTANTIATE_TEST_SUITE_P( testing::Range(0, variants_kai_matmul_clamp_f32_qsi8dxp_qai4cxp.size()), testing::Values( MatMulShape{1, 2, 32}, // + MatMulShape{1, 3, 32}, // + MatMulShape{1, 4, 32}, // + MatMulShape{1, 5, 32}, // + MatMulShape{3, 3, 32}, // + MatMulShape{4, 4, 32}, // + MatMulShape{5, 5, 32}, // MatMulShape{32, 64, 64}, // MatMulShape{16, 32, 64}, // - MatMulShape{4, 4, 32}, // + MatMulShape{8, 32, 64}, // MatMulShape{15, 32, 32}, // MatMulShape{77, 99, 64}), testing::Values( @@ -191,19 +213,22 @@ INSTANTIATE_TEST_SUITE_P( MatrixPortion(0.75, 0.75, 1, 1), // Bottom-right corner. MatrixPortion(0.75, 0, 1, 1), // Partial rows MatrixPortion(0.4, 0.5, 0.6, 0.8) // Somewhere Middle - )), + ), + testing::Bool()), [](const auto& info) { const auto variant_idx = std::get<0>(info.param); const std::string name{variants_kai_matmul_clamp_f32_qsi8dxp_qai4cxp.at(variant_idx).name}; const auto shape = std::get(info.param); const auto portion = std::get<2>(info.param); + const auto has_bias = std::get<3>(info.param); std::stringstream sstream; sstream << name << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k // << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) // << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) // << "__PortionHeight_" << static_cast(portion.height() * 1000) // - << "__PortionWidth_" << static_cast(portion.width() * 1000); + << "__PortionWidth_" << static_cast(portion.width() * 1000) // + << (has_bias ? "__Bias" : ""); return sstream.str(); }); -- GitLab From 85cca0d5dd4c49919f4d400a9a680051be97f1df Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 10 Apr 2025 17:30:10 +0100 Subject: [PATCH 10/15] Update Changelog and Build.bazel files Signed-off-by: Anitha Raj --- CHANGELOG.md | 2 ++ CMakeLists.txt | 2 +- kai/ukernels/matmul/BUILD.bazel | 4 ++++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b5d8567..3c963468 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Matrix multiplication (1xN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F32 output, optimized for FEAT_DotProd. - Matrix multiplication (MxN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_I8MM. - Matrix multiplication (1xN) Micro-kernels of QSI8D32 LHS and QAI4C32 RHS with F16 output, optimized for FEAT_DotProd. + - Matrix multiplication (MxN) Micro-kernels of QSI8DX LHS and QAI4CX RHS with F32 output, optimized for FEAT_I8MM. + - Matrix multiplication (1xN) Micro-kernels of QSI8DX LHS and QAI4CX RHS with F32 output, optimized for FEAT_DotProd. ## v1.6.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index b14daa97..9acfeac1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -137,6 +137,7 @@ set(KLEIDIAI_FILES_NEON kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla_asm.S kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c + kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.c @@ -144,7 +145,6 @@ set(KLEIDIAI_FILES_NEON kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon.c - kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.c ) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 79bfac8b..ba1dc5bd 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -30,10 +30,12 @@ SCALAR_KERNELS = [ NEON_KERNELS = [ "pack/kai_lhs_quant_pack_qsi8d32p_f32_neon", "pack/kai_lhs_quant_pack_qsi8d32pscalef32_f32_neon", + "pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon", "pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon", "pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", "pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon", "pack/kai_rhs_pack_nxk_qai4c32p_qau4c32s0s1_f32_f32_f32_neon", + "pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon", "pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", "pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon", "pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", @@ -105,6 +107,7 @@ DOTPROD_KERNELS_ASM = [ "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod", "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qai4c32p4x8_1x4_neon_dotprod", + "matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp1x8_qai4cxp4x8_1x4_neon_dotprod", ] # buildifier: keep sorted @@ -125,6 +128,7 @@ I8MM_KERNELS_ASM = [ "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm", "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", "matmul_clamp_f32_qsi8d32p_qai4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qai4c32p4x8_8x4_neon_i8mm", + "matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp4x8_qai4cxp4x8_8x4_neon_i8mm", ] # buildifier: keep sorted -- GitLab From ddaf79d6ad9d663061c1624846d893547e13d334 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 10 Apr 2025 18:08:23 +0100 Subject: [PATCH 11/15] Clang tidy fix Signed-off-by: Anitha Raj --- .../matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c index 16b28b39..b98d2509 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c @@ -130,7 +130,7 @@ void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( dst_ptr += dst_idx * kai_num_bytes_sum; // LHS offset at the beginning of the row - *((float*)(dst_ptr)) = (float)(((float)qsum) * recip_scale0); + *((float*)(dst_ptr)) = ((float)qsum) * recip_scale0; // Assuming the same sizeof() for kai_num_bytes_sum and kai_num_bytes_multiplier KAI_ASSERT(kai_num_bytes_sum == kai_num_bytes_multiplier); -- GitLab From d41ee243f9e6b5ea659715ad2113dc481ddb5888 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 11 Apr 2025 01:12:05 +0100 Subject: [PATCH 12/15] Add odd shapes to unit test Signed-off-by: Anitha Raj --- test/reference/quantize.cpp | 1 - .../matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp | 15 ++++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp index 498d26a8..025f677b 100644 --- a/test/reference/quantize.cpp +++ b/test/reference/quantize.cpp @@ -294,5 +294,4 @@ quantize_asymmetric_per_block_dynamic( template std::tuple, std::vector, std::vector> quantize_asymmetric_per_block_dynamic( const void* src, size_t height, size_t width, size_t quant_width); - } // namespace kai::test diff --git a/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp b/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp index 56741449..2a82bd5d 100644 --- a/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp @@ -33,6 +33,7 @@ #include "test/reference/fill.hpp" #include "test/reference/matmul.hpp" #include "test/reference/pack.hpp" +#include "test/reference/pad.hpp" #include "test/reference/quantize.hpp" namespace kai::test { @@ -137,7 +138,8 @@ TEST_P(MatMulTest_f32_qsi8dxp_qai4cxp, EndToEnd) { // Cast to unsigned int const auto ref_rhs_qau4 = cast_qsu4_qsi4(ref_rhs_qai4.data(), N * K); - + const auto ref_rhs_qau4_padded = pad_row( + ref_rhs_qau4.data(), N, K, K, round_up_multiple(K, 2), round_up_division(N * round_up_multiple(K, 2), 2)); const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon(N, K, nr, kr, sr); std::vector imp_packed_rhs(imp_packed_rhs_size); @@ -153,7 +155,7 @@ TEST_P(MatMulTest_f32_qsi8dxp_qai4cxp, EndToEnd) { params.rhs_zero_point = 8; kai_run_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon( - 1, N, K, nr, kr, sr, ref_rhs_qau4.data(), ref_rhs_zp_f32.data(), has_bias ? ref_biases.data() : nullptr, + 1, N, K, nr, kr, sr, ref_rhs_qau4_padded.data(), ref_rhs_zp_f32.data(), has_bias ? ref_biases.data() : nullptr, ref_rhs_scales.data(), imp_packed_rhs.data(), 0, ¶ms); const auto dst_stride_row = N * sizeof(float); @@ -199,12 +201,11 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{1, 5, 32}, // MatMulShape{3, 3, 32}, // MatMulShape{4, 4, 32}, // - MatMulShape{5, 5, 32}, // - MatMulShape{32, 64, 64}, // - MatMulShape{16, 32, 64}, // + MatMulShape{51, 15, 25}, // + MatMulShape{15, 35, 31}, // MatMulShape{8, 32, 64}, // - MatMulShape{15, 32, 32}, // - MatMulShape{77, 99, 64}), + MatMulShape{15, 32, 33}, // + MatMulShape{77, 99, 65}), testing::Values( MatrixPortion(0, 0, 1, 1), // Full matrix. MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. -- GitLab From 1d27fb41bfb65c5938bff1814ba58a0b78e01bb4 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 11 Apr 2025 01:13:07 +0100 Subject: [PATCH 13/15] Optimize LHS packing micro-kernel Signed-off-by: Anitha Raj --- ..._lhs_quant_pack_qsi8dxpscalef32_f32_neon.c | 57 +++++++++++++++---- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c index b98d2509..5dc8c2c5 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c @@ -61,20 +61,24 @@ void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* restrict lhs, size_t lhs_stride, void* restrict lhs_packed) { KAI_ASSERT((kr % sr) == 0); + KAI_ASSUME((kr / sr) % 8 == 0); if (m == 0) { return; } - const float* src_ptr = lhs; - const size_t lhs_packed_stride = kai_get_lhs_packed_stride(k, mr, kr, sr); const size_t num_rows = m; + const float* src_ptr = lhs; + + const size_t lhs_packed_stride = kai_get_lhs_packed_stride(k, mr, kr, sr); const size_t k_internal = kai_k_roundedup(k, kr, sr); const int32_t k_block_len = (int32_t)(kr / sr); + + const int32_t num_blocks_k = (int32_t)(k / k_block_len); + const int32_t num_blocks_k_internal = (int32_t)(k_internal / k_block_len); + for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { - const size_t dst_idx = ((row_idx + m_idx_start) % mr); - uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_idx * k_block_len * sizeof(int8_t)); float absmax = -FLT_MAX; // Find absmax for each channel int32_t k_idx = 0; @@ -100,14 +104,47 @@ void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( const float scale0 = absmax == 0.0F ? 0.0F : qmax / absmax; const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; - int32_t qsum = 0; + const size_t dst_x = ((row_idx + m_idx_start) % mr); + + uint8_t* dst_ptr = (uint8_t*)lhs_packed + (dst_x * k_block_len * sizeof(int8_t)); // Quantize the channels - k_idx = 0; - for (; k_idx < (int32_t)k_internal; k_idx += k_block_len) { + int32_t block_idx = 0; + int32_t qsum = 0; + for (; block_idx < num_blocks_k; ++block_idx) { + // Clamp at the last valid k-index + const int32_t k_idx_start = block_idx * k_block_len; + + const float32x4_t src_0 = vld1q_f32(src_ptr + k_idx_start); + const float32x4_t src_1 = vld1q_f32(src_ptr + k_idx_start + 4); + + // Scale the values + float32x4_t v0_f32 = vmulq_n_f32(src_0, scale0); + float32x4_t v1_f32 = vmulq_n_f32(src_1, scale0); + int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); + int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); + + int16x4_t v0_s16 = vqmovn_s32(v0_s32); + int16x4_t v1_s16 = vqmovn_s32(v1_s32); + int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); + + v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); + v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); + + // Update the sum + qsum += vaddvq_s16(v_s16); + + int8x8_t v0_s8 = vqmovn_s16(v_s16); + vst1_s8((int8_t*)(dst_ptr), v0_s8); + dst_ptr += 8 * sizeof(int8_t); + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + for (; block_idx < num_blocks_k_internal; ++block_idx) { + // left over k for (size_t k_block_idx = 0; k_block_idx < (size_t)k_block_len; ++k_block_idx) { // Clamp at the last valid k-index - const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k - 1); + const size_t k_idx_start = KAI_MIN((size_t)(block_idx * k_block_len) + k_block_idx, k - 1); + const float src0_0 = *(src_ptr + k_idx_start); // Scale the values @@ -117,7 +154,7 @@ void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( v0_s32 = KAI_MIN(v0_s32, INT8_MAX); *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; - if ((size_t)k_idx + k_block_idx <= k - 1) { + if ((size_t)(block_idx * k_block_len) + k_block_idx <= k - 1) { qsum += v0_s32; } dst_ptr += sizeof(int8_t); @@ -127,7 +164,7 @@ void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); - dst_ptr += dst_idx * kai_num_bytes_sum; + dst_ptr += dst_x * kai_num_bytes_sum; // LHS offset at the beginning of the row *((float*)(dst_ptr)) = ((float)qsum) * recip_scale0; -- GitLab From f86ce74e9b224d42afba5967d4119d6346cca705 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 11 Apr 2025 12:07:58 +0100 Subject: [PATCH 14/15] Update LHS optimizations as per review comments Signed-off-by: Anitha Raj --- ..._lhs_quant_pack_qsi8dxpscalef32_f32_neon.c | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c index 5dc8c2c5..5fa582a7 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c @@ -87,8 +87,11 @@ void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx); const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx); // Calculate the max - vabsmax = vmaxq_f32(vabsq_f32(src0_0), vmaxq_f32(vabsmax, vabsq_f32(src0_1))); + float32x4_t vabsmax0 = vmaxq_f32(vabsmax, vabsq_f32(src0_0)); + float32x4_t vabsmax1 = vmaxq_f32(vabsmax, vabsq_f32(src0_1)); + vabsmax = vmaxq_f32(vabsmax0, vabsmax1); } + // Get the absmax absmax = vmaxvq_f32(vabsmax); @@ -119,25 +122,25 @@ void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( const float32x4_t src_1 = vld1q_f32(src_ptr + k_idx_start + 4); // Scale the values - float32x4_t v0_f32 = vmulq_n_f32(src_0, scale0); - float32x4_t v1_f32 = vmulq_n_f32(src_1, scale0); - int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); - int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); + const float32x4_t v0_f32 = vmulq_n_f32(src_0, scale0); + const float32x4_t v1_f32 = vmulq_n_f32(src_1, scale0); + const int32x4_t v0_s32 = vcvtnq_s32_f32(v0_f32); + const int32x4_t v1_s32 = vcvtnq_s32_f32(v1_f32); - int16x4_t v0_s16 = vqmovn_s32(v0_s32); - int16x4_t v1_s16 = vqmovn_s32(v1_s32); + const int16x4_t v0_s16 = vqmovn_s32(v0_s32); + const int16x4_t v1_s16 = vqmovn_s32(v1_s32); int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); - // Update the sum - qsum += vaddvq_s16(v_s16); - - int8x8_t v0_s8 = vqmovn_s16(v_s16); + const int8x8_t v0_s8 = vqmovn_s16(v_s16); vst1_s8((int8_t*)(dst_ptr), v0_s8); dst_ptr += 8 * sizeof(int8_t); dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + + // Update the sum + qsum += vaddlv_s8(v0_s8); } for (; block_idx < num_blocks_k_internal; ++block_idx) { // left over k -- GitLab From 72639914d7e7498594b4c172712c79aceaf36055 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 23 Apr 2025 12:56:19 +0100 Subject: [PATCH 15/15] Update test validation and remove redundant clamping Signed-off-by: Anitha Raj --- ...i_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c | 3 --- .../matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp | 18 ++++++------------ 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c index 5fa582a7..f6d28e96 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.c @@ -131,9 +131,6 @@ void kai_run_lhs_quant_pack_qsi8dxpscalef32_f32_neon( const int16x4_t v1_s16 = vqmovn_s32(v1_s32); int16x8_t v_s16 = vcombine_s16(v0_s16, v1_s16); - v_s16 = vmaxq_s16(v_s16, vdupq_n_s16(INT8_MIN)); - v_s16 = vminq_s16(v_s16, vdupq_n_s16(INT8_MAX)); - const int8x8_t v0_s8 = vqmovn_s16(v_s16); vst1_s8((int8_t*)(dst_ptr), v0_s8); dst_ptr += 8 * sizeof(int8_t); diff --git a/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp b/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp index 2a82bd5d..253afd4f 100644 --- a/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qsi8dxp_qai4cxp_test.cpp @@ -21,7 +21,9 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8dxp_qai4cxp/kai_matmul_clamp_f32_qsi8dxp_qai4cxp_interface.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8dxpscalef32_f32_neon.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qai4cxp_qau4cxs1s0_f32_f32_f32_neon.h" +#include "test/common/compare.hpp" #include "test/common/cpu_info.hpp" +#include "test/common/data_format.hpp" #include "test/common/float16.hpp" #include "test/common/int4.hpp" #include "test/common/matrix_portion.hpp" @@ -176,18 +178,10 @@ TEST_P(MatMulTest_f32_qsi8dxp_qai4cxp, EndToEnd) { // Compares the output of the micro-kernels against the output of the reference implementation for the portion // tested. - for (size_t y = 0; y < rect.height(); ++y) { - for (size_t x = 0; x < rect.width(); ++x) { - const auto imp_value = - read_array(imp_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); - const auto ref_value = - read_array(ref_dst.data(), (rect.start_row() + y) * N + (x + rect.start_col())); - const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : imp_value; - if (rel_error > 0.0001F) { - ASSERT_EQ(imp_value, ref_value); - } - } - } + DefaultMismatchHandler handler(0, 0.02, 0, 0.05); + DataFormat dst_format = DataFormat(DataType::FP32); + const auto success = compare(imp_dst.data(), ref_dst.data(), dst_format, M, N, rect, handler); + ASSERT_TRUE(success); } INSTANTIATE_TEST_SUITE_P( -- GitLab