From 2917b8797e4e0143e0c1f47b77e93ac266902f1e Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 10 Jun 2024 15:26:13 +0100 Subject: [PATCH 01/29] Matmul int4 micro-kernels for QAI8DX (LHS) x QSI4C32 (RHS) -> F32 - The LHS matrix is Quantized (Q) Asymmetric (A) Signed 8-bit (I8) with per-row (DX) quantization parameters - The RHS matrix is quantized (Q) Symmetric (S) Signed 4-bit (I4) with per-block quantization - The destination is F32 - Implement micro-kernels to perform the matrix multiplication - Implement a micro-kernel to pack the RHS matrix - No test added into this PR. Test will be added in a separate PR Signed-off-by: Gian Marco Iodice --- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c | 209 +++++++ ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 145 +++++ ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c | 254 +++++++++ ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 145 +++++ ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c | 519 ++++++++++++++++++ ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h | 145 +++++ ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c | 378 +++++++++++++ ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h | 145 +++++ ...mul_clamp_f32_qai8dxp_qsi4c32p_interface.h | 52 ++ .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c | 166 ++++++ .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h | 113 ++++ 11 files changed, 2271 insertions(+) create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h create mode 100644 kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c new file mode 100644 index 00000000..47fee4d3 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c @@ -0,0 +1,209 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_DOTPROD) +#error "Dotprod extension required to compile this micro-kernel" +#else +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_k0 = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(float); + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % bl) == 0); + return k / bl; +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kai_kr) == 0); + + return kai_mr * (k * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kai_kr) == 0); + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((k % bl) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + + return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + KAI_ASSERT(bl == 32); + + // Temporary assert + KAI_ASSERT((k % kai_k0) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(size_t m, size_t n) { + // Temporary assert + KAI_ASSERT((n % kai_nr) == 0); + + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t m, size_t n, size_t k, size_t bl, const void* restrict lhs_packed, const void* restrict rhs_packed, + float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + // Temporary asserts + KAI_ASSERT(n % kai_nr == 0); + KAI_ASSERT(k % kai_k0 == 0); + KAI_ASSERT(bl == 32); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + size_t num_subblocks = bl / 32; + size_t num_blocks = k / bl; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "movi v31.16b, #0xf0\n" + "1:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v30.16b, #0x0\n" + "mov x21, %x[num_blocks]\n" + "2:" // Block loop + "movi v29.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "3:" // Sub block loop + "ldr q27, [%x[rhs_packed], #0x0]\n" + "ldr q26, [%x[rhs_packed], #0x10]\n" + "subs x20, x20, #0x1\n" + "ld1r { v25.2d }, [x22], #0x8\n" + "ldr q24, [%x[rhs_packed], #0x20]\n" + "ldr q23, [%x[rhs_packed], #0x30]\n" + "add %x[rhs_packed], %x[rhs_packed], #0x40\n" + "ld1r { v22.2d }, [x22], #0x8\n" + "ld1r { v21.2d }, [x22], #0x8\n" + "shl v20.16b, v27.16b, #0x4\n" + "shl v19.16b, v26.16b, #0x4\n" + "ld1r { v18.2d }, [x22], #0x8\n" + "shl v17.16b, v24.16b, #0x4\n" + "and v27.16b, v27.16b, v31.16b\n" + "shl v16.16b, v23.16b, #0x4\n" + "and v26.16b, v26.16b, v31.16b\n" + ".inst 0x4e99969d // sdot v29.4s, v20.16b, v25.16b\n" + ".inst 0x4e99967c // sdot v28.4s, v19.16b, v25.16b\n" + "and v24.16b, v24.16b, v31.16b\n" + "and v23.16b, v23.16b, v31.16b\n" + ".inst 0x4e96963d // sdot v29.4s, v17.16b, v22.16b\n" + ".inst 0x4e96961c // sdot v28.4s, v16.16b, v22.16b\n" + ".inst 0x4e95977d // sdot v29.4s, v27.16b, v21.16b\n" + ".inst 0x4e95975c // sdot v28.4s, v26.16b, v21.16b\n" + ".inst 0x4e92971d // sdot v29.4s, v24.16b, v18.16b\n" + ".inst 0x4e9296fc // sdot v28.4s, v23.16b, v18.16b\n" + "bgt 3b\n" + "ldr q16, [%x[rhs_packed], #0x0]\n" + "addp v29.4s, v29.4s, v28.4s\n" + "sub x21, x21, #0x1\n" + "add %x[rhs_packed], %x[rhs_packed], #0x10\n" + "scvtf v29.4s, v29.4s\n" + "fmla v30.4s, v29.4s, v16.4s\n" + "cbnz x21, 2b\n" + "ld1r { v20.4s }, [x22]\n" + "ldr q19, [%x[rhs_packed], #0x0]\n" + "add x22, x22, #0x4\n" + "add x20, %x[clamp_vals], #0x4\n" + "ld1r { v18.4s }, [x22]\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "cmp %x[n], #0x4\n" + "add %x[rhs_packed], %x[rhs_packed], #0x10\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v20.4s, v20.4s\n" + "fmla v30.4s, v19.4s, v20.s[0]\n" + "fmul v30.4s, v30.4s, v18.4s\n" + "fmax v30.4s, v30.4s, v17.4s\n" + "fmin v30.4s, v30.4s, v16.4s\n" + "blt 4f\n" + "str q30, [%x[dst], #0x0]\n" + "b 7f\n" + "4:" // Partial output + "mov x20, %x[dst]\n" + "tbz %x[n], #1, 5f\n" + "st1 { v30.d }[0], [x20], #0x8\n" + "tbz %x[n], #0, 6f\n" + "st1 { v30.s }[2], [x20]\n" + "b 6f\n" + "5:" // Output block 0: partial_1_0 + "st1 { v30.s }[0], [x20]\n" + "6:" // Output block 0: Done + "7:" // Stores done + "subs %x[n], %x[n], #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 1b\n" + : [dst] "+&r"(dst), [n] "+&r"(n), [rhs_packed] "+&r"(rhs_packed) + : [clamp_vals] "r"(clamp_vals), [lhs_packed] "r"(lhs_packed), [num_blocks] "r"(num_blocks), + [num_subblocks] "r"(num_subblocks) + : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "x20", "x21", "x22"); +} +#endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h new file mode 100644 index 00000000..ce0aa648 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -0,0 +1,145 @@ + +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifndef __cplusplus +#include +#endif +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 1 +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. +/// Output tile: (rows x cols) = 1 x 4 +/// Accumulation performed in a single for loop: 32 +/// Extension used: dotprod +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] bl Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref +/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c new file mode 100644 index 00000000..89d0d164 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c @@ -0,0 +1,254 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_DOTPROD) +#error "Dotprod extension required to compile this micro-kernel" +#else +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 8; +static const size_t kai_mr = 1; +static const size_t kai_nr = 8; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_k0 = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(float); + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % bl) == 0); + return k / bl; +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kai_kr) == 0); + + return kai_mr * (k * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kai_kr) == 0); + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((k % bl) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + + return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + KAI_ASSERT(bl == 32); + + // Temporary assert + KAI_ASSERT((k % kai_k0) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(size_t m, size_t n) { + // Temporary assert + KAI_ASSERT((n % kai_nr) == 0); + + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t m, size_t n, size_t k, size_t bl, const void* restrict lhs_packed, const void* restrict rhs_packed, + float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + // Temporary asserts + KAI_ASSERT(n % kai_nr == 0); + KAI_ASSERT(k % kai_k0 == 0); + KAI_ASSERT(bl == 32); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + size_t num_subblocks = bl / 32; + size_t num_blocks = k / bl; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "movi v7.16b, #0xf0\n" + "1:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v6.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "mov x21, %x[num_blocks]\n" + "2:" // Block loop + "movi v4.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v2.4s, #0x0\n" + "movi v1.4s, #0x0\n" + "3:" // Sub block loop + "ldr q0, [%x[rhs_packed], #0x0]\n" + "ldr q31, [%x[rhs_packed], #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q30, [%x[rhs_packed], #0x20]\n" + "ldr q29, [%x[rhs_packed], #0x30]\n" + "ld1r { v28.2d }, [x22], #0x8\n" + "ldr q27, [%x[rhs_packed], #0x40]\n" + "ldr q26, [%x[rhs_packed], #0x50]\n" + "ldr q25, [%x[rhs_packed], #0x60]\n" + "shl v24.16b, v0.16b, #0x4\n" + "shl v18.16b, v31.16b, #0x4\n" + "ldr q23, [%x[rhs_packed], #0x70]\n" + "shl v17.16b, v30.16b, #0x4\n" + "shl v16.16b, v29.16b, #0x4\n" + "add %x[rhs_packed], %x[rhs_packed], #0x80\n" + "ld1r { v22.2d }, [x22], #0x8\n" + "shl v21.16b, v27.16b, #0x4\n" + "and v0.16b, v0.16b, v7.16b\n" + "ld1r { v20.2d }, [x22], #0x8\n" + "ld1r { v19.2d }, [x22], #0x8\n" + ".inst 0x4e9c9704 // sdot v4.4s, v24.16b, v28.16b\n" + ".inst 0x4e9c9643 // sdot v3.4s, v18.16b, v28.16b\n" + "shl v18.16b, v26.16b, #0x4\n" + ".inst 0x4e9c9622 // sdot v2.4s, v17.16b, v28.16b\n" + ".inst 0x4e9c9601 // sdot v1.4s, v16.16b, v28.16b\n" + "shl v17.16b, v25.16b, #0x4\n" + "shl v16.16b, v23.16b, #0x4\n" + "and v31.16b, v31.16b, v7.16b\n" + "and v30.16b, v30.16b, v7.16b\n" + "and v29.16b, v29.16b, v7.16b\n" + ".inst 0x4e9696a4 // sdot v4.4s, v21.16b, v22.16b\n" + ".inst 0x4e969643 // sdot v3.4s, v18.16b, v22.16b\n" + "and v27.16b, v27.16b, v7.16b\n" + ".inst 0x4e969622 // sdot v2.4s, v17.16b, v22.16b\n" + ".inst 0x4e969601 // sdot v1.4s, v16.16b, v22.16b\n" + "and v26.16b, v26.16b, v7.16b\n" + "and v25.16b, v25.16b, v7.16b\n" + "and v23.16b, v23.16b, v7.16b\n" + ".inst 0x4e949404 // sdot v4.4s, v0.16b, v20.16b\n" + ".inst 0x4e9497e3 // sdot v3.4s, v31.16b, v20.16b\n" + ".inst 0x4e9497c2 // sdot v2.4s, v30.16b, v20.16b\n" + ".inst 0x4e9497a1 // sdot v1.4s, v29.16b, v20.16b\n" + ".inst 0x4e939764 // sdot v4.4s, v27.16b, v19.16b\n" + ".inst 0x4e939743 // sdot v3.4s, v26.16b, v19.16b\n" + ".inst 0x4e939722 // sdot v2.4s, v25.16b, v19.16b\n" + ".inst 0x4e9396e1 // sdot v1.4s, v23.16b, v19.16b\n" + "bgt 3b\n" + "ldr q17, [%x[rhs_packed], #0x0]\n" + "ldr q16, [%x[rhs_packed], #0x10]\n" + "addp v4.4s, v4.4s, v3.4s\n" + "addp v2.4s, v2.4s, v1.4s\n" + "sub x21, x21, #0x1\n" + "add %x[rhs_packed], %x[rhs_packed], #0x20\n" + "scvtf v4.4s, v4.4s\n" + "scvtf v2.4s, v2.4s\n" + "fmla v6.4s, v4.4s, v17.4s\n" + "fmla v5.4s, v2.4s, v16.4s\n" + "cbnz x21, 2b\n" + "ld1r { v21.4s }, [x22]\n" + "ldr q20, [%x[rhs_packed], #0x0]\n" + "add x22, x22, #0x4\n" + "add x20, %x[clamp_vals], #0x4\n" + "ldr q19, [%x[rhs_packed], #0x10]\n" + "ld1r { v18.4s }, [x22]\n" + "cmp %x[n], #0x8\n" + "add %x[rhs_packed], %x[rhs_packed], #0x20\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v21.4s, v21.4s\n" + "fmla v6.4s, v20.4s, v21.s[0]\n" + "fmla v5.4s, v19.4s, v21.s[0]\n" + "fmul v6.4s, v6.4s, v18.4s\n" + "fmax v6.4s, v6.4s, v17.4s\n" + "fmul v5.4s, v5.4s, v18.4s\n" + "fmax v5.4s, v5.4s, v17.4s\n" + "fmin v6.4s, v6.4s, v16.4s\n" + "fmin v5.4s, v5.4s, v16.4s\n" + "blt 4f\n" + "str q6, [%x[dst], #0x0]\n" + "str q5, [%x[dst], #0x10]\n" + "b 9f\n" + "4:" // Partial output + "mov x20, %x[dst]\n" + "tbz %x[n], #2, 6f\n" + "st1 { v6.4s }, [x20], #0x10\n" + "tbz %x[n], #1, 5f\n" + "st1 { v5.d }[0], [x20], #0x8\n" + "tbz %x[n], #0, 8f\n" + "st1 { v5.s }[2], [x20]\n" + "b 8f\n" + "5:" // Output block 0: partial_1_4 + "tbz %x[n], #0, 8f\n" + "st1 { v5.s }[0], [x20]\n" + "b 8f\n" + "6:" // Output block 0: partial_2_0 + "tbz %x[n], #1, 7f\n" + "st1 { v6.d }[0], [x20], #0x8\n" + "tbz %x[n], #0, 8f\n" + "st1 { v6.s }[2], [x20]\n" + "b 8f\n" + "7:" // Output block 0: partial_1_0 + "st1 { v6.s }[0], [x20]\n" + "8:" // Output block 0: Done + "9:" // Stores done + "subs %x[n], %x[n], #0x8\n" + "add %x[dst], %x[dst], #0x20\n" + "bgt 1b\n" + : [dst] "+&r"(dst), [n] "+&r"(n), [rhs_packed] "+&r"(rhs_packed) + : [clamp_vals] "r"(clamp_vals), [lhs_packed] "r"(lhs_packed), [num_blocks] "r"(num_blocks), + [num_subblocks] "r"(num_subblocks) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"); +} +#endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h new file mode 100644 index 00000000..bbe749f1 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -0,0 +1,145 @@ + +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifndef __cplusplus +#include +#endif +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 1 +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 8. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 8. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. +/// Output tile: (rows x cols) = 1 x 8 +/// Accumulation performed in a single for loop: 32 +/// Extension used: dotprod +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] bl Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref +/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c new file mode 100644 index 00000000..876b647b --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -0,0 +1,519 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_MATMUL_INT8) +#error "I8mm extension required to compile this micro-kernel" +#else +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 8; +static const size_t kai_n_step = 4; +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_k0 = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(float); + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % bl) == 0); + return k / bl; +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kai_kr) == 0); + + return kai_mr * (k * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kai_kr) == 0); + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((k % bl) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + + return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + KAI_ASSERT(bl == 32); + + // Temporary assert + KAI_ASSERT((k % kai_k0) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(size_t m, size_t n) { + // Temporary assert + KAI_ASSERT((n % kai_nr) == 0); + + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + size_t num_subblocks = bl / 32; + size_t num_blocks = k / bl; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x12, %x[m]\n" + "mov x11, #0x80\n" + "movi v10.16b, #0xf0\n" + "mov x20, #0x20\n" + "cmp x12, #0x8\n" + "madd x11, %x[num_blocks], x11, x20\n" + "blt 11f\n" + "1:" // Row loop + "mov x10, %x[rhs_packed]\n" + "mov x9, %x[n]\n" + "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" + "2:" // Column loop + "mov x23, %x[lhs_packed]\n" + "movi v6.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "mov x22, %x[num_blocks]\n" + "movi v1.16b, #0x0\n" + "movi v2.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "add x21, x23, x11\n" + "3:" // Block loop + "movi v8.4s, #0x0\n" + "movi v11.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v17.4s, #0x0\n" + "movi v4.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v15.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "4:" // Sub block loop + "ldr q3, [x10, #0x0]\n" + "ldr q19, [x10, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q7, [x23, #0x0]\n" + "ldr q28, [x23, #0x10]\n" + "ldr q14, [x21, #0x0]\n" + "ldr q18, [x21, #0x10]\n" + "ldr q29, [x10, #0x20]\n" + "ldr q25, [x10, #0x30]\n" + "shl v20.16b, v3.16b, #0x4\n" + "shl v16.16b, v19.16b, #0x4\n" + "ldr q24, [x23, #0x20]\n" + "ldr q9, [x23, #0x30]\n" + "and v3.16b, v3.16b, v10.16b\n" + "and v19.16b, v19.16b, v10.16b\n" + "ldr q0, [x21, #0x20]\n" + "ldr q26, [x21, #0x30]\n" + "add x10, x10, #0x40\n" + "ldr q30, [x23, #0x40]\n" + ".inst 0x4e94a4e8 // smmla v8.4s, v7.16b, v20.16b\n" + ".inst 0x4e90a4eb // smmla v11.4s, v7.16b, v16.16b\n" + "ldr q7, [x23, #0x50]\n" + ".inst 0x4e94a791 // smmla v17.4s, v28.16b, v20.16b\n" + ".inst 0x4e90a784 // smmla v4.4s, v28.16b, v16.16b\n" + "ldr q28, [x21, #0x40]\n" + ".inst 0x4e94a5d7 // smmla v23.4s, v14.16b, v20.16b\n" + ".inst 0x4e90a5cf // smmla v15.4s, v14.16b, v16.16b\n" + "ldr q14, [x21, #0x50]\n" + ".inst 0x4e94a65b // smmla v27.4s, v18.16b, v20.16b\n" + "ldr q20, [x23, #0x60]\n" + ".inst 0x4e90a645 // smmla v5.4s, v18.16b, v16.16b\n" + "ldr q18, [x23, #0x70]\n" + "shl v16.16b, v29.16b, #0x4\n" + "and v29.16b, v29.16b, v10.16b\n" + "add x23, x23, #0x80\n" + ".inst 0x4e90a708 // smmla v8.4s, v24.16b, v16.16b\n" + ".inst 0x4e90a531 // smmla v17.4s, v9.16b, v16.16b\n" + ".inst 0x4e90a417 // smmla v23.4s, v0.16b, v16.16b\n" + ".inst 0x4e90a75b // smmla v27.4s, v26.16b, v16.16b\n" + "ldr q16, [x21, #0x60]\n" + ".inst 0x4e83a7c8 // smmla v8.4s, v30.16b, v3.16b\n" + ".inst 0x4e83a4f1 // smmla v17.4s, v7.16b, v3.16b\n" + ".inst 0x4e83a797 // smmla v23.4s, v28.16b, v3.16b\n" + ".inst 0x4e83a5db // smmla v27.4s, v14.16b, v3.16b\n" + "ldr q3, [x21, #0x70]\n" + "add x21, x21, #0x80\n" + ".inst 0x4e9da688 // smmla v8.4s, v20.16b, v29.16b\n" + ".inst 0x4e9da651 // smmla v17.4s, v18.16b, v29.16b\n" + ".inst 0x4e9da617 // smmla v23.4s, v16.16b, v29.16b\n" + ".inst 0x4e9da47b // smmla v27.4s, v3.16b, v29.16b\n" + "shl v29.16b, v25.16b, #0x4\n" + "and v25.16b, v25.16b, v10.16b\n" + ".inst 0x4e9da70b // smmla v11.4s, v24.16b, v29.16b\n" + ".inst 0x4e9da524 // smmla v4.4s, v9.16b, v29.16b\n" + ".inst 0x4e9da40f // smmla v15.4s, v0.16b, v29.16b\n" + ".inst 0x4e9da745 // smmla v5.4s, v26.16b, v29.16b\n" + ".inst 0x4e93a7cb // smmla v11.4s, v30.16b, v19.16b\n" + ".inst 0x4e93a4e4 // smmla v4.4s, v7.16b, v19.16b\n" + ".inst 0x4e93a78f // smmla v15.4s, v28.16b, v19.16b\n" + ".inst 0x4e93a5c5 // smmla v5.4s, v14.16b, v19.16b\n" + ".inst 0x4e99a68b // smmla v11.4s, v20.16b, v25.16b\n" + ".inst 0x4e99a644 // smmla v4.4s, v18.16b, v25.16b\n" + ".inst 0x4e99a60f // smmla v15.4s, v16.16b, v25.16b\n" + ".inst 0x4e99a465 // smmla v5.4s, v3.16b, v25.16b\n" + "bgt 4b\n" + "ldr q20, [x10, #0x0]\n" + "uzp1 v29.2d, v8.2d, v11.2d\n" + "uzp2 v18.2d, v8.2d, v11.2d\n" + "add x10, x10, #0x10\n" + "uzp1 v30.2d, v17.2d, v4.2d\n" + "uzp2 v16.2d, v17.2d, v4.2d\n" + "scvtf v29.4s, v29.4s\n" + "scvtf v18.4s, v18.4s\n" + "scvtf v30.4s, v30.4s\n" + "scvtf v16.4s, v16.4s\n" + "fmla v6.4s, v29.4s, v20.4s\n" + "fmla v31.4s, v18.4s, v20.4s\n" + "fmla v1.4s, v30.4s, v20.4s\n" + "fmla v2.4s, v16.4s, v20.4s\n" + "uzp1 v4.2d, v23.2d, v15.2d\n" + "uzp2 v18.2d, v23.2d, v15.2d\n" + "uzp1 v15.2d, v27.2d, v5.2d\n" + "uzp2 v16.2d, v27.2d, v5.2d\n" + "scvtf v4.4s, v4.4s\n" + "scvtf v18.4s, v18.4s\n" + "scvtf v15.4s, v15.4s\n" + "scvtf v16.4s, v16.4s\n" + "fmla v13.4s, v4.4s, v20.4s\n" + "fmla v21.4s, v18.4s, v20.4s\n" + "fmla v22.4s, v15.4s, v20.4s\n" + "fmla v12.4s, v16.4s, v20.4s\n" + "subs x22, x22, #0x1\n" + "bgt 3b\n" + "ld1 { v29.4s }, [x23]\n" + "ld1 { v4.4s }, [x21]\n" + "add x23, x23, #0x10\n" + "add x21, x21, #0x10\n" + "ldr q20, [x10, #0x0]\n" + "ldr q8, [x23, #0x0]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x9, #0x4\n" + "ldr q18, [x21, #0x0]\n" + "ld1r { v28.4s }, [%x[clamp_vals]]\n" + "add x10, x10, #0x10\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v29.4s, v29.4s\n" + "scvtf v4.4s, v4.4s\n" + "fmla v6.4s, v20.4s, v29.s[0]\n" + "fmla v31.4s, v20.4s, v29.s[1]\n" + "fmla v1.4s, v20.4s, v29.s[2]\n" + "fmla v2.4s, v20.4s, v29.s[3]\n" + "fmla v13.4s, v20.4s, v4.s[0]\n" + "fmla v21.4s, v20.4s, v4.s[1]\n" + "fmla v22.4s, v20.4s, v4.s[2]\n" + "fmla v12.4s, v20.4s, v4.s[3]\n" + "fmul v6.4s, v6.4s, v8.s[0]\n" + "fmul v31.4s, v31.4s, v8.s[1]\n" + "fmul v1.4s, v1.4s, v8.s[2]\n" + "fmul v2.4s, v2.4s, v8.s[3]\n" + "fmul v13.4s, v13.4s, v18.s[0]\n" + "fmul v21.4s, v21.4s, v18.s[1]\n" + "fmul v22.4s, v22.4s, v18.s[2]\n" + "fmul v12.4s, v12.4s, v18.s[3]\n" + "fmax v6.4s, v6.4s, v28.4s\n" + "fmax v31.4s, v31.4s, v28.4s\n" + "fmax v1.4s, v1.4s, v28.4s\n" + "fmax v2.4s, v2.4s, v28.4s\n" + "fmax v13.4s, v13.4s, v28.4s\n" + "fmax v21.4s, v21.4s, v28.4s\n" + "fmax v22.4s, v22.4s, v28.4s\n" + "fmax v12.4s, v12.4s, v28.4s\n" + "fmin v6.4s, v6.4s, v16.4s\n" + "fmin v31.4s, v31.4s, v16.4s\n" + "fmin v1.4s, v1.4s, v16.4s\n" + "fmin v2.4s, v2.4s, v16.4s\n" + "fmin v13.4s, v13.4s, v16.4s\n" + "fmin v21.4s, v21.4s, v16.4s\n" + "fmin v22.4s, v22.4s, v16.4s\n" + "fmin v12.4s, v12.4s, v16.4s\n" + "blt 7f\n" + "mov x20, %x[dst]\n" + "str q6, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q31, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q1, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q21, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q12, [x20, #0x0]\n" + "b 10f\n" + "7:" // Partial output + "mov x27, %x[dst]\n" + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x27, %x[dst_stride_row], LSL #1\n" + "add x21, x27, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "tbz x9, #1, 8f\n" + "st1 { v12.d }[0], [x23], #0x8\n" + "st1 { v22.d }[0], [x25], #0x8\n" + "st1 { v21.d }[0], [x24], #0x8\n" + "st1 { v13.d }[0], [x26], #0x8\n" + "st1 { v2.d }[0], [x20], #0x8\n" + "st1 { v1.d }[0], [x22], #0x8\n" + "st1 { v31.d }[0], [x21], #0x8\n" + "st1 { v6.d }[0], [x27], #0x8\n" + "tbz x9, #0, 9f\n" + "st1 { v12.s }[2], [x23]\n" + "st1 { v22.s }[2], [x25]\n" + "st1 { v21.s }[2], [x24]\n" + "st1 { v13.s }[2], [x26]\n" + "st1 { v2.s }[2], [x20]\n" + "st1 { v1.s }[2], [x22]\n" + "st1 { v31.s }[2], [x21]\n" + "st1 { v6.s }[2], [x27]\n" + "b 9f\n" + "8:" // Output block 0: partial_1_0 + "st1 { v12.s }[0], [x23]\n" + "st1 { v22.s }[0], [x25]\n" + "st1 { v21.s }[0], [x24]\n" + "st1 { v13.s }[0], [x26]\n" + "st1 { v2.s }[0], [x20]\n" + "st1 { v1.s }[0], [x22]\n" + "st1 { v31.s }[0], [x21]\n" + "st1 { v6.s }[0], [x27]\n" + "9:" // Output block 0: Done + "10:" // Output stage exit + "subs x9, x9, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "mov x20, #0x2\n" + "sub x12, x12, #0x8\n" + "cmp x12, #0x8\n" + "mov %x[dst], x28\n" + "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" + "bge 1b\n" + "11:" // Row loop skip + "cbz x12, 21f\n" + "12:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "13:" // Row tail: Column loop + "movi v6.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "mov x23, %x[lhs_packed]\n" + "mov x21, %x[num_blocks]\n" + "movi v1.16b, #0x0\n" + "movi v2.16b, #0x0\n" + "14:" // Row tail: Block loop + "movi v8.4s, #0x0\n" + "movi v11.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v17.4s, #0x0\n" + "movi v4.4s, #0x0\n" + "15:" // Row tail: Sub block loop + "ldr q13, [x26, #0x0]\n" + "ldr q5, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q15, [x23, #0x0]\n" + "ldr q30, [x23, #0x10]\n" + "ldr q29, [x26, #0x20]\n" + "ldr q18, [x26, #0x30]\n" + "add x26, x26, #0x40\n" + "ldr q27, [x23, #0x20]\n" + "ldr q26, [x23, #0x30]\n" + "shl v25.16b, v13.16b, #0x4\n" + "shl v24.16b, v5.16b, #0x4\n" + "ldr q28, [x23, #0x40]\n" + "ldr q20, [x23, #0x50]\n" + "and v13.16b, v13.16b, v10.16b\n" + "and v5.16b, v5.16b, v10.16b\n" + "ldr q0, [x23, #0x60]\n" + "ldr q12, [x23, #0x70]\n" + "shl v9.16b, v29.16b, #0x4\n" + "shl v16.16b, v18.16b, #0x4\n" + ".inst 0x4e99a5e8 // smmla v8.4s, v15.16b, v25.16b\n" + ".inst 0x4e98a5eb // smmla v11.4s, v15.16b, v24.16b\n" + "and v29.16b, v29.16b, v10.16b\n" + "add x23, x23, #0x80\n" + ".inst 0x4e99a7d1 // smmla v17.4s, v30.16b, v25.16b\n" + ".inst 0x4e98a7c4 // smmla v4.4s, v30.16b, v24.16b\n" + "and v18.16b, v18.16b, v10.16b\n" + ".inst 0x4e89a768 // smmla v8.4s, v27.16b, v9.16b\n" + ".inst 0x4e90a76b // smmla v11.4s, v27.16b, v16.16b\n" + ".inst 0x4e89a751 // smmla v17.4s, v26.16b, v9.16b\n" + ".inst 0x4e90a744 // smmla v4.4s, v26.16b, v16.16b\n" + ".inst 0x4e8da788 // smmla v8.4s, v28.16b, v13.16b\n" + ".inst 0x4e85a78b // smmla v11.4s, v28.16b, v5.16b\n" + ".inst 0x4e8da691 // smmla v17.4s, v20.16b, v13.16b\n" + ".inst 0x4e85a684 // smmla v4.4s, v20.16b, v5.16b\n" + ".inst 0x4e9da408 // smmla v8.4s, v0.16b, v29.16b\n" + ".inst 0x4e92a40b // smmla v11.4s, v0.16b, v18.16b\n" + ".inst 0x4e9da591 // smmla v17.4s, v12.16b, v29.16b\n" + ".inst 0x4e92a584 // smmla v4.4s, v12.16b, v18.16b\n" + "bgt 15b\n" + "ldr q20, [x26, #0x0]\n" + "uzp1 v14.2d, v8.2d, v11.2d\n" + "uzp2 v18.2d, v8.2d, v11.2d\n" + "add x26, x26, #0x10\n" + "uzp1 v8.2d, v17.2d, v4.2d\n" + "uzp2 v16.2d, v17.2d, v4.2d\n" + "scvtf v14.4s, v14.4s\n" + "scvtf v18.4s, v18.4s\n" + "scvtf v8.4s, v8.4s\n" + "scvtf v16.4s, v16.4s\n" + "fmla v6.4s, v14.4s, v20.4s\n" + "fmla v31.4s, v18.4s, v20.4s\n" + "fmla v1.4s, v8.4s, v20.4s\n" + "fmla v2.4s, v16.4s, v20.4s\n" + "subs x21, x21, #0x1\n" + "bgt 14b\n" + "ld1 { v20.4s }, [x23]\n" + "ldr q11, [x26, #0x0]\n" + "add x23, x23, #0x10\n" + "add x20, %x[clamp_vals], #0x4\n" + "ldr q18, [x23, #0x0]\n" + "ld1r { v8.4s }, [%x[clamp_vals]]\n" + "cmp x25, #0x4\n" + "add x26, x26, #0x10\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v20.4s, v20.4s\n" + "fmla v6.4s, v11.4s, v20.s[0]\n" + "fmla v31.4s, v11.4s, v20.s[1]\n" + "fmla v1.4s, v11.4s, v20.s[2]\n" + "fmla v2.4s, v11.4s, v20.s[3]\n" + "fmul v6.4s, v6.4s, v18.s[0]\n" + "fmul v31.4s, v31.4s, v18.s[1]\n" + "fmax v6.4s, v6.4s, v8.4s\n" + "fmul v1.4s, v1.4s, v18.s[2]\n" + "fmul v2.4s, v2.4s, v18.s[3]\n" + "fmax v31.4s, v31.4s, v8.4s\n" + "fmin v6.4s, v6.4s, v16.4s\n" + "fmax v1.4s, v1.4s, v8.4s\n" + "fmax v2.4s, v2.4s, v8.4s\n" + "fmin v31.4s, v31.4s, v16.4s\n" + "fmin v1.4s, v1.4s, v16.4s\n" + "fmin v2.4s, v2.4s, v16.4s\n" + "blt 17f\n" + "mov x20, %x[dst]\n" + "cmp x12, #0x1\n" + "str q6, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 20f\n" + "cmp x12, #0x2\n" + "str q31, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 20f\n" + "cmp x12, #0x3\n" + "str q1, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 20f\n" + "str q2, [x20, #0x0]\n" + "b 20f\n" + "17:" // Row tail: Partial output + "mov x23, %x[dst]\n" + "cmp x12, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GE\n" + "cmp x12, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GE\n" + "cmp x12, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GE\n" + "tbz x25, #1, 18f\n" + "st1 { v2.d }[0], [x20], #0x8\n" + "st1 { v1.d }[0], [x21], #0x8\n" + "st1 { v31.d }[0], [x22], #0x8\n" + "st1 { v6.d }[0], [x23], #0x8\n" + "tbz x25, #0, 19f\n" + "st1 { v2.s }[2], [x20]\n" + "st1 { v1.s }[2], [x21]\n" + "st1 { v31.s }[2], [x22]\n" + "st1 { v6.s }[2], [x23]\n" + "b 19f\n" + "18:" // Row tail: Output block 0: partial_1_0 + "st1 { v2.s }[0], [x20]\n" + "st1 { v1.s }[0], [x21]\n" + "st1 { v31.s }[0], [x22]\n" + "st1 { v6.s }[0], [x23]\n" + "19:" // Row tail: Output block 0: Done + "20:" // Row tail: Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 13b\n" + "subs x12, x12, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x11\n" + "mov %x[dst], x24\n" + "bgt 12b\n" + "21:" // Row tail: Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} +#endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h new file mode 100644 index 00000000..8b20a31d --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h @@ -0,0 +1,145 @@ + +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifndef __cplusplus +#include +#endif +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); + +/// Gets the nr value, which must be used to pack the RHS matrix +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 8 +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 8. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. +/// Output tile: (rows x cols) = 8 x 4 +/// Accumulation performed in a single for loop: 32 +/// Extension used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] bl Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref +/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c new file mode 100644 index 00000000..07f3558a --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c @@ -0,0 +1,378 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_MATMUL_INT8) +#error "I8mm extension required to compile this micro-kernel" +#else +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 4; +static const size_t kai_n_step = 8; +static const size_t kai_mr = 4; +static const size_t kai_nr = 8; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_k0 = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(float); + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % bl) == 0); + return k / bl; +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kai_kr) == 0); + + return kai_mr * (k * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kai_kr) == 0); + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((k % bl) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + + return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + KAI_ASSERT(bl == 32); + + // Temporary assert + KAI_ASSERT((k % kai_k0) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(size_t m, size_t n) { + // Temporary assert + KAI_ASSERT((n % kai_nr) == 0); + + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + size_t num_subblocks = bl / 32; + size_t num_blocks = k / bl; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x28, #0x80\n" + "mov x20, #0x20\n" + "movi v12.16b, #0xf0\n" + "mov x27, %x[m]\n" + "madd x28, %x[num_blocks], x28, x20\n" + "cbz x27, 12f\n" + "1:" // Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "2:" // Column loop + "movi v16.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "mov x22, %x[lhs_packed]\n" + "mov x21, %x[num_blocks]\n" + "movi v13.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "3:" // Block loop + "movi v8.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v31.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v4.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v1.4s, #0x0\n" + "4:" // Sub block loop + "ldr q7, [x26, #0x0]\n" + "ldr q6, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q22, [x26, #0x20]\n" + "ldr q26, [x26, #0x30]\n" + "ldr q20, [x22, #0x0]\n" + "ldr q18, [x22, #0x10]\n" + "ldr q23, [x26, #0x40]\n" + "ldr q17, [x26, #0x50]\n" + "shl v28.16b, v7.16b, #0x4\n" + "shl v2.16b, v6.16b, #0x4\n" + "ldr q25, [x26, #0x60]\n" + "ldr q24, [x26, #0x70]\n" + "shl v19.16b, v22.16b, #0x4\n" + "shl v9.16b, v26.16b, #0x4\n" + "ldr q21, [x22, #0x20]\n" + "and v7.16b, v7.16b, v12.16b\n" + "and v6.16b, v6.16b, v12.16b\n" + "add x26, x26, #0x80\n" + ".inst 0x4e9ca688 // smmla v8.4s, v20.16b, v28.16b\n" + ".inst 0x4e82a69f // smmla v31.4s, v20.16b, v2.16b\n" + "and v22.16b, v22.16b, v12.16b\n" + ".inst 0x4e93a683 // smmla v3.4s, v20.16b, v19.16b\n" + ".inst 0x4e89a680 // smmla v0.4s, v20.16b, v9.16b\n" + "ldr q20, [x22, #0x30]\n" + "and v26.16b, v26.16b, v12.16b\n" + ".inst 0x4e9ca65e // smmla v30.4s, v18.16b, v28.16b\n" + "ldr q28, [x22, #0x40]\n" + ".inst 0x4e82a65d // smmla v29.4s, v18.16b, v2.16b\n" + "ldr q2, [x22, #0x50]\n" + ".inst 0x4e93a644 // smmla v4.4s, v18.16b, v19.16b\n" + "ldr q19, [x22, #0x60]\n" + ".inst 0x4e89a641 // smmla v1.4s, v18.16b, v9.16b\n" + "ldr q18, [x22, #0x70]\n" + "shl v9.16b, v23.16b, #0x4\n" + "and v23.16b, v23.16b, v12.16b\n" + "add x22, x22, #0x80\n" + ".inst 0x4e89a6a8 // smmla v8.4s, v21.16b, v9.16b\n" + ".inst 0x4e89a69e // smmla v30.4s, v20.16b, v9.16b\n" + "shl v9.16b, v17.16b, #0x4\n" + "and v17.16b, v17.16b, v12.16b\n" + ".inst 0x4e89a6bf // smmla v31.4s, v21.16b, v9.16b\n" + ".inst 0x4e89a69d // smmla v29.4s, v20.16b, v9.16b\n" + "shl v9.16b, v25.16b, #0x4\n" + "and v25.16b, v25.16b, v12.16b\n" + ".inst 0x4e87a788 // smmla v8.4s, v28.16b, v7.16b\n" + ".inst 0x4e87a45e // smmla v30.4s, v2.16b, v7.16b\n" + "shl v7.16b, v24.16b, #0x4\n" + "and v24.16b, v24.16b, v12.16b\n" + ".inst 0x4e89a6a3 // smmla v3.4s, v21.16b, v9.16b\n" + ".inst 0x4e89a684 // smmla v4.4s, v20.16b, v9.16b\n" + ".inst 0x4e86a79f // smmla v31.4s, v28.16b, v6.16b\n" + ".inst 0x4e86a45d // smmla v29.4s, v2.16b, v6.16b\n" + ".inst 0x4e87a6a0 // smmla v0.4s, v21.16b, v7.16b\n" + ".inst 0x4e87a681 // smmla v1.4s, v20.16b, v7.16b\n" + ".inst 0x4e97a668 // smmla v8.4s, v19.16b, v23.16b\n" + ".inst 0x4e97a65e // smmla v30.4s, v18.16b, v23.16b\n" + ".inst 0x4e96a783 // smmla v3.4s, v28.16b, v22.16b\n" + ".inst 0x4e96a444 // smmla v4.4s, v2.16b, v22.16b\n" + ".inst 0x4e91a67f // smmla v31.4s, v19.16b, v17.16b\n" + ".inst 0x4e91a65d // smmla v29.4s, v18.16b, v17.16b\n" + ".inst 0x4e9aa780 // smmla v0.4s, v28.16b, v26.16b\n" + ".inst 0x4e9aa441 // smmla v1.4s, v2.16b, v26.16b\n" + ".inst 0x4e99a663 // smmla v3.4s, v19.16b, v25.16b\n" + ".inst 0x4e99a644 // smmla v4.4s, v18.16b, v25.16b\n" + ".inst 0x4e98a660 // smmla v0.4s, v19.16b, v24.16b\n" + ".inst 0x4e98a641 // smmla v1.4s, v18.16b, v24.16b\n" + "bgt 4b\n" + "ldr q6, [x26, #0x0]\n" + "ldr q24, [x26, #0x10]\n" + "uzp1 v22.2d, v8.2d, v31.2d\n" + "uzp2 v28.2d, v8.2d, v31.2d\n" + "uzp1 v21.2d, v3.2d, v0.2d\n" + "uzp2 v20.2d, v3.2d, v0.2d\n" + "add x26, x26, #0x20\n" + "uzp1 v19.2d, v30.2d, v29.2d\n" + "uzp2 v18.2d, v30.2d, v29.2d\n" + "uzp1 v17.2d, v4.2d, v1.2d\n" + "uzp2 v0.2d, v4.2d, v1.2d\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v28.4s, v28.4s\n" + "scvtf v20.4s, v20.4s\n" + "scvtf v19.4s, v19.4s\n" + "scvtf v17.4s, v17.4s\n" + "scvtf v18.4s, v18.4s\n" + "scvtf v0.4s, v0.4s\n" + "fmla v16.4s, v22.4s, v6.4s\n" + "fmla v5.4s, v21.4s, v24.4s\n" + "fmla v13.4s, v28.4s, v6.4s\n" + "fmla v11.4s, v20.4s, v24.4s\n" + "fmla v15.4s, v19.4s, v6.4s\n" + "fmla v27.4s, v17.4s, v24.4s\n" + "fmla v10.4s, v18.4s, v6.4s\n" + "fmla v14.4s, v0.4s, v24.4s\n" + "subs x21, x21, #0x1\n" + "bgt 3b\n" + "ld1 { v21.4s }, [x22]\n" + "ldr q20, [x26, #0x0]\n" + "add x22, x22, #0x10\n" + "add x20, %x[clamp_vals], #0x4\n" + "ldr q19, [x26, #0x10]\n" + "ldr q18, [x22, #0x0]\n" + "cmp x25, #0x8\n" + "add x26, x26, #0x20\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "ld1r { v31.4s }, [x20]\n" + "scvtf v21.4s, v21.4s\n" + "fmla v16.4s, v20.4s, v21.s[0]\n" + "fmla v5.4s, v19.4s, v21.s[0]\n" + "fmla v13.4s, v20.4s, v21.s[1]\n" + "fmla v11.4s, v19.4s, v21.s[1]\n" + "fmla v15.4s, v20.4s, v21.s[2]\n" + "fmla v27.4s, v19.4s, v21.s[2]\n" + "fmla v10.4s, v20.4s, v21.s[3]\n" + "fmla v14.4s, v19.4s, v21.s[3]\n" + "fmul v16.4s, v16.4s, v18.s[0]\n" + "fmul v5.4s, v5.4s, v18.s[0]\n" + "fmul v13.4s, v13.4s, v18.s[1]\n" + "fmul v11.4s, v11.4s, v18.s[1]\n" + "fmul v15.4s, v15.4s, v18.s[2]\n" + "fmul v27.4s, v27.4s, v18.s[2]\n" + "fmul v10.4s, v10.4s, v18.s[3]\n" + "fmul v14.4s, v14.4s, v18.s[3]\n" + "fmax v16.4s, v16.4s, v17.4s\n" + "fmax v5.4s, v5.4s, v17.4s\n" + "fmax v13.4s, v13.4s, v17.4s\n" + "fmax v11.4s, v11.4s, v17.4s\n" + "fmax v15.4s, v15.4s, v17.4s\n" + "fmax v27.4s, v27.4s, v17.4s\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v14.4s, v14.4s, v17.4s\n" + "fmin v16.4s, v16.4s, v31.4s\n" + "fmin v5.4s, v5.4s, v31.4s\n" + "fmin v13.4s, v13.4s, v31.4s\n" + "fmin v11.4s, v11.4s, v31.4s\n" + "fmin v15.4s, v15.4s, v31.4s\n" + "fmin v27.4s, v27.4s, v31.4s\n" + "fmin v10.4s, v10.4s, v31.4s\n" + "fmin v14.4s, v14.4s, v31.4s\n" + "blt 6f\n" + "mov x20, %x[dst]\n" + "cmp x27, #0x1\n" + "str q16, [x20, #0x0]\n" + "str q5, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 11f\n" + "cmp x27, #0x2\n" + "str q13, [x20, #0x0]\n" + "str q11, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 11f\n" + "cmp x27, #0x3\n" + "str q15, [x20, #0x0]\n" + "str q27, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 11f\n" + "str q10, [x20, #0x0]\n" + "str q14, [x20, #0x10]\n" + "b 11f\n" + "6:" // Partial output + "mov x23, %x[dst]\n" + "cmp x27, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GE\n" + "cmp x27, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GE\n" + "cmp x27, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GE\n" + "tbz x25, #2, 8f\n" + "st1 { v10.4s }, [x20], #0x10\n" + "st1 { v15.4s }, [x21], #0x10\n" + "st1 { v13.4s }, [x22], #0x10\n" + "st1 { v16.4s }, [x23], #0x10\n" + "tbz x25, #1, 7f\n" + "st1 { v14.d }[0], [x20], #0x8\n" + "st1 { v27.d }[0], [x21], #0x8\n" + "st1 { v11.d }[0], [x22], #0x8\n" + "st1 { v5.d }[0], [x23], #0x8\n" + "tbz x25, #0, 10f\n" + "st1 { v14.s }[2], [x20]\n" + "st1 { v27.s }[2], [x21]\n" + "st1 { v11.s }[2], [x22]\n" + "st1 { v5.s }[2], [x23]\n" + "b 10f\n" + "7:" // Output block 0: partial_1_4 + "tbz x25, #0, 10f\n" + "st1 { v14.s }[0], [x20]\n" + "st1 { v27.s }[0], [x21]\n" + "st1 { v11.s }[0], [x22]\n" + "st1 { v5.s }[0], [x23]\n" + "b 10f\n" + "8:" // Output block 0: partial_2_0 + "tbz x25, #1, 9f\n" + "st1 { v10.d }[0], [x20], #0x8\n" + "st1 { v15.d }[0], [x21], #0x8\n" + "st1 { v13.d }[0], [x22], #0x8\n" + "st1 { v16.d }[0], [x23], #0x8\n" + "tbz x25, #0, 10f\n" + "st1 { v10.s }[2], [x20]\n" + "st1 { v15.s }[2], [x21]\n" + "st1 { v13.s }[2], [x22]\n" + "st1 { v16.s }[2], [x23]\n" + "b 10f\n" + "9:" // Output block 0: partial_1_0 + "st1 { v10.s }[0], [x20]\n" + "st1 { v15.s }[0], [x21]\n" + "st1 { v13.s }[0], [x22]\n" + "st1 { v16.s }[0], [x23]\n" + "10:" // Output block 0: Done + "11:" // Output stage exit + "subs x25, x25, #0x8\n" + "add %x[dst], %x[dst], #0x20\n" + "bgt 2b\n" + "subs x27, x27, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x28\n" + "mov %x[dst], x24\n" + "bgt 1b\n" + "12:" // Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} +#endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h new file mode 100644 index 00000000..31614599 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h @@ -0,0 +1,145 @@ + +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifndef __cplusplus +#include +#endif +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); + +/// Gets the nr value, which must be used to pack the RHS matrix +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 4 +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 8. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 4. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 8. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. +/// Output tile: (rows x cols) = 4 x 8 +/// Accumulation performed in a single for loop: 32 +/// Extension used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] bl Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref +/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h new file mode 100644 index 00000000..fdc7b16e --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h @@ -0,0 +1,52 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f32_qai8dxp_qsi4c32p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k, size_t bl); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_run_matmul_func_t)( + size_t m, size_t n, size_t k, size_t bl, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel { + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_m_step_func_t get_m_step; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_n_step_func_t get_n_step; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_mr_func_t get_mr; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_nr_func_t get_nr; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_nr_func_t get_kr; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_sr_func_t get_sr; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c new file mode 100644 index 00000000..93614b02 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c @@ -0,0 +1,166 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % bl) == 0); + return k / bl; +} + +inline static size_t kai_rhs_stride(size_t k, size_t bl) { + KAI_ASSERT((k % 2) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + + return num_bytes_per_block * num_blocks_per_row; +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kr) == 0); + KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kr) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + + return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs); +} + +inline static size_t kai_rhs_packed_offset_end_of_all_blocks(size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kr) == 0); + KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kr) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + + return (nr * num_bytes_per_block * num_blocks_per_row); +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(size_t n_idx, size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kr) == 0); + KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((n_idx % nr) == 0); + + KAI_UNUSED(kr); + + return (n_idx / nr) * kai_rhs_packed_stride(k, nr, kr, bl); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(size_t n, size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kr) == 0); + KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((n % nr) == 0); + + KAI_UNUSED(kr); + + const size_t num_rows = n / nr; + + return num_rows * kai_rhs_packed_stride(k, nr, kr, bl); +} + +void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, + const int32_t* bias, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params* params) { + // Temporary asserts + KAI_ASSERT(num_groups == 1); + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((n % nr) == 0); + KAI_ASSERT((k % kr) == 0); + KAI_ASSERT((k % bl) == 0); + KAI_ASSERT(bias == NULL); + KAI_ASSERT(extra_bytes == 0); + + KAI_ASSERT(sr == 2); + KAI_ASSERT(kr >= 1 && kr <= 16); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + // Note: The input matrix (rhs) is expected with: + // "k" columns and "n" rows (NxK) + + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_stride = kai_rhs_stride(k, bl); + const size_t rhs_packed_stride = kai_rhs_packed_stride(k, nr, kr, bl); + const size_t rhs_packed_offset_end_of_all_blocks = kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl); + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_segments_per_block = bl / kr; + const size_t num_bytes_per_segment = kr / 2; + + for (size_t y = 0; y < n; y += nr) { + const uint8_t* src_row = (const uint8_t*)rhs + y * rhs_stride; + uint8_t* dst_row = (uint8_t*)rhs_packed + (y / nr) * rhs_packed_stride; + + float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); + + // Initialize the RHS reduction sums to zero + memset(sums, 0, nr * kai_num_bytes_sum_rhs); + + for (size_t x = 0; x < num_blocks_per_row; ++x) { + // Store the scales at the end of the block + float* scales = (float*)(dst_row + num_segments_per_block * num_bytes_per_segment * nr); + + for (size_t i = 0; i < nr; ++i) { + scales[i] = *((const float*)(src_row + i * rhs_stride)); + } + src_row += kai_num_bytes_multiplier_rhs; + + // Store the segments + for (size_t s = 0; s < num_segments_per_block; ++s) { + for (size_t i = 0; i < nr; ++i) { + memcpy(dst_row + i * num_bytes_per_segment, src_row + i * rhs_stride, num_bytes_per_segment); + + for (size_t b = 0; b < num_bytes_per_segment; ++b) { + uint8_t qs = dst_row[i * num_bytes_per_segment + b]; + + const int32_t x0 = (qs & 0x0F) - 8; + const int32_t x1 = (qs >> 4) - 8; + const float d = scales[i]; + + sums[i] += x0 * d; + sums[i] += x1 * d; + + // Add offset (0x88) + dst_row[i * num_bytes_per_segment + b] = qs ^ 0x88; + } + } + + src_row += num_bytes_per_segment; + dst_row += num_bytes_per_segment * nr; + } + + for (size_t i = 0; i < nr; ++i) { + scales[i] *= 0.0625F; + } + + dst_row += (kai_num_bytes_multiplier_rhs * nr); + } + } +} diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h new file mode 100644 index 00000000..193bd2bd --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h @@ -0,0 +1,113 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifndef __cplusplus +#include +#endif +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params { + int8_t lhs_zero_point; + uint8_t rhs_zero_point; +}; + +/// Gets the offset in bytes for the RHS matrix (not packed), which holds +/// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. +/// +/// Two int4 K values are stored in one byte. These values are stored in blocks, where each block +/// has it own scale factor. The scale factor is expected to be a f32 value and stored at the end of each block. +/// The first byte in the block holds the K-index + 0 and K-index + 16 values. +/// The K-index + 0 value is stored in the lower order part of the byte (low nibble) while +/// the K-index + 16 value is stored in the higher order part (high nibble). +/// For example, if the block length is 32, the values are store in the following order: +/// |byte(s16, s0),byte(s17, s1),byte(s18, s2),...,byte(s31, s15),float32(scale)| +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + size_t n_idx, // + size_t rhs_stride); // + +/// Gets the offset in bytes for the packed RHS matrix. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The common dimension between the LHS and RHS matrix (K) +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t bl); // + +/// Gets the size in bytes for the quantized and packed RHS matrix. +/// +/// @param[in] n The number of rows in the RHS matrix (not packed) +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple +/// of 32. +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t bl); // + +/// Runs the RHS packing micro-kernel. +/// +/// The int4 values are stored in a N x K matrix, where N is number of rows and K is the number of columns. +/// Two int4 values are stored in one byte. The lower order part of the byte (low) holds +/// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of columns of the output matrix (N). +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] bl The block length, which defines the number of +/// K values stored in a single block. It must be a multiple of 32. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n/// k/// (sizeof(uint8_t) / 2). +/// @param[in] bias The biases. +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + const int32_t* bias, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params* params); // + +#ifdef __cplusplus +} +#endif -- GitLab From db39547013ac498619d30f4ffaa05827402c03b2 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 13 Jun 2024 12:14:05 +0100 Subject: [PATCH 02/29] Add example for the Int4 matmul blockwise micro-kernel - Add example to demonstrate how to call the Int4 matmul micro-kernels with per-row LHS quantization and per-block RHS quantization - Extend RHS packing function to support f16 scales - Extend the 1x4 and 1x8 Int4 matmul micro-kernels to deal with LHS matrices Signed-off-by: Gian Marco Iodice --- .../CMakeLists.txt | 30 ++ .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 469 ++++++++++++++++++ kai/kai_common.h | 24 + ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c | 182 +++---- ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c | 4 +- ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c | 12 +- ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c | 12 +- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c | 72 ++- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h | 38 +- 9 files changed, 712 insertions(+), 131 deletions(-) create mode 100644 examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt create mode 100644 examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt new file mode 100644 index 00000000..4fb9edd5 --- /dev/null +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt @@ -0,0 +1,30 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.16) + +set(CMAKE_CXX_STANDARD 17) +set(KLEIDIAI_PATH ../../) +set(MATMUL_PACK_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/pack/) +set(MATMUL_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/) + +# KleidiAI include directories +include_directories( + ${KLEIDIAI_PATH} + ${MATMUL_PACK_PATH} + ${MATMUL_PATH}) + +# Files requires to build the executable +add_executable(matmul_clamp_f32_qai8dxp_qsi4c32p + matmul_clamp_f32_qai8dxp_qsi4c32p.cpp + ${KLEIDIAI_PATH}/kai/kai_common.h + ${MATMUL_PACK_PATH}/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c + ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_qai8dxp_f32.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c) + diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp new file mode 100644 index 00000000..9a175c82 --- /dev/null +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -0,0 +1,469 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_MATMUL_INT8) +#error "Dotprod and I8mm extensions required to compile this example" +#else +#include +#include +#include +#include +#include +#include + +// Include micro-kernel variants +#include "kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" +#include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h" + +#define INT4_MIN (-8) +#define INT4_MAX (7) + +// Micro-kernel interface +struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel; + std::string name = {}; +}; + +kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm"}, +}; + +// Number of micro-kernel variants stored in the array +const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); + +static inline size_t num_blocks_per_row(size_t k, size_t bl) { + return k / bl; +} + +static inline size_t num_bytes_per_block(size_t bl) { + return (bl / 2) + sizeof(float); +} + +static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, size_t seed) { + std::srand(seed); + + // Fill the array with random values between -1 and 1 + for (int i = 0; i < num_rows * num_cols; i++) { + if (i % 2 == 0) + dst[i] = (float)((double)std::rand() / RAND_MAX) * 2 - 1; + else + dst[i] = 0; + } +} + +static void quant_qs4c32_f32(size_t n, size_t k, size_t bl, const float* rhs_f32, uint8_t* rhs_qs4c32) { + const size_t num_blocks_row = num_blocks_per_row(k, bl); + const size_t num_bytes_block = num_bytes_per_block(bl); + const size_t dst_stride = num_blocks_row * num_bytes_block; + + for (size_t row_idx = 0; row_idx < n; ++row_idx) { + const float* src_ptr = rhs_f32 + row_idx * k; + + uint8_t* dst_ptr = (uint8_t*)rhs_qs4c32 + row_idx * dst_stride; + + for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { + float amax = 0.0f; + float max = 0.0f; + + for (size_t b = 0; b < bl; ++b) { + const float src0_0 = src_ptr[block_idx * bl + b]; + const float asrc0_0 = fabsf(src0_0); + + if (amax < asrc0_0) { + amax = asrc0_0; + max = src0_0; + } + } + + const float scale = max / -8.0; + const float recip_scale = scale ? 1.0f / scale : 0.0f; + + // Store the scale at the beginning of the block + *((float*)dst_ptr) = scale; + dst_ptr += sizeof(float); + + const size_t block_size = 32; + const size_t num_subblocks = bl / 32; + + for (size_t subblock_idx = 0; subblock_idx < num_subblocks; ++subblock_idx) { + for (size_t i = 0; i < block_size / 2; ++i) { + const size_t src_base_addr = block_idx * bl + i + subblock_idx * block_size; + float v0_f32 = src_ptr[src_base_addr]; + float v1_f32 = src_ptr[src_base_addr + block_size / 2]; + + v0_f32 *= recip_scale; + v1_f32 *= recip_scale; + + const uint8_t v0_u8 = (uint8_t)std::min((int8_t)15, (int8_t)(v0_f32 + 8.0f)); + const uint8_t v1_u8 = (uint8_t)std::min((int8_t)15, (int8_t)(v1_f32 + 8.0f)); + + const uint8_t rhs_v0 = (v1_u8 << 4) | v0_u8; + + dst_ptr[0] = rhs_v0; + dst_ptr += sizeof(uint8_t); + } + } + } + } +}; + +static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { + const size_t dst_stride = (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + const float* src_ptr = lhs_f32 + row_idx * k; + + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + max0 = std::max(src0_0, max0); + min0 = std::min(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = std::min(0.0f, min0); + const float rmax0 = std::max(0.0f, max0); + + const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; + + zero_point0 = std::max(zero_point0, qmin); + zero_point0 = std::min(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = lrintf(zero_point0); + + int8_t* dst_ptr = (int8_t*)lhs_qa8dx + row_idx * dst_stride; + + // LHS offset at the beginning of the row + *((float*)(dst_ptr)) = recip_scale0; + dst_ptr += sizeof(float); + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + dst_ptr += sizeof(int32_t); + + // Quantize the channels + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = std::max(v0_s32, INT8_MIN); + v0_s32 = std::min(v0_s32, INT8_MAX); + dst_ptr[0] = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + } +}; + +static void ref_matmul_f32_qa8dx_qs4c32( + size_t m, size_t n, size_t k, size_t bl, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4c32, float* dst_f32, + float scalar_min, float scalar_max) { + const size_t num_blocks_row = num_blocks_per_row(k, bl); + const size_t num_bytes_block = num_bytes_per_block(bl); + + const size_t lhs_stride = k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); + const size_t rhs_stride = num_blocks_row * num_bytes_block; + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; + + for (size_t col_idx = 0; col_idx < n; ++col_idx) { + // Main f32 accumulator + float main_acc = 0.0f; + + const int8_t* lhs_ptr = lhs_ptr_start; + const uint8_t* rhs_ptr = rhs_qs4c32 + col_idx * rhs_stride; + + // Get the LHS quantization parameters stored at the + // beginning of each row + const float lhs_scale = *(const float*)lhs_ptr; + lhs_ptr += sizeof(float); + + const int32_t lhs_offset = *(const int32_t*)lhs_ptr; + lhs_ptr += sizeof(int32_t); + + for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { + const float rhs_scale = *(const float*)rhs_ptr; + rhs_ptr += sizeof(float); + + int32_t iacc = 0; + + const size_t block_size = 32; + const size_t num_subblocks = bl / 32; + + for (size_t subblock_idx = 0; subblock_idx < num_subblocks; ++subblock_idx) { + for (size_t i = 0; i < block_size / 2; ++i) { + // Get the LHS values + const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; + const int32_t lhs_v1 = (int32_t)lhs_ptr[block_size / 2]; + + // Get the RHS values + const uint8_t rhs_byte = rhs_ptr[0]; + + // Unpack the RHS values + const int32_t rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8); + const int32_t rhs_v1 = (((int32_t)(rhs_byte >> 4)) - 8); + + iacc += lhs_v0 * rhs_v0; + iacc += lhs_v1 * rhs_v1; + iacc += lhs_offset * rhs_v0; + iacc += lhs_offset * rhs_v1; + + lhs_ptr += 1; + rhs_ptr += 1; + } + + lhs_ptr += (block_size / 2); + } + + main_acc += iacc * rhs_scale; + } + + main_acc = main_acc * lhs_scale; + + // Clamp (min-max) operation + main_acc = std::max(main_acc, scalar_min); + main_acc = std::min(main_acc, scalar_max); + + dst_f32[0] = main_acc; + dst_f32 += 1; + } + } +}; + +static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, const float* ref, const float* act) { + bool is_valid = true; + + for (size_t i = 0; i < num_rows * num_cols; ++i) { + if (std::fabs(ref[i] - act[i]) > tolerance) { + const size_t x = i % num_cols; + const size_t y = i / num_cols; + printf("ERROR![%ld][%ld]: ref=%.5f vs. act=%.5f\n", y, x, ref[i], act[i]); + is_valid = false; + } + } + return is_valid; +} + +int main(int argc, char** argv) { + const size_t m = 8; + const size_t n = 8; + const size_t k = 256; + const size_t bl = 32; + const size_t num_blocks_per_row = k / bl; + const size_t num_byte_per_block = bl / 2 + sizeof(float); + const size_t seed_lhs = 4568; + const size_t seed_rhs = seed_lhs + 4; + + const size_t lhs_native_size_f32 = m * k * sizeof(float); + const size_t rhs_native_size_f32 = n * k * sizeof(float); + const size_t rhs_native_size_qs4c32 = n * num_blocks_per_row * num_byte_per_block; + + // Allocate the memory + uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; + uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; + uint8_t* rhs_native_mtx_qs4c32 = new uint8_t[rhs_native_size_qs4c32]; + + fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); + fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); + + quant_qs4c32_f32(n, k, bl, (const float*)rhs_native_mtx_f32, (uint8_t*)rhs_native_mtx_qs4c32); + + delete[] rhs_native_mtx_f32; + + //----------- REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + // Memory sizes for the reference implementation + // After dynamically quantized the LHS matrix, we have the scale and offset for each + // row. The scale (f32) and offset (int32) are stored at the beginning of each row + const size_t lhs_ref_size_qa8dx = m * (k + sizeof(int32_t) + sizeof(float)); + const size_t dst_ref_size_f32 = m * n * sizeof(float); + + uint8_t* lhs_ref_mtx_qa8dx = new uint8_t[lhs_ref_size_qa8dx]; + uint8_t* dst_ref_mtx_f32 = new uint8_t[dst_ref_size_f32]; + + ref_quant_qa8dx_f32(m, k, (const float*)lhs_native_mtx_f32, (int8_t*)lhs_ref_mtx_qa8dx); + + ref_matmul_f32_qa8dx_qs4c32( + m, n, k, // Dimensions + bl, // Block length + (const int8_t*)lhs_ref_mtx_qa8dx, // LHS + (const uint8_t*)rhs_native_mtx_qs4c32, // RHS + (float*)dst_ref_mtx_f32, // DST + -FLT_MAX, FLT_MAX); // Min and max for the clamp operation + + // Remove the unnecessary buffer + delete[] lhs_ref_mtx_qa8dx; + + //----------- END REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + + //----------- MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + for (size_t idx_variant = 0; idx_variant < num_ukernel_variants; ++idx_variant) { + std::cout << "Testing " << ukernel_variants[idx_variant].name << std::endl; + + // Get the packing parameters + const size_t mr = ukernel_variants[idx_variant].ukernel.get_mr(); + const size_t nr = ukernel_variants[idx_variant].ukernel.get_nr(); + const size_t kr = ukernel_variants[idx_variant].ukernel.get_kr(); + const size_t sr = ukernel_variants[idx_variant].ukernel.get_sr(); + + // Get the size in bytes for the packed matrices + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); + const size_t rhs_packed_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(n, k, nr, kr, bl, F32); + const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); + + // Allocate the matrices + uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; + uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; + uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; + + // If the RHS matrix contains constant values, the packing can be performed + // only once + struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = F32; + + // RHS packing + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + 1, n, k, // Dimensions + nr, kr, sr, // Packing arguments + bl, // Block length + (const uint8_t*)(rhs_native_mtx_qs4c32), // RHS + NULL, // Bias + rhs_packed_mtx_qs4cx, // RHS packed + 0, ¶ms); + + // LHS packing + kai_run_lhs_quant_pack_qai8dxp_f32( + m, k, // Dimensions + mr, kr, sr, 0, // Packing arguments + (const float*)lhs_native_mtx_f32, // LHS + k * sizeof(float), // LHS stride + lhs_packed_mtx_qa8dx); // LHS packed + + // Matmul + { + const size_t dst_stride = n * sizeof(float); + const size_t lhs_offset = ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k); + const size_t rhs_offset = ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(0, k, bl); + const size_t dst_offset = ukernel_variants[idx_variant].ukernel.get_dst_offset(0, 0, dst_stride); + + const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); + const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + rhs_offset); + float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); + + ukernel_variants[idx_variant].ukernel.run_matmul( + m, n, k, // Dimensions + bl, // Block length + lhs_ptr, // LHS packed + rhs_ptr, // RHS packed + dst_ptr, // DST + dst_stride, // DST stride (row) + sizeof(float), // DST stride (col) + -FLT_MAX, FLT_MAX // Min and max for the clamp operation + ); + } + + const bool is_valid = + is_output_correct(m, n, 0.0001f, (const float*)dst_ref_mtx_f32, (const float*)dst_act_mtx_f32); + + if (is_valid) { + printf("TEST[%ld] = PASSED\n", idx_variant); + } else { + printf("TEST[%ld] = FAILED\n", idx_variant); + } + delete[] lhs_packed_mtx_qa8dx; + delete[] rhs_packed_mtx_qs4cx; + delete[] dst_act_mtx_f32; + } + delete[] lhs_native_mtx_f32; + delete[] rhs_native_mtx_qs4c32; + delete[] dst_ref_mtx_f32; +} + +//----------- END MICRO-KERNELS TESTS +//------------------------------------ +//------------------------------------ + +#endif // Architectural feature check diff --git a/kai/kai_common.h b/kai/kai_common.h index 60df8162..306cb962 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -50,6 +50,30 @@ extern "C" { #define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b)) #define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b)) +/// KleidiAI data types +/// Format: (reserved)|(num-bytes)|(type)|(variant-type) +enum kai_datatype { + Unknown = 0x0000, + F32 = 0x0411, + F16 = 0x0212, + Bf16 = 0x0213, + Int32 = 0x0421, + Int16 = 0x0222, + Int8 = 0x0124, + Uint32 = 0x0431, + Uint16 = 0x0232, + Uint8 = 0x0134, + Bool = 0x0441 +}; + +/// Gets number of bytes for a given data type +/// @param[in] dt KleidiAI data type +/// +/// @return the numbers of bytes for the data type +inline static size_t kai_num_bytes_datatype(enum kai_datatype dt) { + return (size_t)(dt >> 8); +} + /// Converts a scalar f16 value to f32 /// @param[in] f16 The f16 value /// diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c index 47fee4d3..245e8353 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c @@ -83,7 +83,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( size_t n_idx, size_t k, size_t bl) { KAI_ASSERT((n_idx % kai_n_step) == 0); - KAI_ASSERT(bl == 32); + KAI_ASSERT((bl % 32) == 0); // Temporary assert KAI_ASSERT((k % kai_k0) == 0); @@ -100,9 +100,6 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_do } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(size_t m, size_t n) { - // Temporary assert - KAI_ASSERT((n % kai_nr) == 0); - return m * n * sizeof(float); } @@ -110,9 +107,8 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( size_t m, size_t n, size_t k, size_t bl, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { // Temporary asserts - KAI_ASSERT(n % kai_nr == 0); KAI_ASSERT(k % kai_k0 == 0); - KAI_ASSERT(bl == 32); + KAI_ASSERT((bl % 32) == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { @@ -124,86 +120,98 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( float clamp_vals[2] = {scalar_min, scalar_max}; - __asm__ __volatile__( - "movi v31.16b, #0xf0\n" - "1:" // Column loop - "mov x22, %x[lhs_packed]\n" - "movi v30.16b, #0x0\n" - "mov x21, %x[num_blocks]\n" - "2:" // Block loop - "movi v29.4s, #0x0\n" - "movi v28.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "3:" // Sub block loop - "ldr q27, [%x[rhs_packed], #0x0]\n" - "ldr q26, [%x[rhs_packed], #0x10]\n" - "subs x20, x20, #0x1\n" - "ld1r { v25.2d }, [x22], #0x8\n" - "ldr q24, [%x[rhs_packed], #0x20]\n" - "ldr q23, [%x[rhs_packed], #0x30]\n" - "add %x[rhs_packed], %x[rhs_packed], #0x40\n" - "ld1r { v22.2d }, [x22], #0x8\n" - "ld1r { v21.2d }, [x22], #0x8\n" - "shl v20.16b, v27.16b, #0x4\n" - "shl v19.16b, v26.16b, #0x4\n" - "ld1r { v18.2d }, [x22], #0x8\n" - "shl v17.16b, v24.16b, #0x4\n" - "and v27.16b, v27.16b, v31.16b\n" - "shl v16.16b, v23.16b, #0x4\n" - "and v26.16b, v26.16b, v31.16b\n" - ".inst 0x4e99969d // sdot v29.4s, v20.16b, v25.16b\n" - ".inst 0x4e99967c // sdot v28.4s, v19.16b, v25.16b\n" - "and v24.16b, v24.16b, v31.16b\n" - "and v23.16b, v23.16b, v31.16b\n" - ".inst 0x4e96963d // sdot v29.4s, v17.16b, v22.16b\n" - ".inst 0x4e96961c // sdot v28.4s, v16.16b, v22.16b\n" - ".inst 0x4e95977d // sdot v29.4s, v27.16b, v21.16b\n" - ".inst 0x4e95975c // sdot v28.4s, v26.16b, v21.16b\n" - ".inst 0x4e92971d // sdot v29.4s, v24.16b, v18.16b\n" - ".inst 0x4e9296fc // sdot v28.4s, v23.16b, v18.16b\n" - "bgt 3b\n" - "ldr q16, [%x[rhs_packed], #0x0]\n" - "addp v29.4s, v29.4s, v28.4s\n" - "sub x21, x21, #0x1\n" - "add %x[rhs_packed], %x[rhs_packed], #0x10\n" - "scvtf v29.4s, v29.4s\n" - "fmla v30.4s, v29.4s, v16.4s\n" - "cbnz x21, 2b\n" - "ld1r { v20.4s }, [x22]\n" - "ldr q19, [%x[rhs_packed], #0x0]\n" - "add x22, x22, #0x4\n" - "add x20, %x[clamp_vals], #0x4\n" - "ld1r { v18.4s }, [x22]\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "cmp %x[n], #0x4\n" - "add %x[rhs_packed], %x[rhs_packed], #0x10\n" - "ld1r { v16.4s }, [x20]\n" - "scvtf v20.4s, v20.4s\n" - "fmla v30.4s, v19.4s, v20.s[0]\n" - "fmul v30.4s, v30.4s, v18.4s\n" - "fmax v30.4s, v30.4s, v17.4s\n" - "fmin v30.4s, v30.4s, v16.4s\n" - "blt 4f\n" - "str q30, [%x[dst], #0x0]\n" - "b 7f\n" - "4:" // Partial output - "mov x20, %x[dst]\n" - "tbz %x[n], #1, 5f\n" - "st1 { v30.d }[0], [x20], #0x8\n" - "tbz %x[n], #0, 6f\n" - "st1 { v30.s }[2], [x20]\n" - "b 6f\n" - "5:" // Output block 0: partial_1_0 - "st1 { v30.s }[0], [x20]\n" - "6:" // Output block 0: Done - "7:" // Stores done - "subs %x[n], %x[n], #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 1b\n" - : [dst] "+&r"(dst), [n] "+&r"(n), [rhs_packed] "+&r"(rhs_packed) - : [clamp_vals] "r"(clamp_vals), [lhs_packed] "r"(lhs_packed), [num_blocks] "r"(num_blocks), - [num_subblocks] "r"(num_subblocks) - : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "x20", "x21", "x22"); + const void* lhs_packed_start = lhs_packed; + const void* rhs_packed_start = rhs_packed; + float* dst_start = dst; + const size_t n_original = n; + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + n = n_original; + lhs_packed = lhs_packed_start + row_idx * kai_lhs_packed_stride(k); + rhs_packed = rhs_packed_start; + dst = dst_start + row_idx * (dst_stride_row / sizeof(float)); + + __asm__ __volatile__( + "movi v31.16b, #0xf0\n" + "1:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v30.16b, #0x0\n" + "mov x21, %x[num_blocks]\n" + "2:" // Block loop + "movi v29.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "3:" // Sub block loop + "ldr q27, [%x[rhs_packed], #0x0]\n" + "ldr q26, [%x[rhs_packed], #0x10]\n" + "subs x20, x20, #0x1\n" + "ld1r { v25.2d }, [x22], #0x8\n" + "ldr q24, [%x[rhs_packed], #0x20]\n" + "ldr q23, [%x[rhs_packed], #0x30]\n" + "add %x[rhs_packed], %x[rhs_packed], #0x40\n" + "ld1r { v22.2d }, [x22], #0x8\n" + "ld1r { v21.2d }, [x22], #0x8\n" + "shl v20.16b, v27.16b, #0x4\n" + "shl v19.16b, v26.16b, #0x4\n" + "ld1r { v18.2d }, [x22], #0x8\n" + "shl v17.16b, v24.16b, #0x4\n" + "and v27.16b, v27.16b, v31.16b\n" + "shl v16.16b, v23.16b, #0x4\n" + "and v26.16b, v26.16b, v31.16b\n" + ".inst 0x4e99969d // sdot v29.4s, v20.16b, v25.16b\n" + ".inst 0x4e99967c // sdot v28.4s, v19.16b, v25.16b\n" + "and v24.16b, v24.16b, v31.16b\n" + "and v23.16b, v23.16b, v31.16b\n" + ".inst 0x4e96963d // sdot v29.4s, v17.16b, v22.16b\n" + ".inst 0x4e96961c // sdot v28.4s, v16.16b, v22.16b\n" + ".inst 0x4e95977d // sdot v29.4s, v27.16b, v21.16b\n" + ".inst 0x4e95975c // sdot v28.4s, v26.16b, v21.16b\n" + ".inst 0x4e92971d // sdot v29.4s, v24.16b, v18.16b\n" + ".inst 0x4e9296fc // sdot v28.4s, v23.16b, v18.16b\n" + "bgt 3b\n" + "ldr q16, [%x[rhs_packed], #0x0]\n" + "addp v29.4s, v29.4s, v28.4s\n" + "sub x21, x21, #0x1\n" + "add %x[rhs_packed], %x[rhs_packed], #0x10\n" + "scvtf v29.4s, v29.4s\n" + "fmla v30.4s, v29.4s, v16.4s\n" + "cbnz x21, 2b\n" + "ld1r { v20.4s }, [x22]\n" + "ldr q19, [%x[rhs_packed], #0x0]\n" + "add x22, x22, #0x4\n" + "add x20, %x[clamp_vals], #0x4\n" + "ld1r { v18.4s }, [x22]\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "cmp %x[n], #0x4\n" + "add %x[rhs_packed], %x[rhs_packed], #0x10\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v20.4s, v20.4s\n" + "fmla v30.4s, v19.4s, v20.s[0]\n" + "fmul v30.4s, v30.4s, v18.4s\n" + "fmax v30.4s, v30.4s, v17.4s\n" + "fmin v30.4s, v30.4s, v16.4s\n" + "blt 4f\n" + "str q30, [%x[dst], #0x0]\n" + "b 7f\n" + "4:" // Partial output + "mov x20, %x[dst]\n" + "tbz %x[n], #1, 5f\n" + "st1 { v30.d }[0], [x20], #0x8\n" + "tbz %x[n], #0, 6f\n" + "st1 { v30.s }[2], [x20]\n" + "b 6f\n" + "5:" // Output block 0: partial_1_0 + "st1 { v30.s }[0], [x20]\n" + "6:" // Output block 0: Done + "7:" // Stores done + "subs %x[n], %x[n], #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 1b\n" + : [dst] "+&r"(dst), [n] "+&r"(n), [rhs_packed] "+&r"(rhs_packed) + : [clamp_vals] "r"(clamp_vals), [lhs_packed] "r"(lhs_packed), [num_blocks] "r"(num_blocks), + [num_subblocks] "r"(num_subblocks) + : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "x20", "x21", "x22"); + } } #endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c index 89d0d164..245eeeb1 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c @@ -83,7 +83,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( size_t n_idx, size_t k, size_t bl) { KAI_ASSERT((n_idx % kai_n_step) == 0); - KAI_ASSERT(bl == 32); + KAI_ASSERT((bl % 32) == 0); // Temporary assert KAI_ASSERT((k % kai_k0) == 0); @@ -112,7 +112,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( // Temporary asserts KAI_ASSERT(n % kai_nr == 0); KAI_ASSERT(k % kai_k0 == 0); - KAI_ASSERT(bl == 32); + KAI_ASSERT((bl % 32) == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c index 876b647b..0b3574f5 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -83,7 +83,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( size_t n_idx, size_t k, size_t bl) { KAI_ASSERT((n_idx % kai_n_step) == 0); - KAI_ASSERT(bl == 32); + KAI_ASSERT((bl % 32) == 0); // Temporary assert KAI_ASSERT((k % kai_k0) == 0); @@ -109,6 +109,16 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + // Temporary asserts + KAI_ASSERT(n % kai_nr == 0); + KAI_ASSERT(k % kai_k0 == 0); + KAI_ASSERT((bl % 32) == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + size_t num_subblocks = bl / 32; size_t num_blocks = k / bl; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c index 07f3558a..7ba6efa9 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c @@ -83,7 +83,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( size_t n_idx, size_t k, size_t bl) { KAI_ASSERT((n_idx % kai_n_step) == 0); - KAI_ASSERT(bl == 32); + KAI_ASSERT((bl % 32) == 0); // Temporary assert KAI_ASSERT((k % kai_k0) == 0); @@ -109,6 +109,16 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + // Temporary asserts + KAI_ASSERT(n % kai_nr == 0); + KAI_ASSERT(k % kai_k0 == 0); + KAI_ASSERT((bl % 32) == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + size_t num_subblocks = bl / 32; size_t num_blocks = k / bl; diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c index 93614b02..7d7cc4d0 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c @@ -12,7 +12,6 @@ #include "kai/kai_common.h" static const size_t kai_num_bytes_sum_rhs = sizeof(float); -static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { KAI_ASSERT((k % 2) == 0); @@ -20,35 +19,40 @@ inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { return k / bl; } -inline static size_t kai_rhs_stride(size_t k, size_t bl) { +inline static size_t kai_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { + return (bl / 2) + num_bytes_multiplier_rhs; +} + +inline static size_t kai_rhs_stride(size_t k, size_t bl, size_t num_bytes_multiplier_rhs) { KAI_ASSERT((k % 2) == 0); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); return num_bytes_per_block * num_blocks_per_row; } -inline static size_t kai_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl) { +inline static size_t kai_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { KAI_ASSERT((k % 2) == 0); KAI_ASSERT((k % kr) == 0); KAI_ASSERT((k % bl) == 0); KAI_ASSERT((bl % kr) == 0); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs); } -inline static size_t kai_rhs_packed_offset_end_of_all_blocks(size_t k, size_t nr, size_t kr, size_t bl) { +inline static size_t kai_rhs_packed_offset_end_of_all_blocks( + size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { KAI_ASSERT((k % 2) == 0); KAI_ASSERT((k % kr) == 0); KAI_ASSERT((k % bl) == 0); KAI_ASSERT((bl % kr) == 0); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); return (nr * num_bytes_per_block * num_blocks_per_row); } @@ -58,33 +62,38 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(size_t n_idx, size_ } size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl) { + size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl, enum kai_datatype scale_dt) { KAI_ASSERT((k % 2) == 0); KAI_ASSERT((k % kr) == 0); KAI_ASSERT((k % bl) == 0); KAI_ASSERT((n_idx % nr) == 0); + KAI_ASSERT(scale_dt == F32 || scale_dt == F16); KAI_UNUSED(kr); - return (n_idx / nr) * kai_rhs_packed_stride(k, nr, kr, bl); + const size_t num_bytes_scale = kai_num_bytes_datatype(scale_dt); + return (n_idx / nr) * kai_rhs_packed_stride(k, nr, kr, bl, num_bytes_scale); } -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(size_t n, size_t k, size_t nr, size_t kr, size_t bl) { +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + size_t n, size_t k, size_t nr, size_t kr, size_t bl, enum kai_datatype scale_dt) { KAI_ASSERT((k % 2) == 0); KAI_ASSERT((k % kr) == 0); KAI_ASSERT((k % bl) == 0); KAI_ASSERT((n % nr) == 0); + KAI_ASSERT(scale_dt == F32 || scale_dt == F16); KAI_UNUSED(kr); + const size_t num_bytes_scale = kai_num_bytes_datatype(scale_dt); const size_t num_rows = n / nr; - return num_rows * kai_rhs_packed_stride(k, nr, kr, bl); + return num_rows * kai_rhs_packed_stride(k, nr, kr, bl, scale_dt); } void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, - const int32_t* bias, void* rhs_packed, size_t extra_bytes, + const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params* params) { // Temporary asserts KAI_ASSERT(num_groups == 1); @@ -102,17 +111,21 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( KAI_ASSERT(params != NULL); KAI_ASSERT(params->rhs_zero_point == 8); KAI_ASSERT(params->lhs_zero_point == 1); + KAI_ASSERT(params->scale_dt == F32 || params->scale_dt == F16); // Note: The input matrix (rhs) is expected with: // "k" columns and "n" rows (NxK) + const size_t num_bytes_multiplier_rhs = kai_num_bytes_datatype(params->scale_dt); const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_stride = kai_rhs_stride(k, bl); - const size_t rhs_packed_stride = kai_rhs_packed_stride(k, nr, kr, bl); - const size_t rhs_packed_offset_end_of_all_blocks = kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl); + const size_t rhs_stride = kai_rhs_stride(k, bl, num_bytes_multiplier_rhs); + const size_t rhs_packed_stride = kai_rhs_packed_stride(k, nr, kr, bl, num_bytes_multiplier_rhs); + const size_t rhs_packed_offset_end_of_all_blocks = + kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_segments_per_block = bl / kr; const size_t num_bytes_per_segment = kr / 2; + const bool is_scale_f32 = params->scale_dt == F32; for (size_t y = 0; y < n; y += nr) { const uint8_t* src_row = (const uint8_t*)rhs + y * rhs_stride; @@ -125,12 +138,12 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( for (size_t x = 0; x < num_blocks_per_row; ++x) { // Store the scales at the end of the block - float* scales = (float*)(dst_row + num_segments_per_block * num_bytes_per_segment * nr); + uint8_t* scales = (dst_row + (bl / 2) * nr); for (size_t i = 0; i < nr; ++i) { - scales[i] = *((const float*)(src_row + i * rhs_stride)); + memcpy(scales + i * num_bytes_multiplier_rhs, src_row + i * rhs_stride, num_bytes_multiplier_rhs); } - src_row += kai_num_bytes_multiplier_rhs; + src_row += num_bytes_multiplier_rhs; // Store the segments for (size_t s = 0; s < num_segments_per_block; ++s) { @@ -142,13 +155,19 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( const int32_t x0 = (qs & 0x0F) - 8; const int32_t x1 = (qs >> 4) - 8; - const float d = scales[i]; - - sums[i] += x0 * d; - sums[i] += x1 * d; // Add offset (0x88) dst_row[i * num_bytes_per_segment + b] = qs ^ 0x88; + + float d = 0.0F; + if (is_scale_f32) { + d = ((float*)scales)[i]; + } else { + d = kai_f16_to_f32(((uint16_t*)scales)[i]); + } + + sums[i] += x0 * d; + sums[i] += x1 * d; } } @@ -157,10 +176,15 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( } for (size_t i = 0; i < nr; ++i) { - scales[i] *= 0.0625F; + if (is_scale_f32) { + ((float*)scales)[i] *= 0.0625F; + } else { + const float d = kai_f16_to_f32(((uint16_t*)scales)[i]); + ((float*)scales)[i] = kai_f32_to_f16(d * 0.0625F); + } } - dst_row += (kai_num_bytes_multiplier_rhs * nr); + dst_row += (num_bytes_multiplier_rhs * nr); } } } diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h index 193bd2bd..816862f5 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h @@ -11,6 +11,8 @@ #include #include +#include "kai/kai_common.h" + #ifdef __cplusplus extern "C" { #endif @@ -18,6 +20,7 @@ extern "C" { struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params { int8_t lhs_zero_point; uint8_t rhs_zero_point; + enum kai_datatype scale_dt; }; /// Gets the offset in bytes for the RHS matrix (not packed), which holds @@ -41,20 +44,22 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( /// Gets the offset in bytes for the packed RHS matrix. /// -/// @param[in] n_idx Row index in the RHS matrix (not packed). -/// @param[in] k The common dimension between the LHS and RHS matrix (K) -/// @param[in] nr The number of columns written by the matmul micro-kernel -/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. -/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The common dimension between the LHS and RHS matrix (K) +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a /// multiple of 32. +/// @param[in] scale_dt Block scale data type /// /// @return the offset in bytes to the packed RHS matrix size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - size_t n_idx, // - size_t k, // - size_t nr, // - size_t kr, // - size_t bl); // + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t bl, // + enum kai_datatype scale_dt); // /// Gets the size in bytes for the quantized and packed RHS matrix. /// @@ -67,11 +72,12 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( /// /// @return the packed RHS matrix size in bytes size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - size_t n, // - size_t k, // - size_t nr, // - size_t kr, // - size_t bl); // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t bl, // + enum kai_datatype scale_dt); // /// Runs the RHS packing micro-kernel. /// @@ -103,7 +109,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( size_t sr, // size_t bl, // const uint8_t* rhs, // - const int32_t* bias, // + const float* bias, // void* rhs_packed, // size_t extra_bytes, // const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params* params); // -- GitLab From b05c1d6e56a16c2ad426fc5eb7c10a9f1fe2c4c7 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 17 Jun 2024 14:42:15 +0100 Subject: [PATCH 03/29] Extend Int4 matmul micro-kernels to add bias before storing the result - Extend packing function to pack the bias at the end of each packed matrix - Fix Int4 matmul micro-kernels to work on any block length multiple of 32 - Extend Int4 matmul micro-kernels to add the bias before storing the result The 1x4 variant remains to be fixed Signed-off-by: Gian Marco Iodice --- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 5 +- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c | 190 +++--- ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c | 103 ++-- ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c | 572 +++++++++--------- ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c | 352 +++++------ .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c | 17 +- 6 files changed, 648 insertions(+), 591 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index 9a175c82..aea682a7 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -98,10 +98,7 @@ static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, si // Fill the array with random values between -1 and 1 for (int i = 0; i < num_rows * num_cols; i++) { - if (i % 2 == 0) - dst[i] = (float)((double)std::rand() / RAND_MAX) * 2 - 1; - else - dst[i] = 0; + dst[i] = (float)((double)std::rand() / RAND_MAX) * 2 - 1; } } diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c index 245e8353..038644e8 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c @@ -24,6 +24,7 @@ static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { KAI_ASSERT((k % 2) == 0); @@ -47,7 +48,7 @@ inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; - return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs); + return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { @@ -120,98 +121,99 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( float clamp_vals[2] = {scalar_min, scalar_max}; - const void* lhs_packed_start = lhs_packed; - const void* rhs_packed_start = rhs_packed; - float* dst_start = dst; - const size_t n_original = n; - - for (size_t row_idx = 0; row_idx < m; ++row_idx) { - n = n_original; - lhs_packed = lhs_packed_start + row_idx * kai_lhs_packed_stride(k); - rhs_packed = rhs_packed_start; - dst = dst_start + row_idx * (dst_stride_row / sizeof(float)); - - __asm__ __volatile__( - "movi v31.16b, #0xf0\n" - "1:" // Column loop - "mov x22, %x[lhs_packed]\n" - "movi v30.16b, #0x0\n" - "mov x21, %x[num_blocks]\n" - "2:" // Block loop - "movi v29.4s, #0x0\n" - "movi v28.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "3:" // Sub block loop - "ldr q27, [%x[rhs_packed], #0x0]\n" - "ldr q26, [%x[rhs_packed], #0x10]\n" - "subs x20, x20, #0x1\n" - "ld1r { v25.2d }, [x22], #0x8\n" - "ldr q24, [%x[rhs_packed], #0x20]\n" - "ldr q23, [%x[rhs_packed], #0x30]\n" - "add %x[rhs_packed], %x[rhs_packed], #0x40\n" - "ld1r { v22.2d }, [x22], #0x8\n" - "ld1r { v21.2d }, [x22], #0x8\n" - "shl v20.16b, v27.16b, #0x4\n" - "shl v19.16b, v26.16b, #0x4\n" - "ld1r { v18.2d }, [x22], #0x8\n" - "shl v17.16b, v24.16b, #0x4\n" - "and v27.16b, v27.16b, v31.16b\n" - "shl v16.16b, v23.16b, #0x4\n" - "and v26.16b, v26.16b, v31.16b\n" - ".inst 0x4e99969d // sdot v29.4s, v20.16b, v25.16b\n" - ".inst 0x4e99967c // sdot v28.4s, v19.16b, v25.16b\n" - "and v24.16b, v24.16b, v31.16b\n" - "and v23.16b, v23.16b, v31.16b\n" - ".inst 0x4e96963d // sdot v29.4s, v17.16b, v22.16b\n" - ".inst 0x4e96961c // sdot v28.4s, v16.16b, v22.16b\n" - ".inst 0x4e95977d // sdot v29.4s, v27.16b, v21.16b\n" - ".inst 0x4e95975c // sdot v28.4s, v26.16b, v21.16b\n" - ".inst 0x4e92971d // sdot v29.4s, v24.16b, v18.16b\n" - ".inst 0x4e9296fc // sdot v28.4s, v23.16b, v18.16b\n" - "bgt 3b\n" - "ldr q16, [%x[rhs_packed], #0x0]\n" - "addp v29.4s, v29.4s, v28.4s\n" - "sub x21, x21, #0x1\n" - "add %x[rhs_packed], %x[rhs_packed], #0x10\n" - "scvtf v29.4s, v29.4s\n" - "fmla v30.4s, v29.4s, v16.4s\n" - "cbnz x21, 2b\n" - "ld1r { v20.4s }, [x22]\n" - "ldr q19, [%x[rhs_packed], #0x0]\n" - "add x22, x22, #0x4\n" - "add x20, %x[clamp_vals], #0x4\n" - "ld1r { v18.4s }, [x22]\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "cmp %x[n], #0x4\n" - "add %x[rhs_packed], %x[rhs_packed], #0x10\n" - "ld1r { v16.4s }, [x20]\n" - "scvtf v20.4s, v20.4s\n" - "fmla v30.4s, v19.4s, v20.s[0]\n" - "fmul v30.4s, v30.4s, v18.4s\n" - "fmax v30.4s, v30.4s, v17.4s\n" - "fmin v30.4s, v30.4s, v16.4s\n" - "blt 4f\n" - "str q30, [%x[dst], #0x0]\n" - "b 7f\n" - "4:" // Partial output - "mov x20, %x[dst]\n" - "tbz %x[n], #1, 5f\n" - "st1 { v30.d }[0], [x20], #0x8\n" - "tbz %x[n], #0, 6f\n" - "st1 { v30.s }[2], [x20]\n" - "b 6f\n" - "5:" // Output block 0: partial_1_0 - "st1 { v30.s }[0], [x20]\n" - "6:" // Output block 0: Done - "7:" // Stores done - "subs %x[n], %x[n], #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bgt 1b\n" - : [dst] "+&r"(dst), [n] "+&r"(n), [rhs_packed] "+&r"(rhs_packed) - : [clamp_vals] "r"(clamp_vals), [lhs_packed] "r"(lhs_packed), [num_blocks] "r"(num_blocks), - [num_subblocks] "r"(num_subblocks) - : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "x20", "x21", "x22"); - } + __asm__ __volatile__( + "mov x27, #0x20\n" + "mov x20, #0x8\n" + "movi v31.16b, #0xf0\n" + "mov x26, %x[m]\n" + "mul x27, %x[num_subblocks], x27\n" + "madd x27, %x[num_blocks], x27, x20\n" + "1:" // Row loop + "mov x25, %x[rhs_packed]\n" + "mov x24, %x[n]\n" + "add x23, %x[dst], %x[dst_stride_row]\n" + "2:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v30.16b, #0x0\n" + "mov x21, %x[num_blocks]\n" + "3:" // Block loop + "movi v29.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "4:" // Sub block loop + "ldr q27, [x25, #0x0]\n" + "ldr q26, [x25, #0x10]\n" + "subs x20, x20, #0x1\n" + "ld1r { v25.2d }, [x22], #0x8\n" + "ldr q24, [x25, #0x20]\n" + "ldr q23, [x25, #0x30]\n" + "add x25, x25, #0x40\n" + "ld1r { v22.2d }, [x22], #0x8\n" + "ld1r { v21.2d }, [x22], #0x8\n" + "shl v20.16b, v27.16b, #0x4\n" + "shl v19.16b, v26.16b, #0x4\n" + "ld1r { v18.2d }, [x22], #0x8\n" + "shl v17.16b, v24.16b, #0x4\n" + "and v27.16b, v27.16b, v31.16b\n" + "shl v16.16b, v23.16b, #0x4\n" + "and v26.16b, v26.16b, v31.16b\n" + ".inst 0x4e99969d // sdot v29.4s, v20.16b, v25.16b\n" + ".inst 0x4e99967c // sdot v28.4s, v19.16b, v25.16b\n" + "and v24.16b, v24.16b, v31.16b\n" + "and v23.16b, v23.16b, v31.16b\n" + ".inst 0x4e96963d // sdot v29.4s, v17.16b, v22.16b\n" + ".inst 0x4e96961c // sdot v28.4s, v16.16b, v22.16b\n" + ".inst 0x4e95977d // sdot v29.4s, v27.16b, v21.16b\n" + ".inst 0x4e95975c // sdot v28.4s, v26.16b, v21.16b\n" + ".inst 0x4e92971d // sdot v29.4s, v24.16b, v18.16b\n" + ".inst 0x4e9296fc // sdot v28.4s, v23.16b, v18.16b\n" + "bgt 4b\n" + "ldr q16, [x25, #0x0]\n" + "addp v29.4s, v29.4s, v28.4s\n" + "sub x21, x21, #0x1\n" + "add x25, x25, #0x10\n" + "scvtf v29.4s, v29.4s\n" + "fmla v30.4s, v29.4s, v16.4s\n" + "cbnz x21, 3b\n" + "ld1r { v20.4s }, [x22]\n" + "ldr q19, [x25, #0x0]\n" + "add x22, x22, #0x4\n" + "add x20, %x[clamp_vals], #0x4\n" + "ld1r { v18.4s }, [x22]\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "cmp x24, #0x4\n" + "add x25, x25, #0x10\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v20.4s, v20.4s\n" + "fmla v30.4s, v19.4s, v20.s[0]\n" + "fmul v30.4s, v30.4s, v18.4s\n" + "fmax v30.4s, v30.4s, v17.4s\n" + "fmin v30.4s, v30.4s, v16.4s\n" + "blt 5f\n" + "str q30, [%x[dst], #0x0]\n" + "b 8f\n" + "5:" // Partial output + "mov x20, %x[dst]\n" + "tbz x24, #1, 6f\n" + "st1 { v30.d }[0], [x20], #0x8\n" + "tbz x24, #0, 7f\n" + "st1 { v30.s }[2], [x20]\n" + "b 7f\n" + "6:" // Output block 0: partial_1_0 + "st1 { v30.s }[0], [x20]\n" + "7:" // Output block 0: Done + "8:" // Stores done + "subs x24, x24, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "subs x26, x26, #0x1\n" + "add %x[lhs_packed], %x[lhs_packed], x27\n" + "mov %x[dst], x23\n" + "bgt 1b\n" + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27"); } #endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c index 245eeeb1..c0364c66 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c @@ -24,6 +24,7 @@ static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { KAI_ASSERT((k % 2) == 0); @@ -47,7 +48,7 @@ inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; - return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs); + return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { @@ -125,34 +126,43 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( float clamp_vals[2] = {scalar_min, scalar_max}; __asm__ __volatile__( + "mov x27, #0x20\n" + "mov x20, #0x8\n" "movi v7.16b, #0xf0\n" - "1:" // Column loop + "mov x26, %x[m]\n" + "mul x27, %x[num_subblocks], x27\n" + "madd x27, %x[num_blocks], x27, x20\n" + "1:" // Row loop + "mov x25, %x[rhs_packed]\n" + "mov x24, %x[n]\n" + "add x23, %x[dst], %x[dst_stride_row]\n" + "2:" // Column loop "mov x22, %x[lhs_packed]\n" "movi v6.16b, #0x0\n" "movi v5.16b, #0x0\n" "mov x21, %x[num_blocks]\n" - "2:" // Block loop + "3:" // Block loop "movi v4.4s, #0x0\n" "movi v3.4s, #0x0\n" "mov x20, %x[num_subblocks]\n" "movi v2.4s, #0x0\n" "movi v1.4s, #0x0\n" - "3:" // Sub block loop - "ldr q0, [%x[rhs_packed], #0x0]\n" - "ldr q31, [%x[rhs_packed], #0x10]\n" + "4:" // Sub block loop + "ldr q0, [x25, #0x0]\n" + "ldr q31, [x25, #0x10]\n" "subs x20, x20, #0x1\n" - "ldr q30, [%x[rhs_packed], #0x20]\n" - "ldr q29, [%x[rhs_packed], #0x30]\n" + "ldr q30, [x25, #0x20]\n" + "ldr q29, [x25, #0x30]\n" "ld1r { v28.2d }, [x22], #0x8\n" - "ldr q27, [%x[rhs_packed], #0x40]\n" - "ldr q26, [%x[rhs_packed], #0x50]\n" - "ldr q25, [%x[rhs_packed], #0x60]\n" + "ldr q27, [x25, #0x40]\n" + "ldr q26, [x25, #0x50]\n" + "ldr q25, [x25, #0x60]\n" "shl v24.16b, v0.16b, #0x4\n" "shl v18.16b, v31.16b, #0x4\n" - "ldr q23, [%x[rhs_packed], #0x70]\n" + "ldr q23, [x25, #0x70]\n" "shl v17.16b, v30.16b, #0x4\n" "shl v16.16b, v29.16b, #0x4\n" - "add %x[rhs_packed], %x[rhs_packed], #0x80\n" + "add x25, x25, #0x80\n" "ld1r { v22.2d }, [x22], #0x8\n" "shl v21.16b, v27.16b, #0x4\n" "and v0.16b, v0.16b, v7.16b\n" @@ -184,26 +194,26 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( ".inst 0x4e939743 // sdot v3.4s, v26.16b, v19.16b\n" ".inst 0x4e939722 // sdot v2.4s, v25.16b, v19.16b\n" ".inst 0x4e9396e1 // sdot v1.4s, v23.16b, v19.16b\n" - "bgt 3b\n" - "ldr q17, [%x[rhs_packed], #0x0]\n" - "ldr q16, [%x[rhs_packed], #0x10]\n" + "bgt 4b\n" + "ldr q17, [x25, #0x0]\n" + "ldr q16, [x25, #0x10]\n" "addp v4.4s, v4.4s, v3.4s\n" "addp v2.4s, v2.4s, v1.4s\n" "sub x21, x21, #0x1\n" - "add %x[rhs_packed], %x[rhs_packed], #0x20\n" + "add x25, x25, #0x20\n" "scvtf v4.4s, v4.4s\n" "scvtf v2.4s, v2.4s\n" "fmla v6.4s, v4.4s, v17.4s\n" "fmla v5.4s, v2.4s, v16.4s\n" - "cbnz x21, 2b\n" + "cbnz x21, 3b\n" "ld1r { v21.4s }, [x22]\n" - "ldr q20, [%x[rhs_packed], #0x0]\n" + "ldr q20, [x25, #0x0]\n" "add x22, x22, #0x4\n" "add x20, %x[clamp_vals], #0x4\n" - "ldr q19, [%x[rhs_packed], #0x10]\n" + "ldr q19, [x25, #0x10]\n" "ld1r { v18.4s }, [x22]\n" - "cmp %x[n], #0x8\n" - "add %x[rhs_packed], %x[rhs_packed], #0x20\n" + "cmp x24, #0x8\n" + "add x25, x25, #0x20\n" "ld1r { v17.4s }, [%x[clamp_vals]]\n" "ld1r { v16.4s }, [x20]\n" "scvtf v21.4s, v21.4s\n" @@ -215,40 +225,45 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( "fmax v5.4s, v5.4s, v17.4s\n" "fmin v6.4s, v6.4s, v16.4s\n" "fmin v5.4s, v5.4s, v16.4s\n" - "blt 4f\n" + "blt 5f\n" "str q6, [%x[dst], #0x0]\n" "str q5, [%x[dst], #0x10]\n" - "b 9f\n" - "4:" // Partial output + "b 10f\n" + "5:" // Partial output "mov x20, %x[dst]\n" - "tbz %x[n], #2, 6f\n" + "tbz x24, #2, 7f\n" "st1 { v6.4s }, [x20], #0x10\n" - "tbz %x[n], #1, 5f\n" + "tbz x24, #1, 6f\n" "st1 { v5.d }[0], [x20], #0x8\n" - "tbz %x[n], #0, 8f\n" + "tbz x24, #0, 9f\n" "st1 { v5.s }[2], [x20]\n" - "b 8f\n" - "5:" // Output block 0: partial_1_4 - "tbz %x[n], #0, 8f\n" + "b 9f\n" + "6:" // Output block 0: partial_1_4 + "tbz x24, #0, 9f\n" "st1 { v5.s }[0], [x20]\n" - "b 8f\n" - "6:" // Output block 0: partial_2_0 - "tbz %x[n], #1, 7f\n" + "b 9f\n" + "7:" // Output block 0: partial_2_0 + "tbz x24, #1, 8f\n" "st1 { v6.d }[0], [x20], #0x8\n" - "tbz %x[n], #0, 8f\n" + "tbz x24, #0, 9f\n" "st1 { v6.s }[2], [x20]\n" - "b 8f\n" - "7:" // Output block 0: partial_1_0 + "b 9f\n" + "8:" // Output block 0: partial_1_0 "st1 { v6.s }[0], [x20]\n" - "8:" // Output block 0: Done - "9:" // Stores done - "subs %x[n], %x[n], #0x8\n" + "9:" // Output block 0: Done + "10:" // Stores done + "subs x24, x24, #0x8\n" "add %x[dst], %x[dst], #0x20\n" + "bgt 2b\n" + "subs x26, x26, #0x1\n" + "add %x[lhs_packed], %x[lhs_packed], x27\n" + "mov %x[dst], x23\n" "bgt 1b\n" - : [dst] "+&r"(dst), [n] "+&r"(n), [rhs_packed] "+&r"(rhs_packed) - : [clamp_vals] "r"(clamp_vals), [lhs_packed] "r"(lhs_packed), [num_blocks] "r"(num_blocks), - [num_subblocks] "r"(num_subblocks) + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"); + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", + "x25", "x26", "x27"); } #endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c index 0b3574f5..437b90c2 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -24,6 +24,7 @@ static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { KAI_ASSERT((k % 2) == 0); @@ -47,7 +48,7 @@ inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; - return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs); + return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { @@ -125,12 +126,13 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( float clamp_vals[2] = {scalar_min, scalar_max}; __asm__ __volatile__( - "mov x12, %x[m]\n" - "mov x11, #0x80\n" - "movi v10.16b, #0xf0\n" + "mov x12, #0x80\n" + "mov x11, %x[m]\n" + "movi v15.16b, #0xf0\n" "mov x20, #0x20\n" - "cmp x12, #0x8\n" - "madd x11, %x[num_blocks], x11, x20\n" + "mul x12, %x[num_subblocks], x12\n" + "cmp x11, #0x8\n" + "madd x12, %x[num_blocks], x12, x20\n" "blt 11f\n" "1:" // Row loop "mov x10, %x[rhs_packed]\n" @@ -138,183 +140,192 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" "2:" // Column loop "mov x23, %x[lhs_packed]\n" - "movi v6.16b, #0x0\n" - "movi v31.16b, #0x0\n" - "mov x22, %x[num_blocks]\n" - "movi v1.16b, #0x0\n" - "movi v2.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "movi v21.16b, #0x0\n" "movi v22.16b, #0x0\n" "movi v12.16b, #0x0\n" - "add x21, x23, x11\n" + "mov x22, %x[num_blocks]\n" + "movi v13.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "add x21, x23, x12\n" "3:" // Block loop - "movi v8.4s, #0x0\n" - "movi v11.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "movi v0.4s, #0x0\n" "mov x20, %x[num_subblocks]\n" - "movi v17.4s, #0x0\n" + "movi v2.4s, #0x0\n" + "movi v1.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v8.4s, #0x0\n" "movi v4.4s, #0x0\n" - "movi v23.4s, #0x0\n" - "movi v15.4s, #0x0\n" - "movi v27.4s, #0x0\n" - "movi v5.4s, #0x0\n" + "movi v10.4s, #0x0\n" "4:" // Sub block loop - "ldr q3, [x10, #0x0]\n" - "ldr q19, [x10, #0x10]\n" + "ldr q16, [x10, #0x0]\n" + "ldr q29, [x10, #0x10]\n" "subs x20, x20, #0x1\n" - "ldr q7, [x23, #0x0]\n" - "ldr q28, [x23, #0x10]\n" - "ldr q14, [x21, #0x0]\n" - "ldr q18, [x21, #0x10]\n" - "ldr q29, [x10, #0x20]\n" - "ldr q25, [x10, #0x30]\n" - "shl v20.16b, v3.16b, #0x4\n" - "shl v16.16b, v19.16b, #0x4\n" - "ldr q24, [x23, #0x20]\n" - "ldr q9, [x23, #0x30]\n" - "and v3.16b, v3.16b, v10.16b\n" - "and v19.16b, v19.16b, v10.16b\n" - "ldr q0, [x21, #0x20]\n" - "ldr q26, [x21, #0x30]\n" + "ldr q24, [x23, #0x0]\n" + "ldr q23, [x23, #0x10]\n" + "ldr q28, [x21, #0x0]\n" + "ldr q6, [x21, #0x10]\n" + "ldr q27, [x10, #0x20]\n" + "ldr q11, [x10, #0x30]\n" + "shl v18.16b, v16.16b, #0x4\n" + "shl v21.16b, v29.16b, #0x4\n" + "ldr q17, [x23, #0x20]\n" + "ldr q19, [x23, #0x30]\n" + "and v16.16b, v16.16b, v15.16b\n" + "and v29.16b, v29.16b, v15.16b\n" + "ldr q26, [x21, #0x20]\n" + "ldr q9, [x21, #0x30]\n" "add x10, x10, #0x40\n" - "ldr q30, [x23, #0x40]\n" - ".inst 0x4e94a4e8 // smmla v8.4s, v7.16b, v20.16b\n" - ".inst 0x4e90a4eb // smmla v11.4s, v7.16b, v16.16b\n" - "ldr q7, [x23, #0x50]\n" - ".inst 0x4e94a791 // smmla v17.4s, v28.16b, v20.16b\n" - ".inst 0x4e90a784 // smmla v4.4s, v28.16b, v16.16b\n" - "ldr q28, [x21, #0x40]\n" - ".inst 0x4e94a5d7 // smmla v23.4s, v14.16b, v20.16b\n" - ".inst 0x4e90a5cf // smmla v15.4s, v14.16b, v16.16b\n" - "ldr q14, [x21, #0x50]\n" - ".inst 0x4e94a65b // smmla v27.4s, v18.16b, v20.16b\n" - "ldr q20, [x23, #0x60]\n" - ".inst 0x4e90a645 // smmla v5.4s, v18.16b, v16.16b\n" - "ldr q18, [x23, #0x70]\n" - "shl v16.16b, v29.16b, #0x4\n" - "and v29.16b, v29.16b, v10.16b\n" + "ldr q25, [x23, #0x40]\n" + ".inst 0x4e92a703 // smmla v3.4s, v24.16b, v18.16b\n" + ".inst 0x4e95a700 // smmla v0.4s, v24.16b, v21.16b\n" + "ldr q24, [x23, #0x50]\n" + ".inst 0x4e92a6e2 // smmla v2.4s, v23.16b, v18.16b\n" + ".inst 0x4e95a6e1 // smmla v1.4s, v23.16b, v21.16b\n" + "ldr q23, [x21, #0x40]\n" + ".inst 0x4e92a79e // smmla v30.4s, v28.16b, v18.16b\n" + ".inst 0x4e95a788 // smmla v8.4s, v28.16b, v21.16b\n" + "ldr q28, [x21, #0x50]\n" + ".inst 0x4e92a4c4 // smmla v4.4s, v6.16b, v18.16b\n" + "ldr q18, [x23, #0x60]\n" + ".inst 0x4e95a4ca // smmla v10.4s, v6.16b, v21.16b\n" + "ldr q6, [x23, #0x70]\n" + "shl v21.16b, v27.16b, #0x4\n" + "and v27.16b, v27.16b, v15.16b\n" "add x23, x23, #0x80\n" - ".inst 0x4e90a708 // smmla v8.4s, v24.16b, v16.16b\n" - ".inst 0x4e90a531 // smmla v17.4s, v9.16b, v16.16b\n" - ".inst 0x4e90a417 // smmla v23.4s, v0.16b, v16.16b\n" - ".inst 0x4e90a75b // smmla v27.4s, v26.16b, v16.16b\n" - "ldr q16, [x21, #0x60]\n" - ".inst 0x4e83a7c8 // smmla v8.4s, v30.16b, v3.16b\n" - ".inst 0x4e83a4f1 // smmla v17.4s, v7.16b, v3.16b\n" - ".inst 0x4e83a797 // smmla v23.4s, v28.16b, v3.16b\n" - ".inst 0x4e83a5db // smmla v27.4s, v14.16b, v3.16b\n" - "ldr q3, [x21, #0x70]\n" + ".inst 0x4e95a623 // smmla v3.4s, v17.16b, v21.16b\n" + ".inst 0x4e95a662 // smmla v2.4s, v19.16b, v21.16b\n" + ".inst 0x4e95a75e // smmla v30.4s, v26.16b, v21.16b\n" + ".inst 0x4e95a524 // smmla v4.4s, v9.16b, v21.16b\n" + "ldr q21, [x21, #0x60]\n" + ".inst 0x4e90a723 // smmla v3.4s, v25.16b, v16.16b\n" + ".inst 0x4e90a702 // smmla v2.4s, v24.16b, v16.16b\n" + ".inst 0x4e90a6fe // smmla v30.4s, v23.16b, v16.16b\n" + ".inst 0x4e90a784 // smmla v4.4s, v28.16b, v16.16b\n" + "ldr q16, [x21, #0x70]\n" "add x21, x21, #0x80\n" - ".inst 0x4e9da688 // smmla v8.4s, v20.16b, v29.16b\n" - ".inst 0x4e9da651 // smmla v17.4s, v18.16b, v29.16b\n" - ".inst 0x4e9da617 // smmla v23.4s, v16.16b, v29.16b\n" - ".inst 0x4e9da47b // smmla v27.4s, v3.16b, v29.16b\n" - "shl v29.16b, v25.16b, #0x4\n" - "and v25.16b, v25.16b, v10.16b\n" - ".inst 0x4e9da70b // smmla v11.4s, v24.16b, v29.16b\n" - ".inst 0x4e9da524 // smmla v4.4s, v9.16b, v29.16b\n" - ".inst 0x4e9da40f // smmla v15.4s, v0.16b, v29.16b\n" - ".inst 0x4e9da745 // smmla v5.4s, v26.16b, v29.16b\n" - ".inst 0x4e93a7cb // smmla v11.4s, v30.16b, v19.16b\n" - ".inst 0x4e93a4e4 // smmla v4.4s, v7.16b, v19.16b\n" - ".inst 0x4e93a78f // smmla v15.4s, v28.16b, v19.16b\n" - ".inst 0x4e93a5c5 // smmla v5.4s, v14.16b, v19.16b\n" - ".inst 0x4e99a68b // smmla v11.4s, v20.16b, v25.16b\n" - ".inst 0x4e99a644 // smmla v4.4s, v18.16b, v25.16b\n" - ".inst 0x4e99a60f // smmla v15.4s, v16.16b, v25.16b\n" - ".inst 0x4e99a465 // smmla v5.4s, v3.16b, v25.16b\n" + ".inst 0x4e9ba643 // smmla v3.4s, v18.16b, v27.16b\n" + ".inst 0x4e9ba4c2 // smmla v2.4s, v6.16b, v27.16b\n" + ".inst 0x4e9ba6be // smmla v30.4s, v21.16b, v27.16b\n" + ".inst 0x4e9ba604 // smmla v4.4s, v16.16b, v27.16b\n" + "shl v27.16b, v11.16b, #0x4\n" + "and v11.16b, v11.16b, v15.16b\n" + ".inst 0x4e9ba620 // smmla v0.4s, v17.16b, v27.16b\n" + ".inst 0x4e9ba661 // smmla v1.4s, v19.16b, v27.16b\n" + ".inst 0x4e9ba748 // smmla v8.4s, v26.16b, v27.16b\n" + ".inst 0x4e9ba52a // smmla v10.4s, v9.16b, v27.16b\n" + ".inst 0x4e9da720 // smmla v0.4s, v25.16b, v29.16b\n" + ".inst 0x4e9da701 // smmla v1.4s, v24.16b, v29.16b\n" + ".inst 0x4e9da6e8 // smmla v8.4s, v23.16b, v29.16b\n" + ".inst 0x4e9da78a // smmla v10.4s, v28.16b, v29.16b\n" + ".inst 0x4e8ba640 // smmla v0.4s, v18.16b, v11.16b\n" + ".inst 0x4e8ba4c1 // smmla v1.4s, v6.16b, v11.16b\n" + ".inst 0x4e8ba6a8 // smmla v8.4s, v21.16b, v11.16b\n" + ".inst 0x4e8ba60a // smmla v10.4s, v16.16b, v11.16b\n" "bgt 4b\n" - "ldr q20, [x10, #0x0]\n" - "uzp1 v29.2d, v8.2d, v11.2d\n" - "uzp2 v18.2d, v8.2d, v11.2d\n" + "ldr q11, [x10, #0x0]\n" + "uzp1 v19.2d, v3.2d, v0.2d\n" + "uzp2 v18.2d, v3.2d, v0.2d\n" "add x10, x10, #0x10\n" - "uzp1 v30.2d, v17.2d, v4.2d\n" - "uzp2 v16.2d, v17.2d, v4.2d\n" - "scvtf v29.4s, v29.4s\n" + "uzp1 v17.2d, v2.2d, v1.2d\n" + "uzp2 v16.2d, v2.2d, v1.2d\n" + "scvtf v19.4s, v19.4s\n" "scvtf v18.4s, v18.4s\n" - "scvtf v30.4s, v30.4s\n" + "scvtf v17.4s, v17.4s\n" "scvtf v16.4s, v16.4s\n" - "fmla v6.4s, v29.4s, v20.4s\n" - "fmla v31.4s, v18.4s, v20.4s\n" - "fmla v1.4s, v30.4s, v20.4s\n" - "fmla v2.4s, v16.4s, v20.4s\n" - "uzp1 v4.2d, v23.2d, v15.2d\n" - "uzp2 v18.2d, v23.2d, v15.2d\n" - "uzp1 v15.2d, v27.2d, v5.2d\n" - "uzp2 v16.2d, v27.2d, v5.2d\n" - "scvtf v4.4s, v4.4s\n" + "fmla v22.4s, v19.4s, v11.4s\n" + "fmla v12.4s, v18.4s, v11.4s\n" + "fmla v13.4s, v17.4s, v11.4s\n" + "fmla v20.4s, v16.4s, v11.4s\n" + "uzp1 v19.2d, v30.2d, v8.2d\n" + "uzp2 v18.2d, v30.2d, v8.2d\n" + "uzp1 v17.2d, v4.2d, v10.2d\n" + "uzp2 v16.2d, v4.2d, v10.2d\n" + "scvtf v19.4s, v19.4s\n" "scvtf v18.4s, v18.4s\n" - "scvtf v15.4s, v15.4s\n" + "scvtf v17.4s, v17.4s\n" "scvtf v16.4s, v16.4s\n" - "fmla v13.4s, v4.4s, v20.4s\n" - "fmla v21.4s, v18.4s, v20.4s\n" - "fmla v22.4s, v15.4s, v20.4s\n" - "fmla v12.4s, v16.4s, v20.4s\n" + "fmla v14.4s, v19.4s, v11.4s\n" + "fmla v5.4s, v18.4s, v11.4s\n" + "fmla v7.4s, v17.4s, v11.4s\n" + "fmla v31.4s, v16.4s, v11.4s\n" "subs x22, x22, #0x1\n" "bgt 3b\n" - "ld1 { v29.4s }, [x23]\n" - "ld1 { v4.4s }, [x21]\n" + "ld1 { v24.4s }, [x23]\n" + "ld1 { v23.4s }, [x21]\n" "add x23, x23, #0x10\n" "add x21, x21, #0x10\n" - "ldr q20, [x10, #0x0]\n" - "ldr q8, [x23, #0x0]\n" + "ldr q21, [x10, #0x0]\n" + "ldr q26, [x23, #0x0]\n" "add x20, %x[clamp_vals], #0x4\n" "cmp x9, #0x4\n" - "ldr q18, [x21, #0x0]\n" - "ld1r { v28.4s }, [%x[clamp_vals]]\n" - "add x10, x10, #0x10\n" + "ldr q19, [x21, #0x0]\n" + "ldr q18, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" "ld1r { v16.4s }, [x20]\n" - "scvtf v29.4s, v29.4s\n" - "scvtf v4.4s, v4.4s\n" - "fmla v6.4s, v20.4s, v29.s[0]\n" - "fmla v31.4s, v20.4s, v29.s[1]\n" - "fmla v1.4s, v20.4s, v29.s[2]\n" - "fmla v2.4s, v20.4s, v29.s[3]\n" - "fmla v13.4s, v20.4s, v4.s[0]\n" - "fmla v21.4s, v20.4s, v4.s[1]\n" - "fmla v22.4s, v20.4s, v4.s[2]\n" - "fmla v12.4s, v20.4s, v4.s[3]\n" - "fmul v6.4s, v6.4s, v8.s[0]\n" - "fmul v31.4s, v31.4s, v8.s[1]\n" - "fmul v1.4s, v1.4s, v8.s[2]\n" - "fmul v2.4s, v2.4s, v8.s[3]\n" - "fmul v13.4s, v13.4s, v18.s[0]\n" - "fmul v21.4s, v21.4s, v18.s[1]\n" - "fmul v22.4s, v22.4s, v18.s[2]\n" - "fmul v12.4s, v12.4s, v18.s[3]\n" - "fmax v6.4s, v6.4s, v28.4s\n" - "fmax v31.4s, v31.4s, v28.4s\n" - "fmax v1.4s, v1.4s, v28.4s\n" - "fmax v2.4s, v2.4s, v28.4s\n" - "fmax v13.4s, v13.4s, v28.4s\n" - "fmax v21.4s, v21.4s, v28.4s\n" - "fmax v22.4s, v22.4s, v28.4s\n" - "fmax v12.4s, v12.4s, v28.4s\n" - "fmin v6.4s, v6.4s, v16.4s\n" - "fmin v31.4s, v31.4s, v16.4s\n" - "fmin v1.4s, v1.4s, v16.4s\n" - "fmin v2.4s, v2.4s, v16.4s\n" - "fmin v13.4s, v13.4s, v16.4s\n" - "fmin v21.4s, v21.4s, v16.4s\n" + "scvtf v24.4s, v24.4s\n" + "scvtf v23.4s, v23.4s\n" + "fmla v22.4s, v21.4s, v24.s[0]\n" + "fmla v12.4s, v21.4s, v24.s[1]\n" + "fmla v13.4s, v21.4s, v24.s[2]\n" + "fmla v20.4s, v21.4s, v24.s[3]\n" + "fmla v14.4s, v21.4s, v23.s[0]\n" + "fmla v5.4s, v21.4s, v23.s[1]\n" + "fmla v7.4s, v21.4s, v23.s[2]\n" + "fmla v31.4s, v21.4s, v23.s[3]\n" + "fmul v22.4s, v22.4s, v26.s[0]\n" + "fmul v12.4s, v12.4s, v26.s[1]\n" + "fmul v13.4s, v13.4s, v26.s[2]\n" + "fmul v20.4s, v20.4s, v26.s[3]\n" + "fmul v14.4s, v14.4s, v19.s[0]\n" + "fmul v5.4s, v5.4s, v19.s[1]\n" + "fadd v22.4s, v22.4s, v18.4s\n" + "fmul v7.4s, v7.4s, v19.s[2]\n" + "fmul v31.4s, v31.4s, v19.s[3]\n" + "fadd v12.4s, v12.4s, v18.4s\n" + "fadd v13.4s, v13.4s, v18.4s\n" + "fadd v20.4s, v20.4s, v18.4s\n" + "fadd v14.4s, v14.4s, v18.4s\n" + "fadd v5.4s, v5.4s, v18.4s\n" + "fadd v7.4s, v7.4s, v18.4s\n" + "fadd v31.4s, v31.4s, v18.4s\n" + "fmax v22.4s, v22.4s, v17.4s\n" + "fmax v12.4s, v12.4s, v17.4s\n" + "fmax v13.4s, v13.4s, v17.4s\n" + "fmax v20.4s, v20.4s, v17.4s\n" + "fmax v14.4s, v14.4s, v17.4s\n" + "fmax v5.4s, v5.4s, v17.4s\n" + "fmax v7.4s, v7.4s, v17.4s\n" + "fmax v31.4s, v31.4s, v17.4s\n" "fmin v22.4s, v22.4s, v16.4s\n" "fmin v12.4s, v12.4s, v16.4s\n" + "fmin v13.4s, v13.4s, v16.4s\n" + "fmin v20.4s, v20.4s, v16.4s\n" + "fmin v14.4s, v14.4s, v16.4s\n" + "fmin v5.4s, v5.4s, v16.4s\n" + "fmin v7.4s, v7.4s, v16.4s\n" + "fmin v31.4s, v31.4s, v16.4s\n" "blt 7f\n" "mov x20, %x[dst]\n" - "str q6, [x20, #0x0]\n" + "str q22, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q31, [x20, #0x0]\n" + "str q12, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q1, [x20, #0x0]\n" + "str q13, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q2, [x20, #0x0]\n" + "str q20, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q13, [x20, #0x0]\n" + "str q14, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q21, [x20, #0x0]\n" + "str q5, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q22, [x20, #0x0]\n" + "str q7, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q12, [x20, #0x0]\n" + "str q31, [x20, #0x0]\n" "b 10f\n" "7:" // Partial output "mov x27, %x[dst]\n" @@ -326,196 +337,201 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( "add x21, x27, %x[dst_stride_row]\n" "add x20, x22, %x[dst_stride_row]\n" "tbz x9, #1, 8f\n" - "st1 { v12.d }[0], [x23], #0x8\n" - "st1 { v22.d }[0], [x25], #0x8\n" - "st1 { v21.d }[0], [x24], #0x8\n" - "st1 { v13.d }[0], [x26], #0x8\n" - "st1 { v2.d }[0], [x20], #0x8\n" - "st1 { v1.d }[0], [x22], #0x8\n" - "st1 { v31.d }[0], [x21], #0x8\n" - "st1 { v6.d }[0], [x27], #0x8\n" + "st1 { v31.d }[0], [x23], #0x8\n" + "st1 { v7.d }[0], [x25], #0x8\n" + "st1 { v5.d }[0], [x24], #0x8\n" + "st1 { v14.d }[0], [x26], #0x8\n" + "st1 { v20.d }[0], [x20], #0x8\n" + "st1 { v13.d }[0], [x22], #0x8\n" + "st1 { v12.d }[0], [x21], #0x8\n" + "st1 { v22.d }[0], [x27], #0x8\n" "tbz x9, #0, 9f\n" - "st1 { v12.s }[2], [x23]\n" - "st1 { v22.s }[2], [x25]\n" - "st1 { v21.s }[2], [x24]\n" - "st1 { v13.s }[2], [x26]\n" - "st1 { v2.s }[2], [x20]\n" - "st1 { v1.s }[2], [x22]\n" - "st1 { v31.s }[2], [x21]\n" - "st1 { v6.s }[2], [x27]\n" + "st1 { v31.s }[2], [x23]\n" + "st1 { v7.s }[2], [x25]\n" + "st1 { v5.s }[2], [x24]\n" + "st1 { v14.s }[2], [x26]\n" + "st1 { v20.s }[2], [x20]\n" + "st1 { v13.s }[2], [x22]\n" + "st1 { v12.s }[2], [x21]\n" + "st1 { v22.s }[2], [x27]\n" "b 9f\n" "8:" // Output block 0: partial_1_0 - "st1 { v12.s }[0], [x23]\n" - "st1 { v22.s }[0], [x25]\n" - "st1 { v21.s }[0], [x24]\n" - "st1 { v13.s }[0], [x26]\n" - "st1 { v2.s }[0], [x20]\n" - "st1 { v1.s }[0], [x22]\n" - "st1 { v31.s }[0], [x21]\n" - "st1 { v6.s }[0], [x27]\n" + "st1 { v31.s }[0], [x23]\n" + "st1 { v7.s }[0], [x25]\n" + "st1 { v5.s }[0], [x24]\n" + "st1 { v14.s }[0], [x26]\n" + "st1 { v20.s }[0], [x20]\n" + "st1 { v13.s }[0], [x22]\n" + "st1 { v12.s }[0], [x21]\n" + "st1 { v22.s }[0], [x27]\n" "9:" // Output block 0: Done "10:" // Output stage exit "subs x9, x9, #0x4\n" "add %x[dst], %x[dst], #0x10\n" "bgt 2b\n" "mov x20, #0x2\n" - "sub x12, x12, #0x8\n" - "cmp x12, #0x8\n" + "sub x11, x11, #0x8\n" + "cmp x11, #0x8\n" "mov %x[dst], x28\n" - "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" + "madd %x[lhs_packed], x20, x12, %x[lhs_packed]\n" "bge 1b\n" "11:" // Row loop skip - "cbz x12, 21f\n" + "cbz x11, 21f\n" "12:" // Row tail: Row loop "mov x26, %x[rhs_packed]\n" "mov x25, %x[n]\n" "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" "13:" // Row tail: Column loop - "movi v6.16b, #0x0\n" - "movi v31.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v12.16b, #0x0\n" "mov x23, %x[lhs_packed]\n" "mov x21, %x[num_blocks]\n" - "movi v1.16b, #0x0\n" - "movi v2.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v20.16b, #0x0\n" "14:" // Row tail: Block loop - "movi v8.4s, #0x0\n" - "movi v11.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "movi v0.4s, #0x0\n" "mov x20, %x[num_subblocks]\n" - "movi v17.4s, #0x0\n" - "movi v4.4s, #0x0\n" + "movi v2.4s, #0x0\n" + "movi v1.4s, #0x0\n" "15:" // Row tail: Sub block loop - "ldr q13, [x26, #0x0]\n" - "ldr q5, [x26, #0x10]\n" + "ldr q31, [x26, #0x0]\n" + "ldr q8, [x26, #0x10]\n" "subs x20, x20, #0x1\n" - "ldr q15, [x23, #0x0]\n" - "ldr q30, [x23, #0x10]\n" - "ldr q29, [x26, #0x20]\n" - "ldr q18, [x26, #0x30]\n" + "ldr q30, [x23, #0x0]\n" + "ldr q29, [x23, #0x10]\n" + "ldr q6, [x26, #0x20]\n" + "ldr q27, [x26, #0x30]\n" "add x26, x26, #0x40\n" - "ldr q27, [x23, #0x20]\n" - "ldr q26, [x23, #0x30]\n" - "shl v25.16b, v13.16b, #0x4\n" - "shl v24.16b, v5.16b, #0x4\n" - "ldr q28, [x23, #0x40]\n" - "ldr q20, [x23, #0x50]\n" - "and v13.16b, v13.16b, v10.16b\n" - "and v5.16b, v5.16b, v10.16b\n" - "ldr q0, [x23, #0x60]\n" - "ldr q12, [x23, #0x70]\n" - "shl v9.16b, v29.16b, #0x4\n" - "shl v16.16b, v18.16b, #0x4\n" - ".inst 0x4e99a5e8 // smmla v8.4s, v15.16b, v25.16b\n" - ".inst 0x4e98a5eb // smmla v11.4s, v15.16b, v24.16b\n" - "and v29.16b, v29.16b, v10.16b\n" + "ldr q26, [x23, #0x20]\n" + "ldr q25, [x23, #0x30]\n" + "shl v24.16b, v31.16b, #0x4\n" + "shl v23.16b, v8.16b, #0x4\n" + "ldr q9, [x23, #0x40]\n" + "ldr q11, [x23, #0x50]\n" + "and v31.16b, v31.16b, v15.16b\n" + "and v8.16b, v8.16b, v15.16b\n" + "ldr q19, [x23, #0x60]\n" + "ldr q18, [x23, #0x70]\n" + "shl v17.16b, v6.16b, #0x4\n" + "shl v16.16b, v27.16b, #0x4\n" + ".inst 0x4e98a7c3 // smmla v3.4s, v30.16b, v24.16b\n" + ".inst 0x4e97a7c0 // smmla v0.4s, v30.16b, v23.16b\n" + "and v6.16b, v6.16b, v15.16b\n" "add x23, x23, #0x80\n" - ".inst 0x4e99a7d1 // smmla v17.4s, v30.16b, v25.16b\n" - ".inst 0x4e98a7c4 // smmla v4.4s, v30.16b, v24.16b\n" - "and v18.16b, v18.16b, v10.16b\n" - ".inst 0x4e89a768 // smmla v8.4s, v27.16b, v9.16b\n" - ".inst 0x4e90a76b // smmla v11.4s, v27.16b, v16.16b\n" - ".inst 0x4e89a751 // smmla v17.4s, v26.16b, v9.16b\n" - ".inst 0x4e90a744 // smmla v4.4s, v26.16b, v16.16b\n" - ".inst 0x4e8da788 // smmla v8.4s, v28.16b, v13.16b\n" - ".inst 0x4e85a78b // smmla v11.4s, v28.16b, v5.16b\n" - ".inst 0x4e8da691 // smmla v17.4s, v20.16b, v13.16b\n" - ".inst 0x4e85a684 // smmla v4.4s, v20.16b, v5.16b\n" - ".inst 0x4e9da408 // smmla v8.4s, v0.16b, v29.16b\n" - ".inst 0x4e92a40b // smmla v11.4s, v0.16b, v18.16b\n" - ".inst 0x4e9da591 // smmla v17.4s, v12.16b, v29.16b\n" - ".inst 0x4e92a584 // smmla v4.4s, v12.16b, v18.16b\n" + ".inst 0x4e98a7a2 // smmla v2.4s, v29.16b, v24.16b\n" + ".inst 0x4e97a7a1 // smmla v1.4s, v29.16b, v23.16b\n" + "and v27.16b, v27.16b, v15.16b\n" + ".inst 0x4e91a743 // smmla v3.4s, v26.16b, v17.16b\n" + ".inst 0x4e90a740 // smmla v0.4s, v26.16b, v16.16b\n" + ".inst 0x4e91a722 // smmla v2.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a721 // smmla v1.4s, v25.16b, v16.16b\n" + ".inst 0x4e9fa523 // smmla v3.4s, v9.16b, v31.16b\n" + ".inst 0x4e88a520 // smmla v0.4s, v9.16b, v8.16b\n" + ".inst 0x4e9fa562 // smmla v2.4s, v11.16b, v31.16b\n" + ".inst 0x4e88a561 // smmla v1.4s, v11.16b, v8.16b\n" + ".inst 0x4e86a663 // smmla v3.4s, v19.16b, v6.16b\n" + ".inst 0x4e9ba660 // smmla v0.4s, v19.16b, v27.16b\n" + ".inst 0x4e86a642 // smmla v2.4s, v18.16b, v6.16b\n" + ".inst 0x4e9ba641 // smmla v1.4s, v18.16b, v27.16b\n" "bgt 15b\n" - "ldr q20, [x26, #0x0]\n" - "uzp1 v14.2d, v8.2d, v11.2d\n" - "uzp2 v18.2d, v8.2d, v11.2d\n" + "ldr q11, [x26, #0x0]\n" + "uzp1 v19.2d, v3.2d, v0.2d\n" + "uzp2 v18.2d, v3.2d, v0.2d\n" "add x26, x26, #0x10\n" - "uzp1 v8.2d, v17.2d, v4.2d\n" - "uzp2 v16.2d, v17.2d, v4.2d\n" - "scvtf v14.4s, v14.4s\n" + "uzp1 v17.2d, v2.2d, v1.2d\n" + "uzp2 v16.2d, v2.2d, v1.2d\n" + "scvtf v19.4s, v19.4s\n" "scvtf v18.4s, v18.4s\n" - "scvtf v8.4s, v8.4s\n" + "scvtf v17.4s, v17.4s\n" "scvtf v16.4s, v16.4s\n" - "fmla v6.4s, v14.4s, v20.4s\n" - "fmla v31.4s, v18.4s, v20.4s\n" - "fmla v1.4s, v8.4s, v20.4s\n" - "fmla v2.4s, v16.4s, v20.4s\n" + "fmla v22.4s, v19.4s, v11.4s\n" + "fmla v12.4s, v18.4s, v11.4s\n" + "fmla v13.4s, v17.4s, v11.4s\n" + "fmla v20.4s, v16.4s, v11.4s\n" "subs x21, x21, #0x1\n" "bgt 14b\n" - "ld1 { v20.4s }, [x23]\n" - "ldr q11, [x26, #0x0]\n" + "ld1 { v21.4s }, [x23]\n" + "ldr q1, [x26, #0x0]\n" "add x23, x23, #0x10\n" "add x20, %x[clamp_vals], #0x4\n" - "ldr q18, [x23, #0x0]\n" - "ld1r { v8.4s }, [%x[clamp_vals]]\n" + "ldr q19, [x23, #0x0]\n" + "ldr q18, [x26, #0x10]\n" "cmp x25, #0x4\n" - "add x26, x26, #0x10\n" + "add x26, x26, #0x20\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" "ld1r { v16.4s }, [x20]\n" - "scvtf v20.4s, v20.4s\n" - "fmla v6.4s, v11.4s, v20.s[0]\n" - "fmla v31.4s, v11.4s, v20.s[1]\n" - "fmla v1.4s, v11.4s, v20.s[2]\n" - "fmla v2.4s, v11.4s, v20.s[3]\n" - "fmul v6.4s, v6.4s, v18.s[0]\n" - "fmul v31.4s, v31.4s, v18.s[1]\n" - "fmax v6.4s, v6.4s, v8.4s\n" - "fmul v1.4s, v1.4s, v18.s[2]\n" - "fmul v2.4s, v2.4s, v18.s[3]\n" - "fmax v31.4s, v31.4s, v8.4s\n" - "fmin v6.4s, v6.4s, v16.4s\n" - "fmax v1.4s, v1.4s, v8.4s\n" - "fmax v2.4s, v2.4s, v8.4s\n" - "fmin v31.4s, v31.4s, v16.4s\n" - "fmin v1.4s, v1.4s, v16.4s\n" - "fmin v2.4s, v2.4s, v16.4s\n" + "scvtf v21.4s, v21.4s\n" + "fmla v22.4s, v1.4s, v21.s[0]\n" + "fmla v12.4s, v1.4s, v21.s[1]\n" + "fmla v13.4s, v1.4s, v21.s[2]\n" + "fmla v20.4s, v1.4s, v21.s[3]\n" + "fmul v22.4s, v22.4s, v19.s[0]\n" + "fmul v12.4s, v12.4s, v19.s[1]\n" + "fmul v13.4s, v13.4s, v19.s[2]\n" + "fadd v22.4s, v22.4s, v18.4s\n" + "fmul v20.4s, v20.4s, v19.s[3]\n" + "fadd v12.4s, v12.4s, v18.4s\n" + "fadd v13.4s, v13.4s, v18.4s\n" + "fadd v20.4s, v20.4s, v18.4s\n" + "fmax v22.4s, v22.4s, v17.4s\n" + "fmax v12.4s, v12.4s, v17.4s\n" + "fmax v13.4s, v13.4s, v17.4s\n" + "fmax v20.4s, v20.4s, v17.4s\n" + "fmin v22.4s, v22.4s, v16.4s\n" + "fmin v12.4s, v12.4s, v16.4s\n" + "fmin v13.4s, v13.4s, v16.4s\n" + "fmin v20.4s, v20.4s, v16.4s\n" "blt 17f\n" "mov x20, %x[dst]\n" - "cmp x12, #0x1\n" - "str q6, [x20, #0x0]\n" + "cmp x11, #0x1\n" + "str q22, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 20f\n" - "cmp x12, #0x2\n" - "str q31, [x20, #0x0]\n" + "cmp x11, #0x2\n" + "str q12, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 20f\n" - "cmp x12, #0x3\n" - "str q1, [x20, #0x0]\n" + "cmp x11, #0x3\n" + "str q13, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 20f\n" - "str q2, [x20, #0x0]\n" + "str q20, [x20, #0x0]\n" "b 20f\n" "17:" // Row tail: Partial output "mov x23, %x[dst]\n" - "cmp x12, #0x1\n" + "cmp x11, #0x1\n" "add x22, x23, %x[dst_stride_row]\n" "csel x22, x22, x23, GE\n" - "cmp x12, #0x2\n" + "cmp x11, #0x2\n" "add x21, x23, %x[dst_stride_row], LSL #1\n" "csel x21, x21, x22, GE\n" - "cmp x12, #0x3\n" + "cmp x11, #0x3\n" "add x20, x21, %x[dst_stride_row]\n" "csel x20, x20, x21, GE\n" "tbz x25, #1, 18f\n" - "st1 { v2.d }[0], [x20], #0x8\n" - "st1 { v1.d }[0], [x21], #0x8\n" - "st1 { v31.d }[0], [x22], #0x8\n" - "st1 { v6.d }[0], [x23], #0x8\n" + "st1 { v20.d }[0], [x20], #0x8\n" + "st1 { v13.d }[0], [x21], #0x8\n" + "st1 { v12.d }[0], [x22], #0x8\n" + "st1 { v22.d }[0], [x23], #0x8\n" "tbz x25, #0, 19f\n" - "st1 { v2.s }[2], [x20]\n" - "st1 { v1.s }[2], [x21]\n" - "st1 { v31.s }[2], [x22]\n" - "st1 { v6.s }[2], [x23]\n" + "st1 { v20.s }[2], [x20]\n" + "st1 { v13.s }[2], [x21]\n" + "st1 { v12.s }[2], [x22]\n" + "st1 { v22.s }[2], [x23]\n" "b 19f\n" "18:" // Row tail: Output block 0: partial_1_0 - "st1 { v2.s }[0], [x20]\n" - "st1 { v1.s }[0], [x21]\n" - "st1 { v31.s }[0], [x22]\n" - "st1 { v6.s }[0], [x23]\n" + "st1 { v20.s }[0], [x20]\n" + "st1 { v13.s }[0], [x21]\n" + "st1 { v12.s }[0], [x22]\n" + "st1 { v22.s }[0], [x23]\n" "19:" // Row tail: Output block 0: Done "20:" // Row tail: Output stage exit "subs x25, x25, #0x4\n" "add %x[dst], %x[dst], #0x10\n" "bgt 13b\n" - "subs x12, x12, #0x4\n" - "add %x[lhs_packed], %x[lhs_packed], x11\n" + "subs x11, x11, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x12\n" "mov %x[dst], x24\n" "bgt 12b\n" "21:" // Row tail: Row loop skip diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c index 7ba6efa9..c0070035 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c @@ -24,6 +24,7 @@ static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { KAI_ASSERT((k % 2) == 0); @@ -47,7 +48,7 @@ inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; - return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs); + return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { @@ -127,8 +128,9 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( __asm__ __volatile__( "mov x28, #0x80\n" "mov x20, #0x20\n" - "movi v12.16b, #0xf0\n" + "movi v14.16b, #0xf0\n" "mov x27, %x[m]\n" + "mul x28, %x[num_subblocks], x28\n" "madd x28, %x[num_blocks], x28, x20\n" "cbz x27, 12f\n" "1:" // Row loop @@ -136,186 +138,196 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( "mov x25, %x[n]\n" "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" "2:" // Column loop - "movi v16.16b, #0x0\n" - "movi v5.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v12.16b, #0x0\n" "mov x22, %x[lhs_packed]\n" "mov x21, %x[num_blocks]\n" - "movi v13.16b, #0x0\n" "movi v11.16b, #0x0\n" "movi v15.16b, #0x0\n" - "movi v27.16b, #0x0\n" - "movi v10.16b, #0x0\n" - "movi v14.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "movi v1.16b, #0x0\n" "3:" // Block loop - "movi v8.4s, #0x0\n" - "movi v3.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "movi v10.4s, #0x0\n" "mov x20, %x[num_subblocks]\n" - "movi v31.4s, #0x0\n" + "movi v3.4s, #0x0\n" "movi v0.4s, #0x0\n" - "movi v30.4s, #0x0\n" + "movi v2.4s, #0x0\n" + "movi v25.4s, #0x0\n" "movi v4.4s, #0x0\n" - "movi v29.4s, #0x0\n" - "movi v1.4s, #0x0\n" + "movi v28.4s, #0x0\n" "4:" // Sub block loop - "ldr q7, [x26, #0x0]\n" - "ldr q6, [x26, #0x10]\n" + "ldr q21, [x26, #0x0]\n" + "ldr q30, [x26, #0x10]\n" "subs x20, x20, #0x1\n" "ldr q22, [x26, #0x20]\n" - "ldr q26, [x26, #0x30]\n" + "ldr q29, [x26, #0x30]\n" "ldr q20, [x22, #0x0]\n" - "ldr q18, [x22, #0x10]\n" - "ldr q23, [x26, #0x40]\n" - "ldr q17, [x26, #0x50]\n" - "shl v28.16b, v7.16b, #0x4\n" - "shl v2.16b, v6.16b, #0x4\n" - "ldr q25, [x26, #0x60]\n" + "ldr q23, [x22, #0x10]\n" + "ldr q19, [x26, #0x40]\n" + "ldr q9, [x26, #0x50]\n" + "shl v26.16b, v21.16b, #0x4\n" + "shl v6.16b, v30.16b, #0x4\n" + "ldr q18, [x26, #0x60]\n" "ldr q24, [x26, #0x70]\n" - "shl v19.16b, v22.16b, #0x4\n" - "shl v9.16b, v26.16b, #0x4\n" - "ldr q21, [x22, #0x20]\n" - "and v7.16b, v7.16b, v12.16b\n" - "and v6.16b, v6.16b, v12.16b\n" + "shl v31.16b, v22.16b, #0x4\n" + "shl v16.16b, v29.16b, #0x4\n" + "ldr q17, [x22, #0x20]\n" + "and v21.16b, v21.16b, v14.16b\n" + "and v30.16b, v30.16b, v14.16b\n" "add x26, x26, #0x80\n" - ".inst 0x4e9ca688 // smmla v8.4s, v20.16b, v28.16b\n" - ".inst 0x4e82a69f // smmla v31.4s, v20.16b, v2.16b\n" - "and v22.16b, v22.16b, v12.16b\n" - ".inst 0x4e93a683 // smmla v3.4s, v20.16b, v19.16b\n" - ".inst 0x4e89a680 // smmla v0.4s, v20.16b, v9.16b\n" + ".inst 0x4e9aa687 // smmla v7.4s, v20.16b, v26.16b\n" + ".inst 0x4e86a683 // smmla v3.4s, v20.16b, v6.16b\n" + "and v22.16b, v22.16b, v14.16b\n" + ".inst 0x4e9fa68a // smmla v10.4s, v20.16b, v31.16b\n" + ".inst 0x4e90a680 // smmla v0.4s, v20.16b, v16.16b\n" "ldr q20, [x22, #0x30]\n" - "and v26.16b, v26.16b, v12.16b\n" - ".inst 0x4e9ca65e // smmla v30.4s, v18.16b, v28.16b\n" - "ldr q28, [x22, #0x40]\n" - ".inst 0x4e82a65d // smmla v29.4s, v18.16b, v2.16b\n" - "ldr q2, [x22, #0x50]\n" - ".inst 0x4e93a644 // smmla v4.4s, v18.16b, v19.16b\n" - "ldr q19, [x22, #0x60]\n" - ".inst 0x4e89a641 // smmla v1.4s, v18.16b, v9.16b\n" - "ldr q18, [x22, #0x70]\n" - "shl v9.16b, v23.16b, #0x4\n" - "and v23.16b, v23.16b, v12.16b\n" + "and v29.16b, v29.16b, v14.16b\n" + ".inst 0x4e9aa6e2 // smmla v2.4s, v23.16b, v26.16b\n" + "ldr q26, [x22, #0x40]\n" + ".inst 0x4e86a6e4 // smmla v4.4s, v23.16b, v6.16b\n" + "ldr q6, [x22, #0x50]\n" + ".inst 0x4e9fa6f9 // smmla v25.4s, v23.16b, v31.16b\n" + "ldr q31, [x22, #0x60]\n" + ".inst 0x4e90a6fc // smmla v28.4s, v23.16b, v16.16b\n" + "ldr q23, [x22, #0x70]\n" + "shl v16.16b, v19.16b, #0x4\n" + "and v19.16b, v19.16b, v14.16b\n" "add x22, x22, #0x80\n" - ".inst 0x4e89a6a8 // smmla v8.4s, v21.16b, v9.16b\n" - ".inst 0x4e89a69e // smmla v30.4s, v20.16b, v9.16b\n" - "shl v9.16b, v17.16b, #0x4\n" - "and v17.16b, v17.16b, v12.16b\n" - ".inst 0x4e89a6bf // smmla v31.4s, v21.16b, v9.16b\n" - ".inst 0x4e89a69d // smmla v29.4s, v20.16b, v9.16b\n" - "shl v9.16b, v25.16b, #0x4\n" - "and v25.16b, v25.16b, v12.16b\n" - ".inst 0x4e87a788 // smmla v8.4s, v28.16b, v7.16b\n" - ".inst 0x4e87a45e // smmla v30.4s, v2.16b, v7.16b\n" - "shl v7.16b, v24.16b, #0x4\n" - "and v24.16b, v24.16b, v12.16b\n" - ".inst 0x4e89a6a3 // smmla v3.4s, v21.16b, v9.16b\n" - ".inst 0x4e89a684 // smmla v4.4s, v20.16b, v9.16b\n" - ".inst 0x4e86a79f // smmla v31.4s, v28.16b, v6.16b\n" - ".inst 0x4e86a45d // smmla v29.4s, v2.16b, v6.16b\n" - ".inst 0x4e87a6a0 // smmla v0.4s, v21.16b, v7.16b\n" - ".inst 0x4e87a681 // smmla v1.4s, v20.16b, v7.16b\n" - ".inst 0x4e97a668 // smmla v8.4s, v19.16b, v23.16b\n" - ".inst 0x4e97a65e // smmla v30.4s, v18.16b, v23.16b\n" - ".inst 0x4e96a783 // smmla v3.4s, v28.16b, v22.16b\n" - ".inst 0x4e96a444 // smmla v4.4s, v2.16b, v22.16b\n" - ".inst 0x4e91a67f // smmla v31.4s, v19.16b, v17.16b\n" - ".inst 0x4e91a65d // smmla v29.4s, v18.16b, v17.16b\n" - ".inst 0x4e9aa780 // smmla v0.4s, v28.16b, v26.16b\n" - ".inst 0x4e9aa441 // smmla v1.4s, v2.16b, v26.16b\n" - ".inst 0x4e99a663 // smmla v3.4s, v19.16b, v25.16b\n" - ".inst 0x4e99a644 // smmla v4.4s, v18.16b, v25.16b\n" - ".inst 0x4e98a660 // smmla v0.4s, v19.16b, v24.16b\n" - ".inst 0x4e98a641 // smmla v1.4s, v18.16b, v24.16b\n" + ".inst 0x4e90a627 // smmla v7.4s, v17.16b, v16.16b\n" + ".inst 0x4e90a682 // smmla v2.4s, v20.16b, v16.16b\n" + "shl v16.16b, v9.16b, #0x4\n" + "and v9.16b, v9.16b, v14.16b\n" + ".inst 0x4e90a623 // smmla v3.4s, v17.16b, v16.16b\n" + ".inst 0x4e90a684 // smmla v4.4s, v20.16b, v16.16b\n" + "shl v16.16b, v18.16b, #0x4\n" + "and v18.16b, v18.16b, v14.16b\n" + ".inst 0x4e95a747 // smmla v7.4s, v26.16b, v21.16b\n" + ".inst 0x4e95a4c2 // smmla v2.4s, v6.16b, v21.16b\n" + "shl v21.16b, v24.16b, #0x4\n" + "and v24.16b, v24.16b, v14.16b\n" + ".inst 0x4e90a62a // smmla v10.4s, v17.16b, v16.16b\n" + ".inst 0x4e90a699 // smmla v25.4s, v20.16b, v16.16b\n" + ".inst 0x4e9ea743 // smmla v3.4s, v26.16b, v30.16b\n" + ".inst 0x4e9ea4c4 // smmla v4.4s, v6.16b, v30.16b\n" + ".inst 0x4e95a620 // smmla v0.4s, v17.16b, v21.16b\n" + ".inst 0x4e95a69c // smmla v28.4s, v20.16b, v21.16b\n" + ".inst 0x4e93a7e7 // smmla v7.4s, v31.16b, v19.16b\n" + ".inst 0x4e93a6e2 // smmla v2.4s, v23.16b, v19.16b\n" + ".inst 0x4e96a74a // smmla v10.4s, v26.16b, v22.16b\n" + ".inst 0x4e96a4d9 // smmla v25.4s, v6.16b, v22.16b\n" + ".inst 0x4e89a7e3 // smmla v3.4s, v31.16b, v9.16b\n" + ".inst 0x4e89a6e4 // smmla v4.4s, v23.16b, v9.16b\n" + ".inst 0x4e9da740 // smmla v0.4s, v26.16b, v29.16b\n" + ".inst 0x4e9da4dc // smmla v28.4s, v6.16b, v29.16b\n" + ".inst 0x4e92a7ea // smmla v10.4s, v31.16b, v18.16b\n" + ".inst 0x4e92a6f9 // smmla v25.4s, v23.16b, v18.16b\n" + ".inst 0x4e98a7e0 // smmla v0.4s, v31.16b, v24.16b\n" + ".inst 0x4e98a6fc // smmla v28.4s, v23.16b, v24.16b\n" "bgt 4b\n" - "ldr q6, [x26, #0x0]\n" - "ldr q24, [x26, #0x10]\n" - "uzp1 v22.2d, v8.2d, v31.2d\n" - "uzp2 v28.2d, v8.2d, v31.2d\n" - "uzp1 v21.2d, v3.2d, v0.2d\n" - "uzp2 v20.2d, v3.2d, v0.2d\n" + "ldr q31, [x26, #0x0]\n" + "ldr q20, [x26, #0x10]\n" + "uzp1 v21.2d, v7.2d, v3.2d\n" + "uzp2 v23.2d, v7.2d, v3.2d\n" + "uzp1 v22.2d, v10.2d, v0.2d\n" + "uzp2 v30.2d, v10.2d, v0.2d\n" "add x26, x26, #0x20\n" - "uzp1 v19.2d, v30.2d, v29.2d\n" - "uzp2 v18.2d, v30.2d, v29.2d\n" - "uzp1 v17.2d, v4.2d, v1.2d\n" - "uzp2 v0.2d, v4.2d, v1.2d\n" - "scvtf v22.4s, v22.4s\n" + "uzp1 v7.2d, v2.2d, v4.2d\n" + "uzp2 v18.2d, v2.2d, v4.2d\n" + "uzp1 v17.2d, v25.2d, v28.2d\n" + "uzp2 v16.2d, v25.2d, v28.2d\n" "scvtf v21.4s, v21.4s\n" - "scvtf v28.4s, v28.4s\n" - "scvtf v20.4s, v20.4s\n" - "scvtf v19.4s, v19.4s\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v30.4s, v30.4s\n" + "scvtf v7.4s, v7.4s\n" "scvtf v17.4s, v17.4s\n" "scvtf v18.4s, v18.4s\n" - "scvtf v0.4s, v0.4s\n" - "fmla v16.4s, v22.4s, v6.4s\n" - "fmla v5.4s, v21.4s, v24.4s\n" - "fmla v13.4s, v28.4s, v6.4s\n" - "fmla v11.4s, v20.4s, v24.4s\n" - "fmla v15.4s, v19.4s, v6.4s\n" - "fmla v27.4s, v17.4s, v24.4s\n" - "fmla v10.4s, v18.4s, v6.4s\n" - "fmla v14.4s, v0.4s, v24.4s\n" + "scvtf v16.4s, v16.4s\n" + "fmla v27.4s, v21.4s, v31.4s\n" + "fmla v12.4s, v22.4s, v20.4s\n" + "fmla v11.4s, v23.4s, v31.4s\n" + "fmla v15.4s, v30.4s, v20.4s\n" + "fmla v13.4s, v7.4s, v31.4s\n" + "fmla v5.4s, v17.4s, v20.4s\n" + "fmla v8.4s, v18.4s, v31.4s\n" + "fmla v1.4s, v16.4s, v20.4s\n" "subs x21, x21, #0x1\n" "bgt 3b\n" - "ld1 { v21.4s }, [x22]\n" - "ldr q20, [x26, #0x0]\n" + "ld1 { v23.4s }, [x22]\n" + "ldr q22, [x26, #0x0]\n" "add x22, x22, #0x10\n" "add x20, %x[clamp_vals], #0x4\n" - "ldr q19, [x26, #0x10]\n" - "ldr q18, [x22, #0x0]\n" + "ldr q2, [x26, #0x10]\n" + "ldr q20, [x22, #0x0]\n" "cmp x25, #0x8\n" - "add x26, x26, #0x20\n" + "ldr q19, [x26, #0x20]\n" + "ldr q18, [x26, #0x30]\n" + "add x26, x26, #0x40\n" "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "ld1r { v31.4s }, [x20]\n" - "scvtf v21.4s, v21.4s\n" - "fmla v16.4s, v20.4s, v21.s[0]\n" - "fmla v5.4s, v19.4s, v21.s[0]\n" - "fmla v13.4s, v20.4s, v21.s[1]\n" - "fmla v11.4s, v19.4s, v21.s[1]\n" - "fmla v15.4s, v20.4s, v21.s[2]\n" - "fmla v27.4s, v19.4s, v21.s[2]\n" - "fmla v10.4s, v20.4s, v21.s[3]\n" - "fmla v14.4s, v19.4s, v21.s[3]\n" - "fmul v16.4s, v16.4s, v18.s[0]\n" - "fmul v5.4s, v5.4s, v18.s[0]\n" - "fmul v13.4s, v13.4s, v18.s[1]\n" - "fmul v11.4s, v11.4s, v18.s[1]\n" - "fmul v15.4s, v15.4s, v18.s[2]\n" - "fmul v27.4s, v27.4s, v18.s[2]\n" - "fmul v10.4s, v10.4s, v18.s[3]\n" - "fmul v14.4s, v14.4s, v18.s[3]\n" - "fmax v16.4s, v16.4s, v17.4s\n" - "fmax v5.4s, v5.4s, v17.4s\n" - "fmax v13.4s, v13.4s, v17.4s\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v23.4s, v23.4s\n" + "fmla v27.4s, v22.4s, v23.s[0]\n" + "fmla v12.4s, v2.4s, v23.s[0]\n" + "fmla v11.4s, v22.4s, v23.s[1]\n" + "fmla v15.4s, v2.4s, v23.s[1]\n" + "fmla v13.4s, v22.4s, v23.s[2]\n" + "fmla v5.4s, v2.4s, v23.s[2]\n" + "fmla v8.4s, v22.4s, v23.s[3]\n" + "fmla v1.4s, v2.4s, v23.s[3]\n" + "fmul v27.4s, v27.4s, v20.s[0]\n" + "fmul v12.4s, v12.4s, v20.s[0]\n" + "fmul v11.4s, v11.4s, v20.s[1]\n" + "fmul v15.4s, v15.4s, v20.s[1]\n" + "fmul v13.4s, v13.4s, v20.s[2]\n" + "fmul v5.4s, v5.4s, v20.s[2]\n" + "fmul v8.4s, v8.4s, v20.s[3]\n" + "fmul v1.4s, v1.4s, v20.s[3]\n" + "fadd v27.4s, v27.4s, v19.4s\n" + "fadd v12.4s, v12.4s, v18.4s\n" + "fadd v11.4s, v11.4s, v19.4s\n" + "fadd v15.4s, v15.4s, v18.4s\n" + "fadd v13.4s, v13.4s, v19.4s\n" + "fadd v5.4s, v5.4s, v18.4s\n" + "fadd v8.4s, v8.4s, v19.4s\n" + "fadd v1.4s, v1.4s, v18.4s\n" + "fmax v27.4s, v27.4s, v17.4s\n" + "fmax v12.4s, v12.4s, v17.4s\n" "fmax v11.4s, v11.4s, v17.4s\n" "fmax v15.4s, v15.4s, v17.4s\n" - "fmax v27.4s, v27.4s, v17.4s\n" - "fmax v10.4s, v10.4s, v17.4s\n" - "fmax v14.4s, v14.4s, v17.4s\n" - "fmin v16.4s, v16.4s, v31.4s\n" - "fmin v5.4s, v5.4s, v31.4s\n" - "fmin v13.4s, v13.4s, v31.4s\n" - "fmin v11.4s, v11.4s, v31.4s\n" - "fmin v15.4s, v15.4s, v31.4s\n" - "fmin v27.4s, v27.4s, v31.4s\n" - "fmin v10.4s, v10.4s, v31.4s\n" - "fmin v14.4s, v14.4s, v31.4s\n" + "fmax v13.4s, v13.4s, v17.4s\n" + "fmax v5.4s, v5.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v1.4s, v1.4s, v17.4s\n" + "fmin v27.4s, v27.4s, v16.4s\n" + "fmin v12.4s, v12.4s, v16.4s\n" + "fmin v11.4s, v11.4s, v16.4s\n" + "fmin v15.4s, v15.4s, v16.4s\n" + "fmin v13.4s, v13.4s, v16.4s\n" + "fmin v5.4s, v5.4s, v16.4s\n" + "fmin v8.4s, v8.4s, v16.4s\n" + "fmin v1.4s, v1.4s, v16.4s\n" "blt 6f\n" "mov x20, %x[dst]\n" "cmp x27, #0x1\n" - "str q16, [x20, #0x0]\n" - "str q5, [x20, #0x10]\n" + "str q27, [x20, #0x0]\n" + "str q12, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 11f\n" "cmp x27, #0x2\n" - "str q13, [x20, #0x0]\n" - "str q11, [x20, #0x10]\n" + "str q11, [x20, #0x0]\n" + "str q15, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 11f\n" "cmp x27, #0x3\n" - "str q15, [x20, #0x0]\n" - "str q27, [x20, #0x10]\n" + "str q13, [x20, #0x0]\n" + "str q5, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 11f\n" - "str q10, [x20, #0x0]\n" - "str q14, [x20, #0x10]\n" + "str q8, [x20, #0x0]\n" + "str q1, [x20, #0x10]\n" "b 11f\n" "6:" // Partial output "mov x23, %x[dst]\n" @@ -329,45 +341,45 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( "add x20, x21, %x[dst_stride_row]\n" "csel x20, x20, x21, GE\n" "tbz x25, #2, 8f\n" - "st1 { v10.4s }, [x20], #0x10\n" - "st1 { v15.4s }, [x21], #0x10\n" - "st1 { v13.4s }, [x22], #0x10\n" - "st1 { v16.4s }, [x23], #0x10\n" + "st1 { v8.4s }, [x20], #0x10\n" + "st1 { v13.4s }, [x21], #0x10\n" + "st1 { v11.4s }, [x22], #0x10\n" + "st1 { v27.4s }, [x23], #0x10\n" "tbz x25, #1, 7f\n" - "st1 { v14.d }[0], [x20], #0x8\n" - "st1 { v27.d }[0], [x21], #0x8\n" - "st1 { v11.d }[0], [x22], #0x8\n" - "st1 { v5.d }[0], [x23], #0x8\n" + "st1 { v1.d }[0], [x20], #0x8\n" + "st1 { v5.d }[0], [x21], #0x8\n" + "st1 { v15.d }[0], [x22], #0x8\n" + "st1 { v12.d }[0], [x23], #0x8\n" "tbz x25, #0, 10f\n" - "st1 { v14.s }[2], [x20]\n" - "st1 { v27.s }[2], [x21]\n" - "st1 { v11.s }[2], [x22]\n" - "st1 { v5.s }[2], [x23]\n" + "st1 { v1.s }[2], [x20]\n" + "st1 { v5.s }[2], [x21]\n" + "st1 { v15.s }[2], [x22]\n" + "st1 { v12.s }[2], [x23]\n" "b 10f\n" "7:" // Output block 0: partial_1_4 "tbz x25, #0, 10f\n" - "st1 { v14.s }[0], [x20]\n" - "st1 { v27.s }[0], [x21]\n" - "st1 { v11.s }[0], [x22]\n" - "st1 { v5.s }[0], [x23]\n" + "st1 { v1.s }[0], [x20]\n" + "st1 { v5.s }[0], [x21]\n" + "st1 { v15.s }[0], [x22]\n" + "st1 { v12.s }[0], [x23]\n" "b 10f\n" "8:" // Output block 0: partial_2_0 "tbz x25, #1, 9f\n" - "st1 { v10.d }[0], [x20], #0x8\n" - "st1 { v15.d }[0], [x21], #0x8\n" - "st1 { v13.d }[0], [x22], #0x8\n" - "st1 { v16.d }[0], [x23], #0x8\n" + "st1 { v8.d }[0], [x20], #0x8\n" + "st1 { v13.d }[0], [x21], #0x8\n" + "st1 { v11.d }[0], [x22], #0x8\n" + "st1 { v27.d }[0], [x23], #0x8\n" "tbz x25, #0, 10f\n" - "st1 { v10.s }[2], [x20]\n" - "st1 { v15.s }[2], [x21]\n" - "st1 { v13.s }[2], [x22]\n" - "st1 { v16.s }[2], [x23]\n" + "st1 { v8.s }[2], [x20]\n" + "st1 { v13.s }[2], [x21]\n" + "st1 { v11.s }[2], [x22]\n" + "st1 { v27.s }[2], [x23]\n" "b 10f\n" "9:" // Output block 0: partial_1_0 - "st1 { v10.s }[0], [x20]\n" - "st1 { v15.s }[0], [x21]\n" - "st1 { v13.s }[0], [x22]\n" - "st1 { v16.s }[0], [x23]\n" + "st1 { v8.s }[0], [x20]\n" + "st1 { v13.s }[0], [x21]\n" + "st1 { v11.s }[0], [x22]\n" + "st1 { v27.s }[0], [x23]\n" "10:" // Output block 0: Done "11:" // Output stage exit "subs x25, x25, #0x8\n" diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c index 7d7cc4d0..9cf2ef99 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c @@ -12,6 +12,7 @@ #include "kai/kai_common.h" static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { KAI_ASSERT((k % 2) == 0); @@ -41,7 +42,7 @@ inline static size_t kai_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_ const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); - return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs); + return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } inline static size_t kai_rhs_packed_offset_end_of_all_blocks( @@ -186,5 +187,19 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( dst_row += (num_bytes_multiplier_rhs * nr); } + + // Skip the row sum + dst_row += (kai_num_bytes_sum_rhs * nr); + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + ((float*)dst_row)[i] = bias[y + nr]; + } + } + + dst_row += (kai_num_bytes_bias * nr); } } -- GitLab From 712bfd51b23cef4a87e1115dc1a3d38d49ed46d0 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 18 Jun 2024 14:06:00 +0100 Subject: [PATCH 04/29] Add bias support in all int4 matmul micro-kernels Signed-off-by: Gian Marco Iodice --- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 4 +++- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c | 18 +++++++------- ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c | 24 +++++++++++-------- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index aea682a7..0d6f451b 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -320,7 +320,7 @@ int main(int argc, char** argv) { const size_t m = 8; const size_t n = 8; const size_t k = 256; - const size_t bl = 32; + const size_t bl = 64; const size_t num_blocks_per_row = k / bl; const size_t num_byte_per_block = bl / 2 + sizeof(float); const size_t seed_lhs = 4568; @@ -394,6 +394,8 @@ int main(int argc, char** argv) { uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; + memset(dst_act_mtx_f32, 0, dst_size); + // If the RHS matrix contains constant values, the packing can be performed // only once struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params params; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c index 038644e8..7bf3b70e 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c @@ -175,18 +175,20 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( "scvtf v29.4s, v29.4s\n" "fmla v30.4s, v29.4s, v16.4s\n" "cbnz x21, 3b\n" - "ld1r { v20.4s }, [x22]\n" - "ldr q19, [x25, #0x0]\n" + "ld1r { v21.4s }, [x22]\n" + "ldr q20, [x25, #0x0]\n" "add x22, x22, #0x4\n" "add x20, %x[clamp_vals], #0x4\n" - "ld1r { v18.4s }, [x22]\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "ld1r { v19.4s }, [x22]\n" + "ldr q18, [x25, #0x10]\n" "cmp x24, #0x4\n" - "add x25, x25, #0x10\n" + "add x25, x25, #0x20\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" "ld1r { v16.4s }, [x20]\n" - "scvtf v20.4s, v20.4s\n" - "fmla v30.4s, v19.4s, v20.s[0]\n" - "fmul v30.4s, v30.4s, v18.4s\n" + "scvtf v21.4s, v21.4s\n" + "fmla v30.4s, v20.4s, v21.s[0]\n" + "fmul v30.4s, v30.4s, v19.4s\n" + "fadd v30.4s, v30.4s, v18.4s\n" "fmax v30.4s, v30.4s, v17.4s\n" "fmin v30.4s, v30.4s, v16.4s\n" "blt 5f\n" diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c index c0364c66..6171ea49 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c @@ -206,22 +206,26 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( "fmla v6.4s, v4.4s, v17.4s\n" "fmla v5.4s, v2.4s, v16.4s\n" "cbnz x21, 3b\n" - "ld1r { v21.4s }, [x22]\n" - "ldr q20, [x25, #0x0]\n" + "ld1r { v23.4s }, [x22]\n" + "ldr q22, [x25, #0x0]\n" "add x22, x22, #0x4\n" "add x20, %x[clamp_vals], #0x4\n" - "ldr q19, [x25, #0x10]\n" - "ld1r { v18.4s }, [x22]\n" + "ldr q21, [x25, #0x10]\n" + "ld1r { v20.4s }, [x22]\n" "cmp x24, #0x8\n" - "add x25, x25, #0x20\n" + "ldr q19, [x25, #0x20]\n" + "ldr q18, [x25, #0x30]\n" + "add x25, x25, #0x40\n" "ld1r { v17.4s }, [%x[clamp_vals]]\n" "ld1r { v16.4s }, [x20]\n" - "scvtf v21.4s, v21.4s\n" - "fmla v6.4s, v20.4s, v21.s[0]\n" - "fmla v5.4s, v19.4s, v21.s[0]\n" - "fmul v6.4s, v6.4s, v18.4s\n" + "scvtf v23.4s, v23.4s\n" + "fmla v6.4s, v22.4s, v23.s[0]\n" + "fmla v5.4s, v21.4s, v23.s[0]\n" + "fmul v6.4s, v6.4s, v20.4s\n" + "fadd v6.4s, v6.4s, v19.4s\n" + "fmul v5.4s, v5.4s, v20.4s\n" + "fadd v5.4s, v5.4s, v18.4s\n" "fmax v6.4s, v6.4s, v17.4s\n" - "fmul v5.4s, v5.4s, v18.4s\n" "fmax v5.4s, v5.4s, v17.4s\n" "fmin v6.4s, v6.4s, v16.4s\n" "fmin v5.4s, v5.4s, v16.4s\n" -- GitLab From df35405be987bd5404ad96e0ba8b5040dd412126 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 21 Jun 2024 10:52:52 +0100 Subject: [PATCH 05/29] Add profiler in the example - Add support for measuring the ukernel performance - Add missing stdint.h header file in kai_common.h Signed-off-by: Gian Marco Iodice --- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index 0d6f451b..60acd5b4 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -8,6 +8,7 @@ #else #include #include +#include #include #include #include @@ -317,8 +318,8 @@ static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, } int main(int argc, char** argv) { - const size_t m = 8; - const size_t n = 8; + const size_t m = 37; + const size_t n = 1024; const size_t k = 256; const size_t bl = 64; const size_t num_blocks_per_row = k / bl; @@ -326,6 +327,8 @@ int main(int argc, char** argv) { const size_t seed_lhs = 4568; const size_t seed_rhs = seed_lhs + 4; + std::cout << "------------" << std::endl; + const size_t lhs_native_size_f32 = m * k * sizeof(float); const size_t rhs_native_size_f32 = n * k * sizeof(float); const size_t rhs_native_size_qs4c32 = n * num_blocks_per_row * num_byte_per_block; @@ -375,8 +378,6 @@ int main(int argc, char** argv) { //------------------------------------ //------------------------------------ for (size_t idx_variant = 0; idx_variant < num_ukernel_variants; ++idx_variant) { - std::cout << "Testing " << ukernel_variants[idx_variant].name << std::endl; - // Get the packing parameters const size_t mr = ukernel_variants[idx_variant].ukernel.get_mr(); const size_t nr = ukernel_variants[idx_variant].ukernel.get_nr(); @@ -413,6 +414,8 @@ int main(int argc, char** argv) { rhs_packed_mtx_qs4cx, // RHS packed 0, ¶ms); + const auto time_s = std::chrono::high_resolution_clock::now(); + // LHS packing kai_run_lhs_quant_pack_qai8dxp_f32( m, k, // Dimensions @@ -444,14 +447,22 @@ int main(int argc, char** argv) { ); } + const auto time_e = std::chrono::high_resolution_clock::now(); + + const auto elap = std::chrono::duration_cast(time_e - time_s); + const bool is_valid = is_output_correct(m, n, 0.0001f, (const float*)dst_ref_mtx_f32, (const float*)dst_act_mtx_f32); + std::cout << "TEST[" << idx_variant << "]: Dynamic quantization + matmul" << std::endl; + std::cout << "- ukernel: " << ukernel_variants[idx_variant].name << std::endl; if (is_valid) { - printf("TEST[%ld] = PASSED\n", idx_variant); + std::cout << "- Status: PASSED" << std::endl; + std::cout << "- Performance: " << elap.count() << " us" << std::endl; } else { - printf("TEST[%ld] = FAILED\n", idx_variant); + std::cout << "Status: FAILED" << std::endl; } + std::cout << "------------" << std::endl; delete[] lhs_packed_mtx_qa8dx; delete[] rhs_packed_mtx_qs4cx; delete[] dst_act_mtx_f32; -- GitLab From 2bd19d787eab3bb04b6e53da2a6fd5d2da8dd6ee Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 9 Aug 2024 15:54:09 +0100 Subject: [PATCH 06/29] Refactoring code to support bf16 quantization scale parameters Signed-off-by: Gian Marco Iodice --- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 33 +++++++++--- kai/kai_common.h | 8 +++ .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c | 54 +++++++++++-------- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h | 7 ++- 4 files changed, 70 insertions(+), 32 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index 60acd5b4..f73d92ec 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -103,15 +103,19 @@ static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, si } } -static void quant_qs4c32_f32(size_t n, size_t k, size_t bl, const float* rhs_f32, uint8_t* rhs_qs4c32) { +static void quant_qs4c32_f32( + size_t n, size_t k, size_t bl, const float* rhs_f32, uint8_t* rhs_qs4c32, uint8_t* rhs_with_no_scale_qs4c32, + float* rhs_scales_f32) { const size_t num_blocks_row = num_blocks_per_row(k, bl); const size_t num_bytes_block = num_bytes_per_block(bl); const size_t dst_stride = num_blocks_row * num_bytes_block; + const size_t dst_with_no_scale_stride = num_blocks_row * (bl / 2); for (size_t row_idx = 0; row_idx < n; ++row_idx) { const float* src_ptr = rhs_f32 + row_idx * k; uint8_t* dst_ptr = (uint8_t*)rhs_qs4c32 + row_idx * dst_stride; + uint8_t* dst_with_no_scale_ptr = (uint8_t*)rhs_with_no_scale_qs4c32 + row_idx * dst_with_no_scale_stride; for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { float amax = 0.0f; @@ -132,7 +136,10 @@ static void quant_qs4c32_f32(size_t n, size_t k, size_t bl, const float* rhs_f32 // Store the scale at the beginning of the block *((float*)dst_ptr) = scale; + *rhs_scales_f32 = scale; + dst_ptr += sizeof(float); + rhs_scales_f32 += 1; const size_t block_size = 32; const size_t num_subblocks = bl / 32; @@ -153,6 +160,9 @@ static void quant_qs4c32_f32(size_t n, size_t k, size_t bl, const float* rhs_f32 dst_ptr[0] = rhs_v0; dst_ptr += sizeof(uint8_t); + + dst_with_no_scale_ptr[0] = rhs_v0; + dst_with_no_scale_ptr += sizeof(uint8_t); } } } @@ -332,16 +342,22 @@ int main(int argc, char** argv) { const size_t lhs_native_size_f32 = m * k * sizeof(float); const size_t rhs_native_size_f32 = n * k * sizeof(float); const size_t rhs_native_size_qs4c32 = n * num_blocks_per_row * num_byte_per_block; + const size_t rhs_native_with_no_scale_size_qs4c32 = n * num_blocks_per_row * (bl / 2); + const size_t rhs_scales_size_f32 = n * num_blocks_per_row * sizeof(float); // Allocate the memory uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; uint8_t* rhs_native_mtx_qs4c32 = new uint8_t[rhs_native_size_qs4c32]; + uint8_t* rhs_native_with_no_scale_mtx_qs4c32 = new uint8_t[rhs_native_with_no_scale_size_qs4c32]; + uint8_t* rhs_scales_mtx_f32 = new uint8_t[rhs_scales_size_f32]; fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); - quant_qs4c32_f32(n, k, bl, (const float*)rhs_native_mtx_f32, (uint8_t*)rhs_native_mtx_qs4c32); + quant_qs4c32_f32( + n, k, bl, (const float*)rhs_native_mtx_f32, (uint8_t*)rhs_native_mtx_qs4c32, + rhs_native_with_no_scale_mtx_qs4c32, (float*)rhs_scales_mtx_f32); delete[] rhs_native_mtx_f32; @@ -406,12 +422,13 @@ int main(int argc, char** argv) { // RHS packing kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - 1, n, k, // Dimensions - nr, kr, sr, // Packing arguments - bl, // Block length - (const uint8_t*)(rhs_native_mtx_qs4c32), // RHS - NULL, // Bias - rhs_packed_mtx_qs4cx, // RHS packed + 1, n, k, // Dimensions + nr, kr, sr, // Packing arguments + bl, // Block length + (const uint8_t*)(rhs_native_with_no_scale_mtx_qs4c32), // RHS + NULL, // Bias + rhs_scales_mtx_f32, // Scale + rhs_packed_mtx_qs4cx, // RHS packed 0, ¶ms); const auto time_s = std::chrono::high_resolution_clock::now(); diff --git a/kai/kai_common.h b/kai/kai_common.h index 306cb962..58107ae1 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -86,6 +86,14 @@ inline static float kai_cast_f32_f16(uint16_t f16) { #endif } +/// Converts a scalar bf16 value to f32 +/// @param[in] bf16 The f16 value +/// +/// @return the f32 value +inline static float kai_bf16_to_f32(uint16_t bf16) { + return bf16 << 16; +} + /// Converts a scalar f32 value to f16 /// @param[in] f32 The f32 value /// diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c index 9cf2ef99..99366d64 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c @@ -24,11 +24,13 @@ inline static size_t kai_num_bytes_per_block(size_t bl, size_t num_bytes_multipl return (bl / 2) + num_bytes_multiplier_rhs; } -inline static size_t kai_rhs_stride(size_t k, size_t bl, size_t num_bytes_multiplier_rhs) { +inline static size_t kai_rhs_stride(size_t k, size_t bl) { KAI_ASSERT((k % 2) == 0); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + // The RHS matrix (not packed) does not pack the scale. + // Therefore, the numbr of bytes per scale must be 0 + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, 0); return num_bytes_per_block * num_blocks_per_row; } @@ -68,7 +70,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( KAI_ASSERT((k % kr) == 0); KAI_ASSERT((k % bl) == 0); KAI_ASSERT((n_idx % nr) == 0); - KAI_ASSERT(scale_dt == F32 || scale_dt == F16); + KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); KAI_UNUSED(kr); @@ -82,7 +84,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( KAI_ASSERT((k % kr) == 0); KAI_ASSERT((k % bl) == 0); KAI_ASSERT((n % nr) == 0); - KAI_ASSERT(scale_dt == F32 || scale_dt == F16); + KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); KAI_UNUSED(kr); @@ -94,7 +96,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, - const float* bias, void* rhs_packed, size_t extra_bytes, + const float* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params* params) { // Temporary asserts KAI_ASSERT(num_groups == 1); @@ -112,39 +114,41 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( KAI_ASSERT(params != NULL); KAI_ASSERT(params->rhs_zero_point == 8); KAI_ASSERT(params->lhs_zero_point == 1); - KAI_ASSERT(params->scale_dt == F32 || params->scale_dt == F16); + KAI_ASSERT(params->scale_dt == F32 || params->scale_dt == F16 || params->scale_dt == Bf16); // Note: The input matrix (rhs) is expected with: // "k" columns and "n" rows (NxK) const size_t num_bytes_multiplier_rhs = kai_num_bytes_datatype(params->scale_dt); const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_stride = kai_rhs_stride(k, bl, num_bytes_multiplier_rhs); + const size_t rhs_stride = kai_rhs_stride(k, bl); const size_t rhs_packed_stride = kai_rhs_packed_stride(k, nr, kr, bl, num_bytes_multiplier_rhs); const size_t rhs_packed_offset_end_of_all_blocks = kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_segments_per_block = bl / kr; const size_t num_bytes_per_segment = kr / 2; - const bool is_scale_f32 = params->scale_dt == F32; + const enum kai_datatype scale_dt = params->scale_dt; for (size_t y = 0; y < n; y += nr) { const uint8_t* src_row = (const uint8_t*)rhs + y * rhs_stride; uint8_t* dst_row = (uint8_t*)rhs_packed + (y / nr) * rhs_packed_stride; - float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); + float* dst_sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); // Initialize the RHS reduction sums to zero - memset(sums, 0, nr * kai_num_bytes_sum_rhs); + memset(dst_sums, 0, nr * kai_num_bytes_sum_rhs); for (size_t x = 0; x < num_blocks_per_row; ++x) { // Store the scales at the end of the block - uint8_t* scales = (dst_row + (bl / 2) * nr); + uint8_t* dst_scales = (dst_row + (bl / 2) * nr); for (size_t i = 0; i < nr; ++i) { - memcpy(scales + i * num_bytes_multiplier_rhs, src_row + i * rhs_stride, num_bytes_multiplier_rhs); + memcpy( + dst_scales + i * num_bytes_multiplier_rhs, // + scale + ((y + i) * num_blocks_per_row + x) * num_bytes_multiplier_rhs, // + num_bytes_multiplier_rhs); // } - src_row += num_bytes_multiplier_rhs; // Store the segments for (size_t s = 0; s < num_segments_per_block; ++s) { @@ -161,14 +165,18 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( dst_row[i * num_bytes_per_segment + b] = qs ^ 0x88; float d = 0.0F; - if (is_scale_f32) { - d = ((float*)scales)[i]; + if (scale_dt == F32) { + d = ((float*)dst_scales)[i]; + } else if (scale_dt == F16) { + d = kai_f16_to_f32(((uint16_t*)dst_scales)[i]); + } else if (scale_dt == Bf16) { + d = kai_bf16_to_f32(((uint16_t*)dst_scales)[i]); } else { - d = kai_f16_to_f32(((uint16_t*)scales)[i]); + KAI_ERROR("Unsupported scale data type"); } - sums[i] += x0 * d; - sums[i] += x1 * d; + dst_sums[i] += x0 * d; + dst_sums[i] += x1 * d; } } @@ -177,11 +185,13 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( } for (size_t i = 0; i < nr; ++i) { - if (is_scale_f32) { - ((float*)scales)[i] *= 0.0625F; + if (scale_dt == F32) { + ((float*)dst_scales)[i] *= 0.0625F; + } else if (scale_dt == F16) { + const float d = kai_f16_to_f32(((uint16_t*)dst_scales)[i]); + ((float*)dst_scales)[i] = kai_f32_to_f16(d * 0.0625F); } else { - const float d = kai_f16_to_f32(((uint16_t*)scales)[i]); - ((float*)scales)[i] = kai_f32_to_f16(d * 0.0625F); + KAI_ERROR("Unsupported scale data type"); } } diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h index 816862f5..6f9291f5 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h @@ -27,12 +27,12 @@ struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params { /// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. /// /// Two int4 K values are stored in one byte. These values are stored in blocks, where each block -/// has it own scale factor. The scale factor is expected to be a f32 value and stored at the end of each block. +/// has it own scale factor. The quantization scale factors are stored in a separate buffer. /// The first byte in the block holds the K-index + 0 and K-index + 16 values. /// The K-index + 0 value is stored in the lower order part of the byte (low nibble) while /// the K-index + 16 value is stored in the higher order part (high nibble). /// For example, if the block length is 32, the values are store in the following order: -/// |byte(s16, s0),byte(s17, s1),byte(s18, s2),...,byte(s31, s15),float32(scale)| +/// |byte(s16, s0),byte(s17, s1),byte(s18, s2),...,byte(s31, s15)| /// /// @param[in] n_idx Row index in the RHS matrix (not packed). /// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) @@ -97,6 +97,8 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( /// @param[in] rhs The RHS matrix containing the 4-bit values. /// Size in bytes is expected to be greater than or equal to n/// k/// (sizeof(uint8_t) / 2). /// @param[in] bias The biases. +/// @param[in] scale The per-block quantization scales. +/// The scale data type must be proviided with the params object. /// @param[out] rhs_packed The packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. /// @param[in] params Parameters for the micro-kernel. @@ -110,6 +112,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( size_t bl, // const uint8_t* rhs, // const float* bias, // + const void* scale, // void* rhs_packed, // size_t extra_bytes, // const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params* params); // -- GitLab From ffa423dde085f17e3af29f4b68ea2e0a2a26ea6f Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Tue, 13 Aug 2024 15:54:43 +0100 Subject: [PATCH 07/29] Moved div by 16 optimization from RHS packing to matmul kernel Signed-off-by: Anitha Raj --- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c | 17 +- ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c | 28 +- ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c | 502 +++++++++--------- ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c | 342 ++++++------ .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c | 20 +- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h | 2 + 6 files changed, 458 insertions(+), 453 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c index 7bf3b70e..34161786 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c @@ -123,10 +123,12 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( __asm__ __volatile__( "mov x27, #0x20\n" + "mov x21, #0x3d800000\n" + "movi v0.16b, #0xf0\n" "mov x20, #0x8\n" - "movi v31.16b, #0xf0\n" "mov x26, %x[m]\n" "mul x27, %x[num_subblocks], x27\n" + "dup v31.4s, w21\n" "madd x27, %x[num_blocks], x27, x20\n" "1:" // Row loop "mov x25, %x[rhs_packed]\n" @@ -154,13 +156,13 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( "shl v19.16b, v26.16b, #0x4\n" "ld1r { v18.2d }, [x22], #0x8\n" "shl v17.16b, v24.16b, #0x4\n" - "and v27.16b, v27.16b, v31.16b\n" + "and v27.16b, v27.16b, v0.16b\n" "shl v16.16b, v23.16b, #0x4\n" - "and v26.16b, v26.16b, v31.16b\n" + "and v26.16b, v26.16b, v0.16b\n" ".inst 0x4e99969d // sdot v29.4s, v20.16b, v25.16b\n" ".inst 0x4e99967c // sdot v28.4s, v19.16b, v25.16b\n" - "and v24.16b, v24.16b, v31.16b\n" - "and v23.16b, v23.16b, v31.16b\n" + "and v24.16b, v24.16b, v0.16b\n" + "and v23.16b, v23.16b, v0.16b\n" ".inst 0x4e96963d // sdot v29.4s, v17.16b, v22.16b\n" ".inst 0x4e96961c // sdot v28.4s, v16.16b, v22.16b\n" ".inst 0x4e95977d // sdot v29.4s, v27.16b, v21.16b\n" @@ -172,6 +174,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( "addp v29.4s, v29.4s, v28.4s\n" "sub x21, x21, #0x1\n" "add x25, x25, #0x10\n" + "fmul v16.4s, v16.4s, v31.4s\n" "scvtf v29.4s, v29.4s\n" "fmla v30.4s, v29.4s, v16.4s\n" "cbnz x21, 3b\n" @@ -215,7 +218,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) - : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27"); + : "cc", "memory", "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27"); } #endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c index 6171ea49..92e44e02 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c @@ -127,10 +127,12 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( __asm__ __volatile__( "mov x27, #0x20\n" + "mov x21, #0x3d800000\n" + "movi v8.16b, #0xf0\n" "mov x20, #0x8\n" - "movi v7.16b, #0xf0\n" "mov x26, %x[m]\n" "mul x27, %x[num_subblocks], x27\n" + "dup v7.4s, w21\n" "madd x27, %x[num_blocks], x27, x20\n" "1:" // Row loop "mov x25, %x[rhs_packed]\n" @@ -165,7 +167,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( "add x25, x25, #0x80\n" "ld1r { v22.2d }, [x22], #0x8\n" "shl v21.16b, v27.16b, #0x4\n" - "and v0.16b, v0.16b, v7.16b\n" + "and v0.16b, v0.16b, v8.16b\n" "ld1r { v20.2d }, [x22], #0x8\n" "ld1r { v19.2d }, [x22], #0x8\n" ".inst 0x4e9c9704 // sdot v4.4s, v24.16b, v28.16b\n" @@ -175,17 +177,17 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( ".inst 0x4e9c9601 // sdot v1.4s, v16.16b, v28.16b\n" "shl v17.16b, v25.16b, #0x4\n" "shl v16.16b, v23.16b, #0x4\n" - "and v31.16b, v31.16b, v7.16b\n" - "and v30.16b, v30.16b, v7.16b\n" - "and v29.16b, v29.16b, v7.16b\n" + "and v31.16b, v31.16b, v8.16b\n" + "and v30.16b, v30.16b, v8.16b\n" + "and v29.16b, v29.16b, v8.16b\n" ".inst 0x4e9696a4 // sdot v4.4s, v21.16b, v22.16b\n" ".inst 0x4e969643 // sdot v3.4s, v18.16b, v22.16b\n" - "and v27.16b, v27.16b, v7.16b\n" + "and v27.16b, v27.16b, v8.16b\n" ".inst 0x4e969622 // sdot v2.4s, v17.16b, v22.16b\n" ".inst 0x4e969601 // sdot v1.4s, v16.16b, v22.16b\n" - "and v26.16b, v26.16b, v7.16b\n" - "and v25.16b, v25.16b, v7.16b\n" - "and v23.16b, v23.16b, v7.16b\n" + "and v26.16b, v26.16b, v8.16b\n" + "and v25.16b, v25.16b, v8.16b\n" + "and v23.16b, v23.16b, v8.16b\n" ".inst 0x4e949404 // sdot v4.4s, v0.16b, v20.16b\n" ".inst 0x4e9497e3 // sdot v3.4s, v31.16b, v20.16b\n" ".inst 0x4e9497c2 // sdot v2.4s, v30.16b, v20.16b\n" @@ -201,6 +203,8 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( "addp v2.4s, v2.4s, v1.4s\n" "sub x21, x21, #0x1\n" "add x25, x25, #0x20\n" + "fmul v17.4s, v17.4s, v7.4s\n" + "fmul v16.4s, v16.4s, v7.4s\n" "scvtf v4.4s, v4.4s\n" "scvtf v2.4s, v2.4s\n" "fmla v6.4s, v4.4s, v17.4s\n" @@ -266,8 +270,8 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", - "x25", "x26", "x27"); + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", + "x24", "x25", "x26", "x27"); } #endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c index 437b90c2..76b4d1da 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -128,10 +128,12 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( __asm__ __volatile__( "mov x12, #0x80\n" "mov x11, %x[m]\n" - "movi v15.16b, #0xf0\n" + "movi v14.16b, #0xf0\n" + "mov x21, #0x3d800000\n" "mov x20, #0x20\n" "mul x12, %x[num_subblocks], x12\n" "cmp x11, #0x8\n" + "dup v20.4s, w21\n" "madd x12, %x[num_blocks], x12, x20\n" "blt 11f\n" "1:" // Row loop @@ -140,190 +142,191 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" "2:" // Column loop "mov x23, %x[lhs_packed]\n" - "movi v22.16b, #0x0\n" - "movi v12.16b, #0x0\n" - "mov x22, %x[num_blocks]\n" "movi v13.16b, #0x0\n" - "movi v20.16b, #0x0\n" - "movi v14.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "mov x22, %x[num_blocks]\n" + "movi v29.16b, #0x0\n" + "movi v15.16b, #0x0\n" "movi v5.16b, #0x0\n" - "movi v7.16b, #0x0\n" + "movi v3.16b, #0x0\n" + "movi v9.16b, #0x0\n" "movi v31.16b, #0x0\n" "add x21, x23, x12\n" "3:" // Block loop - "movi v3.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "movi v2.4s, #0x0\n" "movi v1.4s, #0x0\n" - "movi v30.4s, #0x0\n" - "movi v8.4s, #0x0\n" "movi v4.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" "movi v10.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v16.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "movi v12.4s, #0x0\n" + "movi v0.4s, #0x0\n" "4:" // Sub block loop - "ldr q16, [x10, #0x0]\n" - "ldr q29, [x10, #0x10]\n" + "ldr q2, [x10, #0x0]\n" + "ldr q17, [x10, #0x10]\n" "subs x20, x20, #0x1\n" - "ldr q24, [x23, #0x0]\n" - "ldr q23, [x23, #0x10]\n" - "ldr q28, [x21, #0x0]\n" - "ldr q6, [x21, #0x10]\n" - "ldr q27, [x10, #0x20]\n" - "ldr q11, [x10, #0x30]\n" - "shl v18.16b, v16.16b, #0x4\n" - "shl v21.16b, v29.16b, #0x4\n" - "ldr q17, [x23, #0x20]\n" - "ldr q19, [x23, #0x30]\n" - "and v16.16b, v16.16b, v15.16b\n" - "and v29.16b, v29.16b, v15.16b\n" - "ldr q26, [x21, #0x20]\n" - "ldr q9, [x21, #0x30]\n" + "ldr q26, [x23, #0x0]\n" + "ldr q21, [x23, #0x10]\n" + "ldr q19, [x21, #0x0]\n" + "ldr q18, [x21, #0x10]\n" + "ldr q7, [x10, #0x20]\n" + "ldr q28, [x10, #0x30]\n" + "shl v25.16b, v2.16b, #0x4\n" + "shl v11.16b, v17.16b, #0x4\n" + "ldr q30, [x23, #0x20]\n" + "ldr q22, [x23, #0x30]\n" + "and v2.16b, v2.16b, v14.16b\n" + "and v17.16b, v17.16b, v14.16b\n" + "ldr q8, [x21, #0x20]\n" + "ldr q27, [x21, #0x30]\n" "add x10, x10, #0x40\n" - "ldr q25, [x23, #0x40]\n" - ".inst 0x4e92a703 // smmla v3.4s, v24.16b, v18.16b\n" - ".inst 0x4e95a700 // smmla v0.4s, v24.16b, v21.16b\n" - "ldr q24, [x23, #0x50]\n" - ".inst 0x4e92a6e2 // smmla v2.4s, v23.16b, v18.16b\n" - ".inst 0x4e95a6e1 // smmla v1.4s, v23.16b, v21.16b\n" - "ldr q23, [x21, #0x40]\n" - ".inst 0x4e92a79e // smmla v30.4s, v28.16b, v18.16b\n" - ".inst 0x4e95a788 // smmla v8.4s, v28.16b, v21.16b\n" - "ldr q28, [x21, #0x50]\n" - ".inst 0x4e92a4c4 // smmla v4.4s, v6.16b, v18.16b\n" - "ldr q18, [x23, #0x60]\n" - ".inst 0x4e95a4ca // smmla v10.4s, v6.16b, v21.16b\n" - "ldr q6, [x23, #0x70]\n" - "shl v21.16b, v27.16b, #0x4\n" - "and v27.16b, v27.16b, v15.16b\n" + ".inst 0x4e99a741 // smmla v1.4s, v26.16b, v25.16b\n" + ".inst 0x4e8ba744 // smmla v4.4s, v26.16b, v11.16b\n" + "ldr q26, [x23, #0x40]\n" + ".inst 0x4e99a6aa // smmla v10.4s, v21.16b, v25.16b\n" + ".inst 0x4e8ba6b8 // smmla v24.4s, v21.16b, v11.16b\n" + "ldr q21, [x23, #0x50]\n" + ".inst 0x4e99a670 // smmla v16.4s, v19.16b, v25.16b\n" + ".inst 0x4e8ba666 // smmla v6.4s, v19.16b, v11.16b\n" + "ldr q19, [x21, #0x40]\n" + ".inst 0x4e99a64c // smmla v12.4s, v18.16b, v25.16b\n" + "ldr q25, [x21, #0x50]\n" + ".inst 0x4e8ba640 // smmla v0.4s, v18.16b, v11.16b\n" + "ldr q11, [x23, #0x60]\n" + "shl v18.16b, v7.16b, #0x4\n" + "and v7.16b, v7.16b, v14.16b\n" + ".inst 0x4e92a7c1 // smmla v1.4s, v30.16b, v18.16b\n" + ".inst 0x4e92a6ca // smmla v10.4s, v22.16b, v18.16b\n" + ".inst 0x4e92a510 // smmla v16.4s, v8.16b, v18.16b\n" + ".inst 0x4e92a76c // smmla v12.4s, v27.16b, v18.16b\n" + "ldr q18, [x23, #0x70]\n" "add x23, x23, #0x80\n" - ".inst 0x4e95a623 // smmla v3.4s, v17.16b, v21.16b\n" - ".inst 0x4e95a662 // smmla v2.4s, v19.16b, v21.16b\n" - ".inst 0x4e95a75e // smmla v30.4s, v26.16b, v21.16b\n" - ".inst 0x4e95a524 // smmla v4.4s, v9.16b, v21.16b\n" - "ldr q21, [x21, #0x60]\n" - ".inst 0x4e90a723 // smmla v3.4s, v25.16b, v16.16b\n" - ".inst 0x4e90a702 // smmla v2.4s, v24.16b, v16.16b\n" - ".inst 0x4e90a6fe // smmla v30.4s, v23.16b, v16.16b\n" - ".inst 0x4e90a784 // smmla v4.4s, v28.16b, v16.16b\n" - "ldr q16, [x21, #0x70]\n" + ".inst 0x4e82a741 // smmla v1.4s, v26.16b, v2.16b\n" + ".inst 0x4e82a6aa // smmla v10.4s, v21.16b, v2.16b\n" + ".inst 0x4e82a670 // smmla v16.4s, v19.16b, v2.16b\n" + ".inst 0x4e82a72c // smmla v12.4s, v25.16b, v2.16b\n" + "shl v2.16b, v28.16b, #0x4\n" + "and v28.16b, v28.16b, v14.16b\n" + ".inst 0x4e82a7c4 // smmla v4.4s, v30.16b, v2.16b\n" + "ldr q30, [x21, #0x60]\n" + ".inst 0x4e82a6d8 // smmla v24.4s, v22.16b, v2.16b\n" + "ldr q22, [x21, #0x70]\n" "add x21, x21, #0x80\n" - ".inst 0x4e9ba643 // smmla v3.4s, v18.16b, v27.16b\n" - ".inst 0x4e9ba4c2 // smmla v2.4s, v6.16b, v27.16b\n" - ".inst 0x4e9ba6be // smmla v30.4s, v21.16b, v27.16b\n" - ".inst 0x4e9ba604 // smmla v4.4s, v16.16b, v27.16b\n" - "shl v27.16b, v11.16b, #0x4\n" - "and v11.16b, v11.16b, v15.16b\n" - ".inst 0x4e9ba620 // smmla v0.4s, v17.16b, v27.16b\n" - ".inst 0x4e9ba661 // smmla v1.4s, v19.16b, v27.16b\n" - ".inst 0x4e9ba748 // smmla v8.4s, v26.16b, v27.16b\n" - ".inst 0x4e9ba52a // smmla v10.4s, v9.16b, v27.16b\n" - ".inst 0x4e9da720 // smmla v0.4s, v25.16b, v29.16b\n" - ".inst 0x4e9da701 // smmla v1.4s, v24.16b, v29.16b\n" - ".inst 0x4e9da6e8 // smmla v8.4s, v23.16b, v29.16b\n" - ".inst 0x4e9da78a // smmla v10.4s, v28.16b, v29.16b\n" - ".inst 0x4e8ba640 // smmla v0.4s, v18.16b, v11.16b\n" - ".inst 0x4e8ba4c1 // smmla v1.4s, v6.16b, v11.16b\n" - ".inst 0x4e8ba6a8 // smmla v8.4s, v21.16b, v11.16b\n" - ".inst 0x4e8ba60a // smmla v10.4s, v16.16b, v11.16b\n" + ".inst 0x4e82a506 // smmla v6.4s, v8.16b, v2.16b\n" + ".inst 0x4e82a760 // smmla v0.4s, v27.16b, v2.16b\n" + ".inst 0x4e87a561 // smmla v1.4s, v11.16b, v7.16b\n" + ".inst 0x4e87a64a // smmla v10.4s, v18.16b, v7.16b\n" + ".inst 0x4e87a7d0 // smmla v16.4s, v30.16b, v7.16b\n" + ".inst 0x4e91a744 // smmla v4.4s, v26.16b, v17.16b\n" + ".inst 0x4e91a6b8 // smmla v24.4s, v21.16b, v17.16b\n" + ".inst 0x4e87a6cc // smmla v12.4s, v22.16b, v7.16b\n" + ".inst 0x4e91a666 // smmla v6.4s, v19.16b, v17.16b\n" + ".inst 0x4e91a720 // smmla v0.4s, v25.16b, v17.16b\n" + ".inst 0x4e9ca564 // smmla v4.4s, v11.16b, v28.16b\n" + ".inst 0x4e9ca658 // smmla v24.4s, v18.16b, v28.16b\n" + ".inst 0x4e9ca7c6 // smmla v6.4s, v30.16b, v28.16b\n" + ".inst 0x4e9ca6c0 // smmla v0.4s, v22.16b, v28.16b\n" "bgt 4b\n" - "ldr q11, [x10, #0x0]\n" - "uzp1 v19.2d, v3.2d, v0.2d\n" - "uzp2 v18.2d, v3.2d, v0.2d\n" + "ldr q19, [x10, #0x0]\n" + "uzp1 v2.2d, v1.2d, v4.2d\n" + "uzp2 v28.2d, v1.2d, v4.2d\n" "add x10, x10, #0x10\n" - "uzp1 v17.2d, v2.2d, v1.2d\n" - "uzp2 v16.2d, v2.2d, v1.2d\n" - "scvtf v19.4s, v19.4s\n" - "scvtf v18.4s, v18.4s\n" + "uzp1 v17.2d, v10.2d, v24.2d\n" + "uzp2 v4.2d, v10.2d, v24.2d\n" + "fmul v19.4s, v19.4s, v20.4s\n" + "scvtf v2.4s, v2.4s\n" + "scvtf v28.4s, v28.4s\n" "scvtf v17.4s, v17.4s\n" - "scvtf v16.4s, v16.4s\n" - "fmla v22.4s, v19.4s, v11.4s\n" - "fmla v12.4s, v18.4s, v11.4s\n" - "fmla v13.4s, v17.4s, v11.4s\n" - "fmla v20.4s, v16.4s, v11.4s\n" - "uzp1 v19.2d, v30.2d, v8.2d\n" - "uzp2 v18.2d, v30.2d, v8.2d\n" - "uzp1 v17.2d, v4.2d, v10.2d\n" - "uzp2 v16.2d, v4.2d, v10.2d\n" - "scvtf v19.4s, v19.4s\n" - "scvtf v18.4s, v18.4s\n" + "scvtf v4.4s, v4.4s\n" + "fmla v13.4s, v2.4s, v19.4s\n" + "fmla v23.4s, v28.4s, v19.4s\n" + "fmla v29.4s, v17.4s, v19.4s\n" + "fmla v15.4s, v4.4s, v19.4s\n" + "uzp1 v4.2d, v16.2d, v6.2d\n" + "uzp2 v6.2d, v16.2d, v6.2d\n" + "uzp1 v17.2d, v12.2d, v0.2d\n" + "uzp2 v24.2d, v12.2d, v0.2d\n" + "scvtf v4.4s, v4.4s\n" + "scvtf v6.4s, v6.4s\n" "scvtf v17.4s, v17.4s\n" - "scvtf v16.4s, v16.4s\n" - "fmla v14.4s, v19.4s, v11.4s\n" - "fmla v5.4s, v18.4s, v11.4s\n" - "fmla v7.4s, v17.4s, v11.4s\n" - "fmla v31.4s, v16.4s, v11.4s\n" + "scvtf v24.4s, v24.4s\n" + "fmla v5.4s, v4.4s, v19.4s\n" + "fmla v3.4s, v6.4s, v19.4s\n" + "fmla v9.4s, v17.4s, v19.4s\n" + "fmla v31.4s, v24.4s, v19.4s\n" "subs x22, x22, #0x1\n" "bgt 3b\n" "ld1 { v24.4s }, [x23]\n" - "ld1 { v23.4s }, [x21]\n" + "ld1 { v22.4s }, [x21]\n" "add x23, x23, #0x10\n" "add x21, x21, #0x10\n" "ldr q21, [x10, #0x0]\n" - "ldr q26, [x23, #0x0]\n" + "ldr q11, [x23, #0x0]\n" "add x20, %x[clamp_vals], #0x4\n" "cmp x9, #0x4\n" "ldr q19, [x21, #0x0]\n" - "ldr q18, [x10, #0x10]\n" + "ldr q0, [x10, #0x10]\n" "add x10, x10, #0x20\n" "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "ld1r { v16.4s }, [x20]\n" + "ld1r { v4.4s }, [x20]\n" "scvtf v24.4s, v24.4s\n" - "scvtf v23.4s, v23.4s\n" - "fmla v22.4s, v21.4s, v24.s[0]\n" - "fmla v12.4s, v21.4s, v24.s[1]\n" - "fmla v13.4s, v21.4s, v24.s[2]\n" - "fmla v20.4s, v21.4s, v24.s[3]\n" - "fmla v14.4s, v21.4s, v23.s[0]\n" - "fmla v5.4s, v21.4s, v23.s[1]\n" - "fmla v7.4s, v21.4s, v23.s[2]\n" - "fmla v31.4s, v21.4s, v23.s[3]\n" - "fmul v22.4s, v22.4s, v26.s[0]\n" - "fmul v12.4s, v12.4s, v26.s[1]\n" - "fmul v13.4s, v13.4s, v26.s[2]\n" - "fmul v20.4s, v20.4s, v26.s[3]\n" - "fmul v14.4s, v14.4s, v19.s[0]\n" - "fmul v5.4s, v5.4s, v19.s[1]\n" - "fadd v22.4s, v22.4s, v18.4s\n" - "fmul v7.4s, v7.4s, v19.s[2]\n" + "scvtf v22.4s, v22.4s\n" + "fmla v13.4s, v21.4s, v24.s[0]\n" + "fmla v23.4s, v21.4s, v24.s[1]\n" + "fmla v29.4s, v21.4s, v24.s[2]\n" + "fmla v15.4s, v21.4s, v24.s[3]\n" + "fmla v5.4s, v21.4s, v22.s[0]\n" + "fmla v3.4s, v21.4s, v22.s[1]\n" + "fmla v9.4s, v21.4s, v22.s[2]\n" + "fmla v31.4s, v21.4s, v22.s[3]\n" + "fmul v13.4s, v13.4s, v11.s[0]\n" + "fmul v23.4s, v23.4s, v11.s[1]\n" + "fmul v29.4s, v29.4s, v11.s[2]\n" + "fmul v15.4s, v15.4s, v11.s[3]\n" + "fmul v5.4s, v5.4s, v19.s[0]\n" + "fmul v3.4s, v3.4s, v19.s[1]\n" + "fadd v13.4s, v13.4s, v0.4s\n" + "fmul v9.4s, v9.4s, v19.s[2]\n" "fmul v31.4s, v31.4s, v19.s[3]\n" - "fadd v12.4s, v12.4s, v18.4s\n" - "fadd v13.4s, v13.4s, v18.4s\n" - "fadd v20.4s, v20.4s, v18.4s\n" - "fadd v14.4s, v14.4s, v18.4s\n" - "fadd v5.4s, v5.4s, v18.4s\n" - "fadd v7.4s, v7.4s, v18.4s\n" - "fadd v31.4s, v31.4s, v18.4s\n" - "fmax v22.4s, v22.4s, v17.4s\n" - "fmax v12.4s, v12.4s, v17.4s\n" + "fadd v23.4s, v23.4s, v0.4s\n" + "fadd v29.4s, v29.4s, v0.4s\n" + "fadd v15.4s, v15.4s, v0.4s\n" + "fadd v5.4s, v5.4s, v0.4s\n" + "fadd v3.4s, v3.4s, v0.4s\n" + "fadd v9.4s, v9.4s, v0.4s\n" + "fadd v31.4s, v31.4s, v0.4s\n" "fmax v13.4s, v13.4s, v17.4s\n" - "fmax v20.4s, v20.4s, v17.4s\n" - "fmax v14.4s, v14.4s, v17.4s\n" + "fmax v23.4s, v23.4s, v17.4s\n" + "fmax v29.4s, v29.4s, v17.4s\n" + "fmax v15.4s, v15.4s, v17.4s\n" "fmax v5.4s, v5.4s, v17.4s\n" - "fmax v7.4s, v7.4s, v17.4s\n" + "fmax v3.4s, v3.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" "fmax v31.4s, v31.4s, v17.4s\n" - "fmin v22.4s, v22.4s, v16.4s\n" - "fmin v12.4s, v12.4s, v16.4s\n" - "fmin v13.4s, v13.4s, v16.4s\n" - "fmin v20.4s, v20.4s, v16.4s\n" - "fmin v14.4s, v14.4s, v16.4s\n" - "fmin v5.4s, v5.4s, v16.4s\n" - "fmin v7.4s, v7.4s, v16.4s\n" - "fmin v31.4s, v31.4s, v16.4s\n" + "fmin v13.4s, v13.4s, v4.4s\n" + "fmin v23.4s, v23.4s, v4.4s\n" + "fmin v29.4s, v29.4s, v4.4s\n" + "fmin v15.4s, v15.4s, v4.4s\n" + "fmin v5.4s, v5.4s, v4.4s\n" + "fmin v3.4s, v3.4s, v4.4s\n" + "fmin v9.4s, v9.4s, v4.4s\n" + "fmin v31.4s, v31.4s, v4.4s\n" "blt 7f\n" "mov x20, %x[dst]\n" - "str q22, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q12, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" "str q13, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q20, [x20, #0x0]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q29, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q14, [x20, #0x0]\n" + "str q15, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" "str q5, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q7, [x20, #0x0]\n" + "str q3, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q9, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" "str q31, [x20, #0x0]\n" "b 10f\n" @@ -338,32 +341,32 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( "add x20, x22, %x[dst_stride_row]\n" "tbz x9, #1, 8f\n" "st1 { v31.d }[0], [x23], #0x8\n" - "st1 { v7.d }[0], [x25], #0x8\n" - "st1 { v5.d }[0], [x24], #0x8\n" - "st1 { v14.d }[0], [x26], #0x8\n" - "st1 { v20.d }[0], [x20], #0x8\n" - "st1 { v13.d }[0], [x22], #0x8\n" - "st1 { v12.d }[0], [x21], #0x8\n" - "st1 { v22.d }[0], [x27], #0x8\n" + "st1 { v9.d }[0], [x25], #0x8\n" + "st1 { v3.d }[0], [x24], #0x8\n" + "st1 { v5.d }[0], [x26], #0x8\n" + "st1 { v15.d }[0], [x20], #0x8\n" + "st1 { v29.d }[0], [x22], #0x8\n" + "st1 { v23.d }[0], [x21], #0x8\n" + "st1 { v13.d }[0], [x27], #0x8\n" "tbz x9, #0, 9f\n" "st1 { v31.s }[2], [x23]\n" - "st1 { v7.s }[2], [x25]\n" - "st1 { v5.s }[2], [x24]\n" - "st1 { v14.s }[2], [x26]\n" - "st1 { v20.s }[2], [x20]\n" - "st1 { v13.s }[2], [x22]\n" - "st1 { v12.s }[2], [x21]\n" - "st1 { v22.s }[2], [x27]\n" + "st1 { v9.s }[2], [x25]\n" + "st1 { v3.s }[2], [x24]\n" + "st1 { v5.s }[2], [x26]\n" + "st1 { v15.s }[2], [x20]\n" + "st1 { v29.s }[2], [x22]\n" + "st1 { v23.s }[2], [x21]\n" + "st1 { v13.s }[2], [x27]\n" "b 9f\n" "8:" // Output block 0: partial_1_0 "st1 { v31.s }[0], [x23]\n" - "st1 { v7.s }[0], [x25]\n" - "st1 { v5.s }[0], [x24]\n" - "st1 { v14.s }[0], [x26]\n" - "st1 { v20.s }[0], [x20]\n" - "st1 { v13.s }[0], [x22]\n" - "st1 { v12.s }[0], [x21]\n" - "st1 { v22.s }[0], [x27]\n" + "st1 { v9.s }[0], [x25]\n" + "st1 { v3.s }[0], [x24]\n" + "st1 { v5.s }[0], [x26]\n" + "st1 { v15.s }[0], [x20]\n" + "st1 { v29.s }[0], [x22]\n" + "st1 { v23.s }[0], [x21]\n" + "st1 { v13.s }[0], [x27]\n" "9:" // Output block 0: Done "10:" // Output stage exit "subs x9, x9, #0x4\n" @@ -382,77 +385,78 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( "mov x25, %x[n]\n" "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" "13:" // Row tail: Column loop - "movi v22.16b, #0x0\n" - "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v23.16b, #0x0\n" "mov x23, %x[lhs_packed]\n" "mov x21, %x[num_blocks]\n" - "movi v13.16b, #0x0\n" - "movi v20.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v15.16b, #0x0\n" "14:" // Row tail: Block loop - "movi v3.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "movi v2.4s, #0x0\n" "movi v1.4s, #0x0\n" + "movi v4.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v10.4s, #0x0\n" + "movi v24.4s, #0x0\n" "15:" // Row tail: Sub block loop - "ldr q31, [x26, #0x0]\n" - "ldr q8, [x26, #0x10]\n" + "ldr q0, [x26, #0x0]\n" + "ldr q31, [x26, #0x10]\n" "subs x20, x20, #0x1\n" "ldr q30, [x23, #0x0]\n" - "ldr q29, [x23, #0x10]\n" - "ldr q6, [x26, #0x20]\n" + "ldr q3, [x23, #0x10]\n" + "ldr q8, [x26, #0x20]\n" "ldr q27, [x26, #0x30]\n" "add x26, x26, #0x40\n" "ldr q26, [x23, #0x20]\n" "ldr q25, [x23, #0x30]\n" - "shl v24.16b, v31.16b, #0x4\n" - "shl v23.16b, v8.16b, #0x4\n" - "ldr q9, [x23, #0x40]\n" - "ldr q11, [x23, #0x50]\n" - "and v31.16b, v31.16b, v15.16b\n" - "and v8.16b, v8.16b, v15.16b\n" + "shl v22.16b, v0.16b, #0x4\n" + "shl v16.16b, v31.16b, #0x4\n" + "ldr q21, [x23, #0x40]\n" + "ldr q7, [x23, #0x50]\n" + "and v0.16b, v0.16b, v14.16b\n" + "and v31.16b, v31.16b, v14.16b\n" "ldr q19, [x23, #0x60]\n" "ldr q18, [x23, #0x70]\n" - "shl v17.16b, v6.16b, #0x4\n" - "shl v16.16b, v27.16b, #0x4\n" - ".inst 0x4e98a7c3 // smmla v3.4s, v30.16b, v24.16b\n" - ".inst 0x4e97a7c0 // smmla v0.4s, v30.16b, v23.16b\n" - "and v6.16b, v6.16b, v15.16b\n" + "shl v17.16b, v8.16b, #0x4\n" + "shl v2.16b, v27.16b, #0x4\n" + ".inst 0x4e96a7c1 // smmla v1.4s, v30.16b, v22.16b\n" + ".inst 0x4e90a7c4 // smmla v4.4s, v30.16b, v16.16b\n" + "and v8.16b, v8.16b, v14.16b\n" "add x23, x23, #0x80\n" - ".inst 0x4e98a7a2 // smmla v2.4s, v29.16b, v24.16b\n" - ".inst 0x4e97a7a1 // smmla v1.4s, v29.16b, v23.16b\n" - "and v27.16b, v27.16b, v15.16b\n" - ".inst 0x4e91a743 // smmla v3.4s, v26.16b, v17.16b\n" - ".inst 0x4e90a740 // smmla v0.4s, v26.16b, v16.16b\n" - ".inst 0x4e91a722 // smmla v2.4s, v25.16b, v17.16b\n" - ".inst 0x4e90a721 // smmla v1.4s, v25.16b, v16.16b\n" - ".inst 0x4e9fa523 // smmla v3.4s, v9.16b, v31.16b\n" - ".inst 0x4e88a520 // smmla v0.4s, v9.16b, v8.16b\n" - ".inst 0x4e9fa562 // smmla v2.4s, v11.16b, v31.16b\n" - ".inst 0x4e88a561 // smmla v1.4s, v11.16b, v8.16b\n" - ".inst 0x4e86a663 // smmla v3.4s, v19.16b, v6.16b\n" - ".inst 0x4e9ba660 // smmla v0.4s, v19.16b, v27.16b\n" - ".inst 0x4e86a642 // smmla v2.4s, v18.16b, v6.16b\n" - ".inst 0x4e9ba641 // smmla v1.4s, v18.16b, v27.16b\n" + ".inst 0x4e96a46a // smmla v10.4s, v3.16b, v22.16b\n" + ".inst 0x4e90a478 // smmla v24.4s, v3.16b, v16.16b\n" + "and v27.16b, v27.16b, v14.16b\n" + ".inst 0x4e91a741 // smmla v1.4s, v26.16b, v17.16b\n" + ".inst 0x4e82a744 // smmla v4.4s, v26.16b, v2.16b\n" + ".inst 0x4e91a72a // smmla v10.4s, v25.16b, v17.16b\n" + ".inst 0x4e82a738 // smmla v24.4s, v25.16b, v2.16b\n" + ".inst 0x4e80a6a1 // smmla v1.4s, v21.16b, v0.16b\n" + ".inst 0x4e9fa6a4 // smmla v4.4s, v21.16b, v31.16b\n" + ".inst 0x4e80a4ea // smmla v10.4s, v7.16b, v0.16b\n" + ".inst 0x4e9fa4f8 // smmla v24.4s, v7.16b, v31.16b\n" + ".inst 0x4e88a661 // smmla v1.4s, v19.16b, v8.16b\n" + ".inst 0x4e9ba664 // smmla v4.4s, v19.16b, v27.16b\n" + ".inst 0x4e88a64a // smmla v10.4s, v18.16b, v8.16b\n" + ".inst 0x4e9ba658 // smmla v24.4s, v18.16b, v27.16b\n" "bgt 15b\n" - "ldr q11, [x26, #0x0]\n" - "uzp1 v19.2d, v3.2d, v0.2d\n" - "uzp2 v18.2d, v3.2d, v0.2d\n" + "ldr q7, [x26, #0x0]\n" + "uzp1 v19.2d, v1.2d, v4.2d\n" + "uzp2 v18.2d, v1.2d, v4.2d\n" "add x26, x26, #0x10\n" - "uzp1 v17.2d, v2.2d, v1.2d\n" - "uzp2 v16.2d, v2.2d, v1.2d\n" + "uzp1 v17.2d, v10.2d, v24.2d\n" + "uzp2 v16.2d, v10.2d, v24.2d\n" + "fmul v7.4s, v7.4s, v20.4s\n" "scvtf v19.4s, v19.4s\n" "scvtf v18.4s, v18.4s\n" "scvtf v17.4s, v17.4s\n" "scvtf v16.4s, v16.4s\n" - "fmla v22.4s, v19.4s, v11.4s\n" - "fmla v12.4s, v18.4s, v11.4s\n" - "fmla v13.4s, v17.4s, v11.4s\n" - "fmla v20.4s, v16.4s, v11.4s\n" + "fmla v13.4s, v19.4s, v7.4s\n" + "fmla v23.4s, v18.4s, v7.4s\n" + "fmla v29.4s, v17.4s, v7.4s\n" + "fmla v15.4s, v16.4s, v7.4s\n" "subs x21, x21, #0x1\n" "bgt 14b\n" "ld1 { v21.4s }, [x23]\n" - "ldr q1, [x26, #0x0]\n" + "ldr q10, [x26, #0x0]\n" "add x23, x23, #0x10\n" "add x20, %x[clamp_vals], #0x4\n" "ldr q19, [x23, #0x0]\n" @@ -462,69 +466,69 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( "ld1r { v17.4s }, [%x[clamp_vals]]\n" "ld1r { v16.4s }, [x20]\n" "scvtf v21.4s, v21.4s\n" - "fmla v22.4s, v1.4s, v21.s[0]\n" - "fmla v12.4s, v1.4s, v21.s[1]\n" - "fmla v13.4s, v1.4s, v21.s[2]\n" - "fmla v20.4s, v1.4s, v21.s[3]\n" - "fmul v22.4s, v22.4s, v19.s[0]\n" - "fmul v12.4s, v12.4s, v19.s[1]\n" - "fmul v13.4s, v13.4s, v19.s[2]\n" - "fadd v22.4s, v22.4s, v18.4s\n" - "fmul v20.4s, v20.4s, v19.s[3]\n" - "fadd v12.4s, v12.4s, v18.4s\n" + "fmla v13.4s, v10.4s, v21.s[0]\n" + "fmla v23.4s, v10.4s, v21.s[1]\n" + "fmla v29.4s, v10.4s, v21.s[2]\n" + "fmla v15.4s, v10.4s, v21.s[3]\n" + "fmul v13.4s, v13.4s, v19.s[0]\n" + "fmul v23.4s, v23.4s, v19.s[1]\n" + "fmul v29.4s, v29.4s, v19.s[2]\n" "fadd v13.4s, v13.4s, v18.4s\n" - "fadd v20.4s, v20.4s, v18.4s\n" - "fmax v22.4s, v22.4s, v17.4s\n" - "fmax v12.4s, v12.4s, v17.4s\n" + "fmul v15.4s, v15.4s, v19.s[3]\n" + "fadd v23.4s, v23.4s, v18.4s\n" + "fadd v29.4s, v29.4s, v18.4s\n" + "fadd v15.4s, v15.4s, v18.4s\n" "fmax v13.4s, v13.4s, v17.4s\n" - "fmax v20.4s, v20.4s, v17.4s\n" - "fmin v22.4s, v22.4s, v16.4s\n" - "fmin v12.4s, v12.4s, v16.4s\n" + "fmax v23.4s, v23.4s, v17.4s\n" + "fmax v29.4s, v29.4s, v17.4s\n" + "fmax v15.4s, v15.4s, v17.4s\n" "fmin v13.4s, v13.4s, v16.4s\n" - "fmin v20.4s, v20.4s, v16.4s\n" + "fmin v23.4s, v23.4s, v16.4s\n" + "fmin v29.4s, v29.4s, v16.4s\n" + "fmin v15.4s, v15.4s, v16.4s\n" "blt 17f\n" "mov x20, %x[dst]\n" "cmp x11, #0x1\n" - "str q22, [x20, #0x0]\n" + "str q13, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 20f\n" "cmp x11, #0x2\n" - "str q12, [x20, #0x0]\n" + "str q23, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 20f\n" "cmp x11, #0x3\n" - "str q13, [x20, #0x0]\n" + "str q29, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 20f\n" - "str q20, [x20, #0x0]\n" + "str q15, [x20, #0x0]\n" "b 20f\n" "17:" // Row tail: Partial output "mov x23, %x[dst]\n" "cmp x11, #0x1\n" "add x22, x23, %x[dst_stride_row]\n" - "csel x22, x22, x23, GE\n" + "csel x22, x22, x23, GT\n" "cmp x11, #0x2\n" "add x21, x23, %x[dst_stride_row], LSL #1\n" - "csel x21, x21, x22, GE\n" + "csel x21, x21, x22, GT\n" "cmp x11, #0x3\n" "add x20, x21, %x[dst_stride_row]\n" - "csel x20, x20, x21, GE\n" + "csel x20, x20, x21, GT\n" "tbz x25, #1, 18f\n" - "st1 { v20.d }[0], [x20], #0x8\n" - "st1 { v13.d }[0], [x21], #0x8\n" - "st1 { v12.d }[0], [x22], #0x8\n" - "st1 { v22.d }[0], [x23], #0x8\n" + "st1 { v15.d }[0], [x20], #0x8\n" + "st1 { v29.d }[0], [x21], #0x8\n" + "st1 { v23.d }[0], [x22], #0x8\n" + "st1 { v13.d }[0], [x23], #0x8\n" "tbz x25, #0, 19f\n" - "st1 { v20.s }[2], [x20]\n" - "st1 { v13.s }[2], [x21]\n" - "st1 { v12.s }[2], [x22]\n" - "st1 { v22.s }[2], [x23]\n" + "st1 { v15.s }[2], [x20]\n" + "st1 { v29.s }[2], [x21]\n" + "st1 { v23.s }[2], [x22]\n" + "st1 { v13.s }[2], [x23]\n" "b 19f\n" "18:" // Row tail: Output block 0: partial_1_0 - "st1 { v20.s }[0], [x20]\n" - "st1 { v13.s }[0], [x21]\n" - "st1 { v12.s }[0], [x22]\n" - "st1 { v22.s }[0], [x23]\n" + "st1 { v15.s }[0], [x20]\n" + "st1 { v29.s }[0], [x21]\n" + "st1 { v23.s }[0], [x22]\n" + "st1 { v13.s }[0], [x23]\n" "19:" // Row tail: Output block 0: Done "20:" // Row tail: Output stage exit "subs x25, x25, #0x4\n" diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c index c0070035..e5472d21 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c @@ -127,10 +127,12 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( __asm__ __volatile__( "mov x28, #0x80\n" + "mov x21, #0x3d800000\n" + "movi v17.16b, #0xf0\n" "mov x20, #0x20\n" - "movi v14.16b, #0xf0\n" "mov x27, %x[m]\n" "mul x28, %x[num_subblocks], x28\n" + "dup v15.4s, w21\n" "madd x28, %x[num_blocks], x28, x20\n" "cbz x27, 12f\n" "1:" // Row loop @@ -138,248 +140,250 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( "mov x25, %x[n]\n" "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" "2:" // Column loop - "movi v27.16b, #0x0\n" + "movi v14.16b, #0x0\n" "movi v12.16b, #0x0\n" "mov x22, %x[lhs_packed]\n" "mov x21, %x[num_blocks]\n" "movi v11.16b, #0x0\n" - "movi v15.16b, #0x0\n" "movi v13.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "movi v8.16b, #0x0\n" - "movi v1.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v9.16b, #0x0\n" "3:" // Block loop - "movi v7.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "movi v3.4s, #0x0\n" "movi v0.4s, #0x0\n" - "movi v2.4s, #0x0\n" - "movi v25.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v6.4s, #0x0\n" + "movi v8.4s, #0x0\n" "movi v4.4s, #0x0\n" - "movi v28.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "movi v22.4s, #0x0\n" + "movi v2.4s, #0x0\n" "4:" // Sub block loop - "ldr q21, [x26, #0x0]\n" - "ldr q30, [x26, #0x10]\n" + "ldr q24, [x26, #0x0]\n" + "ldr q23, [x26, #0x10]\n" "subs x20, x20, #0x1\n" - "ldr q22, [x26, #0x20]\n" - "ldr q29, [x26, #0x30]\n" + "ldr q31, [x26, #0x20]\n" + "ldr q16, [x26, #0x30]\n" "ldr q20, [x22, #0x0]\n" - "ldr q23, [x22, #0x10]\n" + "ldr q3, [x22, #0x10]\n" "ldr q19, [x26, #0x40]\n" - "ldr q9, [x26, #0x50]\n" - "shl v26.16b, v21.16b, #0x4\n" - "shl v6.16b, v30.16b, #0x4\n" - "ldr q18, [x26, #0x60]\n" - "ldr q24, [x26, #0x70]\n" - "shl v31.16b, v22.16b, #0x4\n" - "shl v16.16b, v29.16b, #0x4\n" - "ldr q17, [x22, #0x20]\n" - "and v21.16b, v21.16b, v14.16b\n" - "and v30.16b, v30.16b, v14.16b\n" + "ldr q28, [x26, #0x50]\n" + "shl v1.16b, v24.16b, #0x4\n" + "shl v21.16b, v23.16b, #0x4\n" + "ldr q25, [x26, #0x60]\n" + "ldr q10, [x26, #0x70]\n" + "shl v30.16b, v31.16b, #0x4\n" + "shl v26.16b, v16.16b, #0x4\n" + "and v24.16b, v24.16b, v17.16b\n" + "and v23.16b, v23.16b, v17.16b\n" "add x26, x26, #0x80\n" - ".inst 0x4e9aa687 // smmla v7.4s, v20.16b, v26.16b\n" - ".inst 0x4e86a683 // smmla v3.4s, v20.16b, v6.16b\n" - "and v22.16b, v22.16b, v14.16b\n" - ".inst 0x4e9fa68a // smmla v10.4s, v20.16b, v31.16b\n" - ".inst 0x4e90a680 // smmla v0.4s, v20.16b, v16.16b\n" + ".inst 0x4e81a680 // smmla v0.4s, v20.16b, v1.16b\n" + ".inst 0x4e95a686 // smmla v6.4s, v20.16b, v21.16b\n" + ".inst 0x4e81a464 // smmla v4.4s, v3.16b, v1.16b\n" + "ldr q1, [x22, #0x20]\n" + "and v31.16b, v31.16b, v17.16b\n" + ".inst 0x4e9ea685 // smmla v5.4s, v20.16b, v30.16b\n" + ".inst 0x4e9aa688 // smmla v8.4s, v20.16b, v26.16b\n" "ldr q20, [x22, #0x30]\n" - "and v29.16b, v29.16b, v14.16b\n" - ".inst 0x4e9aa6e2 // smmla v2.4s, v23.16b, v26.16b\n" - "ldr q26, [x22, #0x40]\n" - ".inst 0x4e86a6e4 // smmla v4.4s, v23.16b, v6.16b\n" - "ldr q6, [x22, #0x50]\n" - ".inst 0x4e9fa6f9 // smmla v25.4s, v23.16b, v31.16b\n" - "ldr q31, [x22, #0x60]\n" - ".inst 0x4e90a6fc // smmla v28.4s, v23.16b, v16.16b\n" - "ldr q23, [x22, #0x70]\n" - "shl v16.16b, v19.16b, #0x4\n" - "and v19.16b, v19.16b, v14.16b\n" + "and v16.16b, v16.16b, v17.16b\n" + ".inst 0x4e95a476 // smmla v22.4s, v3.16b, v21.16b\n" + "ldr q21, [x22, #0x40]\n" + ".inst 0x4e9ea467 // smmla v7.4s, v3.16b, v30.16b\n" + "ldr q30, [x22, #0x50]\n" + ".inst 0x4e9aa462 // smmla v2.4s, v3.16b, v26.16b\n" + "ldr q26, [x22, #0x60]\n" + "shl v3.16b, v19.16b, #0x4\n" + "and v19.16b, v19.16b, v17.16b\n" + ".inst 0x4e83a420 // smmla v0.4s, v1.16b, v3.16b\n" + ".inst 0x4e83a684 // smmla v4.4s, v20.16b, v3.16b\n" + "ldr q3, [x22, #0x70]\n" "add x22, x22, #0x80\n" - ".inst 0x4e90a627 // smmla v7.4s, v17.16b, v16.16b\n" - ".inst 0x4e90a682 // smmla v2.4s, v20.16b, v16.16b\n" - "shl v16.16b, v9.16b, #0x4\n" - "and v9.16b, v9.16b, v14.16b\n" - ".inst 0x4e90a623 // smmla v3.4s, v17.16b, v16.16b\n" - ".inst 0x4e90a684 // smmla v4.4s, v20.16b, v16.16b\n" - "shl v16.16b, v18.16b, #0x4\n" - "and v18.16b, v18.16b, v14.16b\n" - ".inst 0x4e95a747 // smmla v7.4s, v26.16b, v21.16b\n" - ".inst 0x4e95a4c2 // smmla v2.4s, v6.16b, v21.16b\n" - "shl v21.16b, v24.16b, #0x4\n" - "and v24.16b, v24.16b, v14.16b\n" - ".inst 0x4e90a62a // smmla v10.4s, v17.16b, v16.16b\n" - ".inst 0x4e90a699 // smmla v25.4s, v20.16b, v16.16b\n" - ".inst 0x4e9ea743 // smmla v3.4s, v26.16b, v30.16b\n" - ".inst 0x4e9ea4c4 // smmla v4.4s, v6.16b, v30.16b\n" - ".inst 0x4e95a620 // smmla v0.4s, v17.16b, v21.16b\n" - ".inst 0x4e95a69c // smmla v28.4s, v20.16b, v21.16b\n" - ".inst 0x4e93a7e7 // smmla v7.4s, v31.16b, v19.16b\n" - ".inst 0x4e93a6e2 // smmla v2.4s, v23.16b, v19.16b\n" - ".inst 0x4e96a74a // smmla v10.4s, v26.16b, v22.16b\n" - ".inst 0x4e96a4d9 // smmla v25.4s, v6.16b, v22.16b\n" - ".inst 0x4e89a7e3 // smmla v3.4s, v31.16b, v9.16b\n" - ".inst 0x4e89a6e4 // smmla v4.4s, v23.16b, v9.16b\n" - ".inst 0x4e9da740 // smmla v0.4s, v26.16b, v29.16b\n" - ".inst 0x4e9da4dc // smmla v28.4s, v6.16b, v29.16b\n" - ".inst 0x4e92a7ea // smmla v10.4s, v31.16b, v18.16b\n" - ".inst 0x4e92a6f9 // smmla v25.4s, v23.16b, v18.16b\n" - ".inst 0x4e98a7e0 // smmla v0.4s, v31.16b, v24.16b\n" - ".inst 0x4e98a6fc // smmla v28.4s, v23.16b, v24.16b\n" + ".inst 0x4e98a6a0 // smmla v0.4s, v21.16b, v24.16b\n" + ".inst 0x4e98a7c4 // smmla v4.4s, v30.16b, v24.16b\n" + "shl v24.16b, v28.16b, #0x4\n" + "and v28.16b, v28.16b, v17.16b\n" + ".inst 0x4e98a426 // smmla v6.4s, v1.16b, v24.16b\n" + ".inst 0x4e98a696 // smmla v22.4s, v20.16b, v24.16b\n" + "shl v24.16b, v25.16b, #0x4\n" + "and v25.16b, v25.16b, v17.16b\n" + ".inst 0x4e93a740 // smmla v0.4s, v26.16b, v19.16b\n" + ".inst 0x4e93a464 // smmla v4.4s, v3.16b, v19.16b\n" + "shl v19.16b, v10.16b, #0x4\n" + "and v10.16b, v10.16b, v17.16b\n" + ".inst 0x4e98a425 // smmla v5.4s, v1.16b, v24.16b\n" + ".inst 0x4e98a687 // smmla v7.4s, v20.16b, v24.16b\n" + ".inst 0x4e97a6a6 // smmla v6.4s, v21.16b, v23.16b\n" + ".inst 0x4e97a7d6 // smmla v22.4s, v30.16b, v23.16b\n" + ".inst 0x4e93a428 // smmla v8.4s, v1.16b, v19.16b\n" + ".inst 0x4e93a682 // smmla v2.4s, v20.16b, v19.16b\n" + ".inst 0x4e9fa6a5 // smmla v5.4s, v21.16b, v31.16b\n" + ".inst 0x4e9fa7c7 // smmla v7.4s, v30.16b, v31.16b\n" + ".inst 0x4e9ca746 // smmla v6.4s, v26.16b, v28.16b\n" + ".inst 0x4e9ca476 // smmla v22.4s, v3.16b, v28.16b\n" + ".inst 0x4e90a6a8 // smmla v8.4s, v21.16b, v16.16b\n" + ".inst 0x4e90a7c2 // smmla v2.4s, v30.16b, v16.16b\n" + ".inst 0x4e99a745 // smmla v5.4s, v26.16b, v25.16b\n" + ".inst 0x4e99a467 // smmla v7.4s, v3.16b, v25.16b\n" + ".inst 0x4e8aa748 // smmla v8.4s, v26.16b, v10.16b\n" + ".inst 0x4e8aa462 // smmla v2.4s, v3.16b, v10.16b\n" "bgt 4b\n" - "ldr q31, [x26, #0x0]\n" - "ldr q20, [x26, #0x10]\n" - "uzp1 v21.2d, v7.2d, v3.2d\n" - "uzp2 v23.2d, v7.2d, v3.2d\n" - "uzp1 v22.2d, v10.2d, v0.2d\n" - "uzp2 v30.2d, v10.2d, v0.2d\n" + "ldr q26, [x26, #0x0]\n" + "ldr q28, [x26, #0x10]\n" + "uzp1 v25.2d, v0.2d, v6.2d\n" + "uzp2 v24.2d, v0.2d, v6.2d\n" + "uzp1 v23.2d, v5.2d, v8.2d\n" + "uzp2 v21.2d, v5.2d, v8.2d\n" "add x26, x26, #0x20\n" - "uzp1 v7.2d, v2.2d, v4.2d\n" - "uzp2 v18.2d, v2.2d, v4.2d\n" - "uzp1 v17.2d, v25.2d, v28.2d\n" - "uzp2 v16.2d, v25.2d, v28.2d\n" - "scvtf v21.4s, v21.4s\n" - "scvtf v22.4s, v22.4s\n" + "uzp1 v20.2d, v4.2d, v22.2d\n" + "uzp2 v19.2d, v4.2d, v22.2d\n" + "uzp1 v3.2d, v7.2d, v2.2d\n" + "uzp2 v4.2d, v7.2d, v2.2d\n" + "fmul v26.4s, v26.4s, v15.4s\n" + "fmul v28.4s, v28.4s, v15.4s\n" + "scvtf v25.4s, v25.4s\n" "scvtf v23.4s, v23.4s\n" - "scvtf v30.4s, v30.4s\n" - "scvtf v7.4s, v7.4s\n" - "scvtf v17.4s, v17.4s\n" - "scvtf v18.4s, v18.4s\n" - "scvtf v16.4s, v16.4s\n" - "fmla v27.4s, v21.4s, v31.4s\n" - "fmla v12.4s, v22.4s, v20.4s\n" - "fmla v11.4s, v23.4s, v31.4s\n" - "fmla v15.4s, v30.4s, v20.4s\n" - "fmla v13.4s, v7.4s, v31.4s\n" - "fmla v5.4s, v17.4s, v20.4s\n" - "fmla v8.4s, v18.4s, v31.4s\n" - "fmla v1.4s, v16.4s, v20.4s\n" + "scvtf v24.4s, v24.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" + "scvtf v3.4s, v3.4s\n" + "scvtf v19.4s, v19.4s\n" + "scvtf v4.4s, v4.4s\n" + "fmla v14.4s, v25.4s, v26.4s\n" + "fmla v12.4s, v23.4s, v28.4s\n" + "fmla v11.4s, v24.4s, v26.4s\n" + "fmla v13.4s, v21.4s, v28.4s\n" + "fmla v18.4s, v20.4s, v26.4s\n" + "fmla v27.4s, v3.4s, v28.4s\n" + "fmla v29.4s, v19.4s, v26.4s\n" + "fmla v9.4s, v4.4s, v28.4s\n" "subs x21, x21, #0x1\n" "bgt 3b\n" "ld1 { v23.4s }, [x22]\n" "ldr q22, [x26, #0x0]\n" "add x22, x22, #0x10\n" "add x20, %x[clamp_vals], #0x4\n" - "ldr q2, [x26, #0x10]\n" + "ldr q21, [x26, #0x10]\n" "ldr q20, [x22, #0x0]\n" "cmp x25, #0x8\n" "ldr q19, [x26, #0x20]\n" - "ldr q18, [x26, #0x30]\n" + "ldr q25, [x26, #0x30]\n" "add x26, x26, #0x40\n" - "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "ld1r { v16.4s }, [x20]\n" + "ld1r { v31.4s }, [%x[clamp_vals]]\n" + "ld1r { v3.4s }, [x20]\n" "scvtf v23.4s, v23.4s\n" - "fmla v27.4s, v22.4s, v23.s[0]\n" - "fmla v12.4s, v2.4s, v23.s[0]\n" + "fmla v14.4s, v22.4s, v23.s[0]\n" + "fmla v12.4s, v21.4s, v23.s[0]\n" "fmla v11.4s, v22.4s, v23.s[1]\n" - "fmla v15.4s, v2.4s, v23.s[1]\n" - "fmla v13.4s, v22.4s, v23.s[2]\n" - "fmla v5.4s, v2.4s, v23.s[2]\n" - "fmla v8.4s, v22.4s, v23.s[3]\n" - "fmla v1.4s, v2.4s, v23.s[3]\n" - "fmul v27.4s, v27.4s, v20.s[0]\n" + "fmla v13.4s, v21.4s, v23.s[1]\n" + "fmla v18.4s, v22.4s, v23.s[2]\n" + "fmla v27.4s, v21.4s, v23.s[2]\n" + "fmla v29.4s, v22.4s, v23.s[3]\n" + "fmla v9.4s, v21.4s, v23.s[3]\n" + "fmul v14.4s, v14.4s, v20.s[0]\n" "fmul v12.4s, v12.4s, v20.s[0]\n" "fmul v11.4s, v11.4s, v20.s[1]\n" - "fmul v15.4s, v15.4s, v20.s[1]\n" - "fmul v13.4s, v13.4s, v20.s[2]\n" - "fmul v5.4s, v5.4s, v20.s[2]\n" - "fmul v8.4s, v8.4s, v20.s[3]\n" - "fmul v1.4s, v1.4s, v20.s[3]\n" - "fadd v27.4s, v27.4s, v19.4s\n" - "fadd v12.4s, v12.4s, v18.4s\n" + "fmul v13.4s, v13.4s, v20.s[1]\n" + "fmul v18.4s, v18.4s, v20.s[2]\n" + "fmul v27.4s, v27.4s, v20.s[2]\n" + "fmul v29.4s, v29.4s, v20.s[3]\n" + "fmul v9.4s, v9.4s, v20.s[3]\n" + "fadd v14.4s, v14.4s, v19.4s\n" + "fadd v12.4s, v12.4s, v25.4s\n" "fadd v11.4s, v11.4s, v19.4s\n" - "fadd v15.4s, v15.4s, v18.4s\n" - "fadd v13.4s, v13.4s, v19.4s\n" - "fadd v5.4s, v5.4s, v18.4s\n" - "fadd v8.4s, v8.4s, v19.4s\n" - "fadd v1.4s, v1.4s, v18.4s\n" - "fmax v27.4s, v27.4s, v17.4s\n" - "fmax v12.4s, v12.4s, v17.4s\n" - "fmax v11.4s, v11.4s, v17.4s\n" - "fmax v15.4s, v15.4s, v17.4s\n" - "fmax v13.4s, v13.4s, v17.4s\n" - "fmax v5.4s, v5.4s, v17.4s\n" - "fmax v8.4s, v8.4s, v17.4s\n" - "fmax v1.4s, v1.4s, v17.4s\n" - "fmin v27.4s, v27.4s, v16.4s\n" - "fmin v12.4s, v12.4s, v16.4s\n" - "fmin v11.4s, v11.4s, v16.4s\n" - "fmin v15.4s, v15.4s, v16.4s\n" - "fmin v13.4s, v13.4s, v16.4s\n" - "fmin v5.4s, v5.4s, v16.4s\n" - "fmin v8.4s, v8.4s, v16.4s\n" - "fmin v1.4s, v1.4s, v16.4s\n" + "fadd v13.4s, v13.4s, v25.4s\n" + "fadd v18.4s, v18.4s, v19.4s\n" + "fadd v27.4s, v27.4s, v25.4s\n" + "fadd v29.4s, v29.4s, v19.4s\n" + "fadd v9.4s, v9.4s, v25.4s\n" + "fmax v14.4s, v14.4s, v31.4s\n" + "fmax v12.4s, v12.4s, v31.4s\n" + "fmax v11.4s, v11.4s, v31.4s\n" + "fmax v13.4s, v13.4s, v31.4s\n" + "fmax v18.4s, v18.4s, v31.4s\n" + "fmax v27.4s, v27.4s, v31.4s\n" + "fmax v29.4s, v29.4s, v31.4s\n" + "fmax v9.4s, v9.4s, v31.4s\n" + "fmin v14.4s, v14.4s, v3.4s\n" + "fmin v12.4s, v12.4s, v3.4s\n" + "fmin v11.4s, v11.4s, v3.4s\n" + "fmin v13.4s, v13.4s, v3.4s\n" + "fmin v18.4s, v18.4s, v3.4s\n" + "fmin v27.4s, v27.4s, v3.4s\n" + "fmin v29.4s, v29.4s, v3.4s\n" + "fmin v9.4s, v9.4s, v3.4s\n" "blt 6f\n" "mov x20, %x[dst]\n" "cmp x27, #0x1\n" - "str q27, [x20, #0x0]\n" + "str q14, [x20, #0x0]\n" "str q12, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 11f\n" "cmp x27, #0x2\n" "str q11, [x20, #0x0]\n" - "str q15, [x20, #0x10]\n" + "str q13, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 11f\n" "cmp x27, #0x3\n" - "str q13, [x20, #0x0]\n" - "str q5, [x20, #0x10]\n" + "str q18, [x20, #0x0]\n" + "str q27, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 11f\n" - "str q8, [x20, #0x0]\n" - "str q1, [x20, #0x10]\n" + "str q29, [x20, #0x0]\n" + "str q9, [x20, #0x10]\n" "b 11f\n" "6:" // Partial output "mov x23, %x[dst]\n" "cmp x27, #0x1\n" "add x22, x23, %x[dst_stride_row]\n" - "csel x22, x22, x23, GE\n" + "csel x22, x22, x23, GT\n" "cmp x27, #0x2\n" "add x21, x23, %x[dst_stride_row], LSL #1\n" - "csel x21, x21, x22, GE\n" + "csel x21, x21, x22, GT\n" "cmp x27, #0x3\n" "add x20, x21, %x[dst_stride_row]\n" - "csel x20, x20, x21, GE\n" + "csel x20, x20, x21, GT\n" "tbz x25, #2, 8f\n" - "st1 { v8.4s }, [x20], #0x10\n" - "st1 { v13.4s }, [x21], #0x10\n" + "st1 { v29.4s }, [x20], #0x10\n" + "st1 { v18.4s }, [x21], #0x10\n" "st1 { v11.4s }, [x22], #0x10\n" - "st1 { v27.4s }, [x23], #0x10\n" + "st1 { v14.4s }, [x23], #0x10\n" "tbz x25, #1, 7f\n" - "st1 { v1.d }[0], [x20], #0x8\n" - "st1 { v5.d }[0], [x21], #0x8\n" - "st1 { v15.d }[0], [x22], #0x8\n" + "st1 { v9.d }[0], [x20], #0x8\n" + "st1 { v27.d }[0], [x21], #0x8\n" + "st1 { v13.d }[0], [x22], #0x8\n" "st1 { v12.d }[0], [x23], #0x8\n" "tbz x25, #0, 10f\n" - "st1 { v1.s }[2], [x20]\n" - "st1 { v5.s }[2], [x21]\n" - "st1 { v15.s }[2], [x22]\n" + "st1 { v9.s }[2], [x20]\n" + "st1 { v27.s }[2], [x21]\n" + "st1 { v13.s }[2], [x22]\n" "st1 { v12.s }[2], [x23]\n" "b 10f\n" "7:" // Output block 0: partial_1_4 "tbz x25, #0, 10f\n" - "st1 { v1.s }[0], [x20]\n" - "st1 { v5.s }[0], [x21]\n" - "st1 { v15.s }[0], [x22]\n" + "st1 { v9.s }[0], [x20]\n" + "st1 { v27.s }[0], [x21]\n" + "st1 { v13.s }[0], [x22]\n" "st1 { v12.s }[0], [x23]\n" "b 10f\n" "8:" // Output block 0: partial_2_0 "tbz x25, #1, 9f\n" - "st1 { v8.d }[0], [x20], #0x8\n" - "st1 { v13.d }[0], [x21], #0x8\n" + "st1 { v29.d }[0], [x20], #0x8\n" + "st1 { v18.d }[0], [x21], #0x8\n" "st1 { v11.d }[0], [x22], #0x8\n" - "st1 { v27.d }[0], [x23], #0x8\n" + "st1 { v14.d }[0], [x23], #0x8\n" "tbz x25, #0, 10f\n" - "st1 { v8.s }[2], [x20]\n" - "st1 { v13.s }[2], [x21]\n" + "st1 { v29.s }[2], [x20]\n" + "st1 { v18.s }[2], [x21]\n" "st1 { v11.s }[2], [x22]\n" - "st1 { v27.s }[2], [x23]\n" + "st1 { v14.s }[2], [x23]\n" "b 10f\n" "9:" // Output block 0: partial_1_0 - "st1 { v8.s }[0], [x20]\n" - "st1 { v13.s }[0], [x21]\n" + "st1 { v29.s }[0], [x20]\n" + "st1 { v18.s }[0], [x21]\n" "st1 { v11.s }[0], [x22]\n" - "st1 { v27.s }[0], [x23]\n" + "st1 { v14.s }[0], [x23]\n" "10:" // Output block 0: Done "11:" // Output stage exit "subs x25, x25, #0x8\n" diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c index 99366d64..b43cd2b3 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c @@ -35,12 +35,13 @@ inline static size_t kai_rhs_stride(size_t k, size_t bl) { return num_bytes_per_block * num_blocks_per_row; } -inline static size_t kai_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { +inline static size_t kai_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl, enum kai_datatype scale_dt) { KAI_ASSERT((k % 2) == 0); KAI_ASSERT((k % kr) == 0); KAI_ASSERT((k % bl) == 0); KAI_ASSERT((bl % kr) == 0); + const size_t num_bytes_multiplier_rhs = kai_num_bytes_datatype(scale_dt); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); @@ -74,8 +75,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( KAI_UNUSED(kr); - const size_t num_bytes_scale = kai_num_bytes_datatype(scale_dt); - return (n_idx / nr) * kai_rhs_packed_stride(k, nr, kr, bl, num_bytes_scale); + return (n_idx / nr) * kai_rhs_packed_stride(k, nr, kr, bl, scale_dt); } size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( @@ -88,7 +88,6 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( KAI_UNUSED(kr); - const size_t num_bytes_scale = kai_num_bytes_datatype(scale_dt); const size_t num_rows = n / nr; return num_rows * kai_rhs_packed_stride(k, nr, kr, bl, scale_dt); @@ -122,7 +121,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( const size_t num_bytes_multiplier_rhs = kai_num_bytes_datatype(params->scale_dt); const size_t rhs_zero_point = params->rhs_zero_point; const size_t rhs_stride = kai_rhs_stride(k, bl); - const size_t rhs_packed_stride = kai_rhs_packed_stride(k, nr, kr, bl, num_bytes_multiplier_rhs); + const size_t rhs_packed_stride = kai_rhs_packed_stride(k, nr, kr, bl, params->scale_dt); const size_t rhs_packed_offset_end_of_all_blocks = kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); @@ -184,17 +183,6 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( dst_row += num_bytes_per_segment * nr; } - for (size_t i = 0; i < nr; ++i) { - if (scale_dt == F32) { - ((float*)dst_scales)[i] *= 0.0625F; - } else if (scale_dt == F16) { - const float d = kai_f16_to_f32(((uint16_t*)dst_scales)[i]); - ((float*)dst_scales)[i] = kai_f32_to_f16(d * 0.0625F); - } else { - KAI_ERROR("Unsupported scale data type"); - } - } - dst_row += (num_bytes_multiplier_rhs * nr); } diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h index 6f9291f5..9132d9be 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h @@ -69,6 +69,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple /// of 32. +/// @param[in] scale_dt Block scale data type /// /// @return the packed RHS matrix size in bytes size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( @@ -99,6 +100,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( /// @param[in] bias The biases. /// @param[in] scale The per-block quantization scales. /// The scale data type must be proviided with the params object. +/// Supported scale data types are FP32, FP16 and BF16. /// @param[out] rhs_packed The packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. /// @param[in] params Parameters for the micro-kernel. -- GitLab From b545c6765d45f39084261ea1f0d4e1bedfba3ea2 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Wed, 14 Aug 2024 17:00:33 +0100 Subject: [PATCH 08/29] Update the matmul ukernels to support Bf16 quantization scale parameters Signed-off-by: Anitha Raj --- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 30 +- kai/kai_common.h | 5 +- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c | 9 +- ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c | 13 +- ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c | 524 +++++++++--------- ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c | 295 +++++----- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c | 22 +- 7 files changed, 455 insertions(+), 443 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index f73d92ec..a24535cf 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -94,6 +94,11 @@ static inline size_t num_bytes_per_block(size_t bl) { return (bl / 2) + sizeof(float); } +static inline uint16_t kai_f32_to_bf16(float f32) { + const uint32_t* i32 = reinterpret_cast(&f32); + uint16_t bf16 = (*i32 >> 16); + return bf16; +} static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, size_t seed) { std::srand(seed); @@ -105,7 +110,7 @@ static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, si static void quant_qs4c32_f32( size_t n, size_t k, size_t bl, const float* rhs_f32, uint8_t* rhs_qs4c32, uint8_t* rhs_with_no_scale_qs4c32, - float* rhs_scales_f32) { + uint16_t* rhs_scales_bf16) { const size_t num_blocks_row = num_blocks_per_row(k, bl); const size_t num_bytes_block = num_bytes_per_block(bl); const size_t dst_stride = num_blocks_row * num_bytes_block; @@ -135,11 +140,12 @@ static void quant_qs4c32_f32( const float recip_scale = scale ? 1.0f / scale : 0.0f; // Store the scale at the beginning of the block - *((float*)dst_ptr) = scale; - *rhs_scales_f32 = scale; + uint16_t bf16_scale = kai_f32_to_bf16(scale); + *((float*)dst_ptr) = kai_bf16_to_f32(bf16_scale); + *rhs_scales_bf16 = bf16_scale; dst_ptr += sizeof(float); - rhs_scales_f32 += 1; + rhs_scales_bf16 += 1; const size_t block_size = 32; const size_t num_subblocks = bl / 32; @@ -167,7 +173,7 @@ static void quant_qs4c32_f32( } } } -}; +} static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { const size_t dst_stride = (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); @@ -235,7 +241,7 @@ static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t dst_ptr += sizeof(int8_t); } } -}; +} static void ref_matmul_f32_qa8dx_qs4c32( size_t m, size_t n, size_t k, size_t bl, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4c32, float* dst_f32, @@ -343,21 +349,21 @@ int main(int argc, char** argv) { const size_t rhs_native_size_f32 = n * k * sizeof(float); const size_t rhs_native_size_qs4c32 = n * num_blocks_per_row * num_byte_per_block; const size_t rhs_native_with_no_scale_size_qs4c32 = n * num_blocks_per_row * (bl / 2); - const size_t rhs_scales_size_f32 = n * num_blocks_per_row * sizeof(float); + const size_t rhs_scales_size_bf16 = n * num_blocks_per_row * sizeof(uint16_t); // Allocate the memory uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; uint8_t* rhs_native_mtx_qs4c32 = new uint8_t[rhs_native_size_qs4c32]; uint8_t* rhs_native_with_no_scale_mtx_qs4c32 = new uint8_t[rhs_native_with_no_scale_size_qs4c32]; - uint8_t* rhs_scales_mtx_f32 = new uint8_t[rhs_scales_size_f32]; + uint8_t* rhs_scales_mtx_bf16 = new uint8_t[rhs_scales_size_bf16]; fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); quant_qs4c32_f32( n, k, bl, (const float*)rhs_native_mtx_f32, (uint8_t*)rhs_native_mtx_qs4c32, - rhs_native_with_no_scale_mtx_qs4c32, (float*)rhs_scales_mtx_f32); + rhs_native_with_no_scale_mtx_qs4c32, (uint16_t*)rhs_scales_mtx_bf16); delete[] rhs_native_mtx_f32; @@ -403,7 +409,7 @@ int main(int argc, char** argv) { // Get the size in bytes for the packed matrices const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); const size_t rhs_packed_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(n, k, nr, kr, bl, F32); + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(n, k, nr, kr, bl, Bf16); const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); // Allocate the matrices @@ -418,7 +424,7 @@ int main(int argc, char** argv) { struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; - params.scale_dt = F32; + params.scale_dt = Bf16; // RHS packing kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( @@ -427,7 +433,7 @@ int main(int argc, char** argv) { bl, // Block length (const uint8_t*)(rhs_native_with_no_scale_mtx_qs4c32), // RHS NULL, // Bias - rhs_scales_mtx_f32, // Scale + rhs_scales_mtx_bf16, // Scale rhs_packed_mtx_qs4cx, // RHS packed 0, ¶ms); diff --git a/kai/kai_common.h b/kai/kai_common.h index 58107ae1..790c2d26 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -91,7 +91,10 @@ inline static float kai_cast_f32_f16(uint16_t f16) { /// /// @return the f32 value inline static float kai_bf16_to_f32(uint16_t bf16) { - return bf16 << 16; + const uint32_t i32 = (bf16 << 16); + float f32; + memcpy(&f32, &i32, sizeof(i32)); + return f32; } /// Converts a scalar f32 value to f16 diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c index 34161786..38d496a3 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c @@ -21,7 +21,7 @@ static const size_t kai_kr = 16; static const size_t kai_sr = 2; static const size_t kai_k0 = 32; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); -static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); @@ -170,12 +170,13 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( ".inst 0x4e92971d // sdot v29.4s, v24.16b, v18.16b\n" ".inst 0x4e9296fc // sdot v28.4s, v23.16b, v18.16b\n" "bgt 4b\n" - "ldr q16, [x25, #0x0]\n" + "ldr d16, [x25, #0x0]\n" "addp v29.4s, v29.4s, v28.4s\n" "sub x21, x21, #0x1\n" - "add x25, x25, #0x10\n" - "fmul v16.4s, v16.4s, v31.4s\n" + "add x25, x25, #0x8\n" + "shll v16.4s, v16.4h, #0x10\n" "scvtf v29.4s, v29.4s\n" + "fmul v16.4s, v16.4s, v31.4s\n" "fmla v30.4s, v29.4s, v16.4s\n" "cbnz x21, 3b\n" "ld1r { v21.4s }, [x22]\n" diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c index 92e44e02..46147992 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c @@ -21,7 +21,7 @@ static const size_t kai_kr = 16; static const size_t kai_sr = 2; static const size_t kai_k0 = 32; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); -static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); @@ -197,16 +197,17 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( ".inst 0x4e939722 // sdot v2.4s, v25.16b, v19.16b\n" ".inst 0x4e9396e1 // sdot v1.4s, v23.16b, v19.16b\n" "bgt 4b\n" - "ldr q17, [x25, #0x0]\n" - "ldr q16, [x25, #0x10]\n" + "ldr q16, [x25, #0x0]\n" "addp v4.4s, v4.4s, v3.4s\n" "addp v2.4s, v2.4s, v1.4s\n" "sub x21, x21, #0x1\n" - "add x25, x25, #0x20\n" - "fmul v17.4s, v17.4s, v7.4s\n" - "fmul v16.4s, v16.4s, v7.4s\n" + "add x25, x25, #0x10\n" + "shll v17.4s, v16.4h, #0x10\n" + "shll2 v16.4s, v16.8h, #0x10\n" "scvtf v4.4s, v4.4s\n" "scvtf v2.4s, v2.4s\n" + "fmul v17.4s, v17.4s, v7.4s\n" + "fmul v16.4s, v16.4s, v7.4s\n" "fmla v6.4s, v4.4s, v17.4s\n" "fmla v5.4s, v2.4s, v16.4s\n" "cbnz x21, 3b\n" diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c index 76b4d1da..91b4f4f2 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -21,7 +21,7 @@ static const size_t kai_kr = 16; static const size_t kai_sr = 2; static const size_t kai_k0 = 32; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); -static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); @@ -128,12 +128,12 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( __asm__ __volatile__( "mov x12, #0x80\n" "mov x11, %x[m]\n" - "movi v14.16b, #0xf0\n" + "movi v15.16b, #0xf0\n" "mov x21, #0x3d800000\n" "mov x20, #0x20\n" "mul x12, %x[num_subblocks], x12\n" "cmp x11, #0x8\n" - "dup v20.4s, w21\n" + "dup v24.4s, w21\n" "madd x12, %x[num_blocks], x12, x20\n" "blt 11f\n" "1:" // Row loop @@ -142,193 +142,194 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" "2:" // Column loop "mov x23, %x[lhs_packed]\n" + "movi v12.16b, #0x0\n" "movi v13.16b, #0x0\n" - "movi v23.16b, #0x0\n" "mov x22, %x[num_blocks]\n" - "movi v29.16b, #0x0\n" - "movi v15.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v14.16b, #0x0\n" "movi v5.16b, #0x0\n" - "movi v3.16b, #0x0\n" - "movi v9.16b, #0x0\n" - "movi v31.16b, #0x0\n" + "movi v0.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v8.16b, #0x0\n" "add x21, x23, x12\n" "3:" // Block loop - "movi v1.4s, #0x0\n" - "movi v4.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" - "movi v10.4s, #0x0\n" - "movi v24.4s, #0x0\n" - "movi v16.4s, #0x0\n" "movi v6.4s, #0x0\n" - "movi v12.4s, #0x0\n" - "movi v0.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v4.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v31.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "movi v23.4s, #0x0\n" "4:" // Sub block loop "ldr q2, [x10, #0x0]\n" - "ldr q17, [x10, #0x10]\n" + "ldr q20, [x10, #0x10]\n" "subs x20, x20, #0x1\n" - "ldr q26, [x23, #0x0]\n" - "ldr q21, [x23, #0x10]\n" - "ldr q19, [x21, #0x0]\n" - "ldr q18, [x21, #0x10]\n" - "ldr q7, [x10, #0x20]\n" - "ldr q28, [x10, #0x30]\n" - "shl v25.16b, v2.16b, #0x4\n" - "shl v11.16b, v17.16b, #0x4\n" - "ldr q30, [x23, #0x20]\n" - "ldr q22, [x23, #0x30]\n" - "and v2.16b, v2.16b, v14.16b\n" - "and v17.16b, v17.16b, v14.16b\n" - "ldr q8, [x21, #0x20]\n" - "ldr q27, [x21, #0x30]\n" + "ldr q25, [x23, #0x0]\n" + "ldr q11, [x23, #0x10]\n" + "ldr q9, [x21, #0x0]\n" + "ldr q19, [x21, #0x10]\n" + "ldr q1, [x10, #0x20]\n" + "ldr q29, [x10, #0x30]\n" + "shl v27.16b, v2.16b, #0x4\n" + "shl v21.16b, v20.16b, #0x4\n" + "ldr q17, [x23, #0x20]\n" + "ldr q26, [x23, #0x30]\n" + "and v2.16b, v2.16b, v15.16b\n" + "and v20.16b, v20.16b, v15.16b\n" + "ldr q28, [x21, #0x20]\n" + "ldr q16, [x21, #0x30]\n" "add x10, x10, #0x40\n" - ".inst 0x4e99a741 // smmla v1.4s, v26.16b, v25.16b\n" - ".inst 0x4e8ba744 // smmla v4.4s, v26.16b, v11.16b\n" - "ldr q26, [x23, #0x40]\n" - ".inst 0x4e99a6aa // smmla v10.4s, v21.16b, v25.16b\n" - ".inst 0x4e8ba6b8 // smmla v24.4s, v21.16b, v11.16b\n" - "ldr q21, [x23, #0x50]\n" - ".inst 0x4e99a670 // smmla v16.4s, v19.16b, v25.16b\n" - ".inst 0x4e8ba666 // smmla v6.4s, v19.16b, v11.16b\n" - "ldr q19, [x21, #0x40]\n" - ".inst 0x4e99a64c // smmla v12.4s, v18.16b, v25.16b\n" - "ldr q25, [x21, #0x50]\n" - ".inst 0x4e8ba640 // smmla v0.4s, v18.16b, v11.16b\n" - "ldr q11, [x23, #0x60]\n" - "shl v18.16b, v7.16b, #0x4\n" - "and v7.16b, v7.16b, v14.16b\n" - ".inst 0x4e92a7c1 // smmla v1.4s, v30.16b, v18.16b\n" - ".inst 0x4e92a6ca // smmla v10.4s, v22.16b, v18.16b\n" - ".inst 0x4e92a510 // smmla v16.4s, v8.16b, v18.16b\n" - ".inst 0x4e92a76c // smmla v12.4s, v27.16b, v18.16b\n" - "ldr q18, [x23, #0x70]\n" + ".inst 0x4e9ba726 // smmla v6.4s, v25.16b, v27.16b\n" + ".inst 0x4e95a72a // smmla v10.4s, v25.16b, v21.16b\n" + "ldr q25, [x23, #0x40]\n" + ".inst 0x4e9ba564 // smmla v4.4s, v11.16b, v27.16b\n" + ".inst 0x4e95a572 // smmla v18.4s, v11.16b, v21.16b\n" + "ldr q11, [x23, #0x50]\n" + ".inst 0x4e9ba53f // smmla v31.4s, v9.16b, v27.16b\n" + ".inst 0x4e95a523 // smmla v3.4s, v9.16b, v21.16b\n" + "ldr q9, [x21, #0x40]\n" + ".inst 0x4e9ba667 // smmla v7.4s, v19.16b, v27.16b\n" + "ldr q27, [x21, #0x50]\n" + ".inst 0x4e95a677 // smmla v23.4s, v19.16b, v21.16b\n" + "ldr q21, [x23, #0x60]\n" + "shl v19.16b, v1.16b, #0x4\n" + "and v1.16b, v1.16b, v15.16b\n" + ".inst 0x4e93a626 // smmla v6.4s, v17.16b, v19.16b\n" + ".inst 0x4e93a744 // smmla v4.4s, v26.16b, v19.16b\n" + ".inst 0x4e93a79f // smmla v31.4s, v28.16b, v19.16b\n" + ".inst 0x4e93a607 // smmla v7.4s, v16.16b, v19.16b\n" + "ldr q19, [x23, #0x70]\n" "add x23, x23, #0x80\n" - ".inst 0x4e82a741 // smmla v1.4s, v26.16b, v2.16b\n" - ".inst 0x4e82a6aa // smmla v10.4s, v21.16b, v2.16b\n" - ".inst 0x4e82a670 // smmla v16.4s, v19.16b, v2.16b\n" - ".inst 0x4e82a72c // smmla v12.4s, v25.16b, v2.16b\n" - "shl v2.16b, v28.16b, #0x4\n" - "and v28.16b, v28.16b, v14.16b\n" - ".inst 0x4e82a7c4 // smmla v4.4s, v30.16b, v2.16b\n" - "ldr q30, [x21, #0x60]\n" - ".inst 0x4e82a6d8 // smmla v24.4s, v22.16b, v2.16b\n" - "ldr q22, [x21, #0x70]\n" + ".inst 0x4e82a726 // smmla v6.4s, v25.16b, v2.16b\n" + ".inst 0x4e82a564 // smmla v4.4s, v11.16b, v2.16b\n" + ".inst 0x4e82a53f // smmla v31.4s, v9.16b, v2.16b\n" + ".inst 0x4e82a767 // smmla v7.4s, v27.16b, v2.16b\n" + "shl v2.16b, v29.16b, #0x4\n" + "and v29.16b, v29.16b, v15.16b\n" + ".inst 0x4e82a62a // smmla v10.4s, v17.16b, v2.16b\n" + "ldr q17, [x21, #0x60]\n" + ".inst 0x4e82a752 // smmla v18.4s, v26.16b, v2.16b\n" + "ldr q26, [x21, #0x70]\n" "add x21, x21, #0x80\n" - ".inst 0x4e82a506 // smmla v6.4s, v8.16b, v2.16b\n" - ".inst 0x4e82a760 // smmla v0.4s, v27.16b, v2.16b\n" - ".inst 0x4e87a561 // smmla v1.4s, v11.16b, v7.16b\n" - ".inst 0x4e87a64a // smmla v10.4s, v18.16b, v7.16b\n" - ".inst 0x4e87a7d0 // smmla v16.4s, v30.16b, v7.16b\n" - ".inst 0x4e91a744 // smmla v4.4s, v26.16b, v17.16b\n" - ".inst 0x4e91a6b8 // smmla v24.4s, v21.16b, v17.16b\n" - ".inst 0x4e87a6cc // smmla v12.4s, v22.16b, v7.16b\n" - ".inst 0x4e91a666 // smmla v6.4s, v19.16b, v17.16b\n" - ".inst 0x4e91a720 // smmla v0.4s, v25.16b, v17.16b\n" - ".inst 0x4e9ca564 // smmla v4.4s, v11.16b, v28.16b\n" - ".inst 0x4e9ca658 // smmla v24.4s, v18.16b, v28.16b\n" - ".inst 0x4e9ca7c6 // smmla v6.4s, v30.16b, v28.16b\n" - ".inst 0x4e9ca6c0 // smmla v0.4s, v22.16b, v28.16b\n" + ".inst 0x4e82a783 // smmla v3.4s, v28.16b, v2.16b\n" + ".inst 0x4e82a617 // smmla v23.4s, v16.16b, v2.16b\n" + ".inst 0x4e81a6a6 // smmla v6.4s, v21.16b, v1.16b\n" + ".inst 0x4e81a664 // smmla v4.4s, v19.16b, v1.16b\n" + ".inst 0x4e81a63f // smmla v31.4s, v17.16b, v1.16b\n" + ".inst 0x4e94a72a // smmla v10.4s, v25.16b, v20.16b\n" + ".inst 0x4e94a572 // smmla v18.4s, v11.16b, v20.16b\n" + ".inst 0x4e81a747 // smmla v7.4s, v26.16b, v1.16b\n" + ".inst 0x4e94a523 // smmla v3.4s, v9.16b, v20.16b\n" + ".inst 0x4e94a777 // smmla v23.4s, v27.16b, v20.16b\n" + ".inst 0x4e9da6aa // smmla v10.4s, v21.16b, v29.16b\n" + ".inst 0x4e9da672 // smmla v18.4s, v19.16b, v29.16b\n" + ".inst 0x4e9da623 // smmla v3.4s, v17.16b, v29.16b\n" + ".inst 0x4e9da757 // smmla v23.4s, v26.16b, v29.16b\n" "bgt 4b\n" - "ldr q19, [x10, #0x0]\n" - "uzp1 v2.2d, v1.2d, v4.2d\n" - "uzp2 v28.2d, v1.2d, v4.2d\n" - "add x10, x10, #0x10\n" - "uzp1 v17.2d, v10.2d, v24.2d\n" - "uzp2 v4.2d, v10.2d, v24.2d\n" - "fmul v19.4s, v19.4s, v20.4s\n" - "scvtf v2.4s, v2.4s\n" - "scvtf v28.4s, v28.4s\n" + "ldr d20, [x10, #0x0]\n" + "uzp1 v21.2d, v6.2d, v10.2d\n" + "uzp2 v19.2d, v6.2d, v10.2d\n" + "add x10, x10, #0x8\n" + "uzp1 v17.2d, v4.2d, v18.2d\n" + "uzp2 v16.2d, v4.2d, v18.2d\n" + "shll v20.4s, v20.4h, #0x10\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v19.4s, v19.4s\n" "scvtf v17.4s, v17.4s\n" - "scvtf v4.4s, v4.4s\n" - "fmla v13.4s, v2.4s, v19.4s\n" - "fmla v23.4s, v28.4s, v19.4s\n" - "fmla v29.4s, v17.4s, v19.4s\n" - "fmla v15.4s, v4.4s, v19.4s\n" - "uzp1 v4.2d, v16.2d, v6.2d\n" - "uzp2 v6.2d, v16.2d, v6.2d\n" - "uzp1 v17.2d, v12.2d, v0.2d\n" - "uzp2 v24.2d, v12.2d, v0.2d\n" - "scvtf v4.4s, v4.4s\n" - "scvtf v6.4s, v6.4s\n" + "scvtf v16.4s, v16.4s\n" + "fmul v20.4s, v20.4s, v24.4s\n" + "fmla v12.4s, v21.4s, v20.4s\n" + "fmla v13.4s, v19.4s, v20.4s\n" + "fmla v22.4s, v17.4s, v20.4s\n" + "fmla v14.4s, v16.4s, v20.4s\n" + "uzp1 v19.2d, v31.2d, v3.2d\n" + "uzp2 v18.2d, v31.2d, v3.2d\n" + "uzp1 v17.2d, v7.2d, v23.2d\n" + "uzp2 v16.2d, v7.2d, v23.2d\n" + "scvtf v19.4s, v19.4s\n" + "scvtf v18.4s, v18.4s\n" "scvtf v17.4s, v17.4s\n" - "scvtf v24.4s, v24.4s\n" - "fmla v5.4s, v4.4s, v19.4s\n" - "fmla v3.4s, v6.4s, v19.4s\n" - "fmla v9.4s, v17.4s, v19.4s\n" - "fmla v31.4s, v24.4s, v19.4s\n" + "scvtf v16.4s, v16.4s\n" + "fmla v5.4s, v19.4s, v20.4s\n" + "fmla v0.4s, v18.4s, v20.4s\n" + "fmla v30.4s, v17.4s, v20.4s\n" + "fmla v8.4s, v16.4s, v20.4s\n" "subs x22, x22, #0x1\n" "bgt 3b\n" - "ld1 { v24.4s }, [x23]\n" - "ld1 { v22.4s }, [x21]\n" + "ld1 { v23.4s }, [x23]\n" + "ld1 { v1.4s }, [x21]\n" "add x23, x23, #0x10\n" "add x21, x21, #0x10\n" "ldr q21, [x10, #0x0]\n" - "ldr q11, [x23, #0x0]\n" + "ldr q20, [x23, #0x0]\n" "add x20, %x[clamp_vals], #0x4\n" "cmp x9, #0x4\n" "ldr q19, [x21, #0x0]\n" - "ldr q0, [x10, #0x10]\n" + "ldr q18, [x10, #0x10]\n" "add x10, x10, #0x20\n" "ld1r { v17.4s }, [%x[clamp_vals]]\n" - "ld1r { v4.4s }, [x20]\n" - "scvtf v24.4s, v24.4s\n" - "scvtf v22.4s, v22.4s\n" - "fmla v13.4s, v21.4s, v24.s[0]\n" - "fmla v23.4s, v21.4s, v24.s[1]\n" - "fmla v29.4s, v21.4s, v24.s[2]\n" - "fmla v15.4s, v21.4s, v24.s[3]\n" - "fmla v5.4s, v21.4s, v22.s[0]\n" - "fmla v3.4s, v21.4s, v22.s[1]\n" - "fmla v9.4s, v21.4s, v22.s[2]\n" - "fmla v31.4s, v21.4s, v22.s[3]\n" - "fmul v13.4s, v13.4s, v11.s[0]\n" - "fmul v23.4s, v23.4s, v11.s[1]\n" - "fmul v29.4s, v29.4s, v11.s[2]\n" - "fmul v15.4s, v15.4s, v11.s[3]\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v1.4s, v1.4s\n" + "fmla v12.4s, v21.4s, v23.s[0]\n" + "fmla v13.4s, v21.4s, v23.s[1]\n" + "fmla v22.4s, v21.4s, v23.s[2]\n" + "fmla v14.4s, v21.4s, v23.s[3]\n" + "fmla v5.4s, v21.4s, v1.s[0]\n" + "fmla v0.4s, v21.4s, v1.s[1]\n" + "fmla v30.4s, v21.4s, v1.s[2]\n" + "fmla v8.4s, v21.4s, v1.s[3]\n" + "fmul v12.4s, v12.4s, v20.s[0]\n" + "fmul v13.4s, v13.4s, v20.s[1]\n" + "fmul v22.4s, v22.4s, v20.s[2]\n" + "fmul v14.4s, v14.4s, v20.s[3]\n" "fmul v5.4s, v5.4s, v19.s[0]\n" - "fmul v3.4s, v3.4s, v19.s[1]\n" - "fadd v13.4s, v13.4s, v0.4s\n" - "fmul v9.4s, v9.4s, v19.s[2]\n" - "fmul v31.4s, v31.4s, v19.s[3]\n" - "fadd v23.4s, v23.4s, v0.4s\n" - "fadd v29.4s, v29.4s, v0.4s\n" - "fadd v15.4s, v15.4s, v0.4s\n" - "fadd v5.4s, v5.4s, v0.4s\n" - "fadd v3.4s, v3.4s, v0.4s\n" - "fadd v9.4s, v9.4s, v0.4s\n" - "fadd v31.4s, v31.4s, v0.4s\n" + "fmul v0.4s, v0.4s, v19.s[1]\n" + "fadd v12.4s, v12.4s, v18.4s\n" + "fmul v30.4s, v30.4s, v19.s[2]\n" + "fmul v8.4s, v8.4s, v19.s[3]\n" + "fadd v13.4s, v13.4s, v18.4s\n" + "fadd v22.4s, v22.4s, v18.4s\n" + "fadd v14.4s, v14.4s, v18.4s\n" + "fadd v5.4s, v5.4s, v18.4s\n" + "fadd v0.4s, v0.4s, v18.4s\n" + "fadd v30.4s, v30.4s, v18.4s\n" + "fadd v8.4s, v8.4s, v18.4s\n" + "fmax v12.4s, v12.4s, v17.4s\n" "fmax v13.4s, v13.4s, v17.4s\n" - "fmax v23.4s, v23.4s, v17.4s\n" - "fmax v29.4s, v29.4s, v17.4s\n" - "fmax v15.4s, v15.4s, v17.4s\n" + "fmax v22.4s, v22.4s, v17.4s\n" + "fmax v14.4s, v14.4s, v17.4s\n" "fmax v5.4s, v5.4s, v17.4s\n" - "fmax v3.4s, v3.4s, v17.4s\n" - "fmax v9.4s, v9.4s, v17.4s\n" - "fmax v31.4s, v31.4s, v17.4s\n" - "fmin v13.4s, v13.4s, v4.4s\n" - "fmin v23.4s, v23.4s, v4.4s\n" - "fmin v29.4s, v29.4s, v4.4s\n" - "fmin v15.4s, v15.4s, v4.4s\n" - "fmin v5.4s, v5.4s, v4.4s\n" - "fmin v3.4s, v3.4s, v4.4s\n" - "fmin v9.4s, v9.4s, v4.4s\n" - "fmin v31.4s, v31.4s, v4.4s\n" + "fmax v0.4s, v0.4s, v17.4s\n" + "fmax v30.4s, v30.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmin v12.4s, v12.4s, v16.4s\n" + "fmin v13.4s, v13.4s, v16.4s\n" + "fmin v22.4s, v22.4s, v16.4s\n" + "fmin v14.4s, v14.4s, v16.4s\n" + "fmin v5.4s, v5.4s, v16.4s\n" + "fmin v0.4s, v0.4s, v16.4s\n" + "fmin v30.4s, v30.4s, v16.4s\n" + "fmin v8.4s, v8.4s, v16.4s\n" "blt 7f\n" "mov x20, %x[dst]\n" - "str q13, [x20, #0x0]\n" + "str q12, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q23, [x20, #0x0]\n" + "str q13, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q29, [x20, #0x0]\n" + "str q22, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q15, [x20, #0x0]\n" + "str q14, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" "str q5, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q3, [x20, #0x0]\n" + "str q0, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q9, [x20, #0x0]\n" + "str q30, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" - "str q31, [x20, #0x0]\n" + "str q8, [x20, #0x0]\n" "b 10f\n" "7:" // Partial output "mov x27, %x[dst]\n" @@ -340,33 +341,33 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( "add x21, x27, %x[dst_stride_row]\n" "add x20, x22, %x[dst_stride_row]\n" "tbz x9, #1, 8f\n" - "st1 { v31.d }[0], [x23], #0x8\n" - "st1 { v9.d }[0], [x25], #0x8\n" - "st1 { v3.d }[0], [x24], #0x8\n" + "st1 { v8.d }[0], [x23], #0x8\n" + "st1 { v30.d }[0], [x25], #0x8\n" + "st1 { v0.d }[0], [x24], #0x8\n" "st1 { v5.d }[0], [x26], #0x8\n" - "st1 { v15.d }[0], [x20], #0x8\n" - "st1 { v29.d }[0], [x22], #0x8\n" - "st1 { v23.d }[0], [x21], #0x8\n" - "st1 { v13.d }[0], [x27], #0x8\n" + "st1 { v14.d }[0], [x20], #0x8\n" + "st1 { v22.d }[0], [x22], #0x8\n" + "st1 { v13.d }[0], [x21], #0x8\n" + "st1 { v12.d }[0], [x27], #0x8\n" "tbz x9, #0, 9f\n" - "st1 { v31.s }[2], [x23]\n" - "st1 { v9.s }[2], [x25]\n" - "st1 { v3.s }[2], [x24]\n" + "st1 { v8.s }[2], [x23]\n" + "st1 { v30.s }[2], [x25]\n" + "st1 { v0.s }[2], [x24]\n" "st1 { v5.s }[2], [x26]\n" - "st1 { v15.s }[2], [x20]\n" - "st1 { v29.s }[2], [x22]\n" - "st1 { v23.s }[2], [x21]\n" - "st1 { v13.s }[2], [x27]\n" + "st1 { v14.s }[2], [x20]\n" + "st1 { v22.s }[2], [x22]\n" + "st1 { v13.s }[2], [x21]\n" + "st1 { v12.s }[2], [x27]\n" "b 9f\n" "8:" // Output block 0: partial_1_0 - "st1 { v31.s }[0], [x23]\n" - "st1 { v9.s }[0], [x25]\n" - "st1 { v3.s }[0], [x24]\n" + "st1 { v8.s }[0], [x23]\n" + "st1 { v30.s }[0], [x25]\n" + "st1 { v0.s }[0], [x24]\n" "st1 { v5.s }[0], [x26]\n" - "st1 { v15.s }[0], [x20]\n" - "st1 { v29.s }[0], [x22]\n" - "st1 { v23.s }[0], [x21]\n" - "st1 { v13.s }[0], [x27]\n" + "st1 { v14.s }[0], [x20]\n" + "st1 { v22.s }[0], [x22]\n" + "st1 { v13.s }[0], [x21]\n" + "st1 { v12.s }[0], [x27]\n" "9:" // Output block 0: Done "10:" // Output stage exit "subs x9, x9, #0x4\n" @@ -385,78 +386,79 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( "mov x25, %x[n]\n" "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" "13:" // Row tail: Column loop + "movi v12.16b, #0x0\n" "movi v13.16b, #0x0\n" - "movi v23.16b, #0x0\n" "mov x23, %x[lhs_packed]\n" "mov x21, %x[num_blocks]\n" - "movi v29.16b, #0x0\n" - "movi v15.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v14.16b, #0x0\n" "14:" // Row tail: Block loop - "movi v1.4s, #0x0\n" - "movi v4.4s, #0x0\n" - "mov x20, %x[num_subblocks]\n" + "movi v6.4s, #0x0\n" "movi v10.4s, #0x0\n" - "movi v24.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v4.4s, #0x0\n" + "movi v18.4s, #0x0\n" "15:" // Row tail: Sub block loop "ldr q0, [x26, #0x0]\n" "ldr q31, [x26, #0x10]\n" "subs x20, x20, #0x1\n" - "ldr q30, [x23, #0x0]\n" - "ldr q3, [x23, #0x10]\n" - "ldr q8, [x26, #0x20]\n" - "ldr q27, [x26, #0x30]\n" + "ldr q11, [x23, #0x0]\n" + "ldr q30, [x23, #0x10]\n" + "ldr q29, [x26, #0x20]\n" + "ldr q28, [x26, #0x30]\n" "add x26, x26, #0x40\n" - "ldr q26, [x23, #0x20]\n" - "ldr q25, [x23, #0x30]\n" - "shl v22.16b, v0.16b, #0x4\n" - "shl v16.16b, v31.16b, #0x4\n" - "ldr q21, [x23, #0x40]\n" - "ldr q7, [x23, #0x50]\n" - "and v0.16b, v0.16b, v14.16b\n" - "and v31.16b, v31.16b, v14.16b\n" - "ldr q19, [x23, #0x60]\n" - "ldr q18, [x23, #0x70]\n" - "shl v17.16b, v8.16b, #0x4\n" - "shl v2.16b, v27.16b, #0x4\n" - ".inst 0x4e96a7c1 // smmla v1.4s, v30.16b, v22.16b\n" - ".inst 0x4e90a7c4 // smmla v4.4s, v30.16b, v16.16b\n" - "and v8.16b, v8.16b, v14.16b\n" + "ldr q27, [x23, #0x20]\n" + "ldr q26, [x23, #0x30]\n" + "shl v25.16b, v0.16b, #0x4\n" + "shl v23.16b, v31.16b, #0x4\n" + "ldr q1, [x23, #0x40]\n" + "ldr q21, [x23, #0x50]\n" + "and v0.16b, v0.16b, v15.16b\n" + "and v31.16b, v31.16b, v15.16b\n" + "ldr q20, [x23, #0x60]\n" + "ldr q19, [x23, #0x70]\n" + "shl v17.16b, v29.16b, #0x4\n" + "shl v16.16b, v28.16b, #0x4\n" + ".inst 0x4e99a566 // smmla v6.4s, v11.16b, v25.16b\n" + ".inst 0x4e97a56a // smmla v10.4s, v11.16b, v23.16b\n" + "and v29.16b, v29.16b, v15.16b\n" "add x23, x23, #0x80\n" - ".inst 0x4e96a46a // smmla v10.4s, v3.16b, v22.16b\n" - ".inst 0x4e90a478 // smmla v24.4s, v3.16b, v16.16b\n" - "and v27.16b, v27.16b, v14.16b\n" - ".inst 0x4e91a741 // smmla v1.4s, v26.16b, v17.16b\n" - ".inst 0x4e82a744 // smmla v4.4s, v26.16b, v2.16b\n" - ".inst 0x4e91a72a // smmla v10.4s, v25.16b, v17.16b\n" - ".inst 0x4e82a738 // smmla v24.4s, v25.16b, v2.16b\n" - ".inst 0x4e80a6a1 // smmla v1.4s, v21.16b, v0.16b\n" - ".inst 0x4e9fa6a4 // smmla v4.4s, v21.16b, v31.16b\n" - ".inst 0x4e80a4ea // smmla v10.4s, v7.16b, v0.16b\n" - ".inst 0x4e9fa4f8 // smmla v24.4s, v7.16b, v31.16b\n" - ".inst 0x4e88a661 // smmla v1.4s, v19.16b, v8.16b\n" - ".inst 0x4e9ba664 // smmla v4.4s, v19.16b, v27.16b\n" - ".inst 0x4e88a64a // smmla v10.4s, v18.16b, v8.16b\n" - ".inst 0x4e9ba658 // smmla v24.4s, v18.16b, v27.16b\n" + ".inst 0x4e99a7c4 // smmla v4.4s, v30.16b, v25.16b\n" + ".inst 0x4e97a7d2 // smmla v18.4s, v30.16b, v23.16b\n" + "and v28.16b, v28.16b, v15.16b\n" + ".inst 0x4e91a766 // smmla v6.4s, v27.16b, v17.16b\n" + ".inst 0x4e90a76a // smmla v10.4s, v27.16b, v16.16b\n" + ".inst 0x4e91a744 // smmla v4.4s, v26.16b, v17.16b\n" + ".inst 0x4e90a752 // smmla v18.4s, v26.16b, v16.16b\n" + ".inst 0x4e80a426 // smmla v6.4s, v1.16b, v0.16b\n" + ".inst 0x4e9fa42a // smmla v10.4s, v1.16b, v31.16b\n" + ".inst 0x4e80a6a4 // smmla v4.4s, v21.16b, v0.16b\n" + ".inst 0x4e9fa6b2 // smmla v18.4s, v21.16b, v31.16b\n" + ".inst 0x4e9da686 // smmla v6.4s, v20.16b, v29.16b\n" + ".inst 0x4e9ca68a // smmla v10.4s, v20.16b, v28.16b\n" + ".inst 0x4e9da664 // smmla v4.4s, v19.16b, v29.16b\n" + ".inst 0x4e9ca672 // smmla v18.4s, v19.16b, v28.16b\n" "bgt 15b\n" - "ldr q7, [x26, #0x0]\n" - "uzp1 v19.2d, v1.2d, v4.2d\n" - "uzp2 v18.2d, v1.2d, v4.2d\n" - "add x26, x26, #0x10\n" - "uzp1 v17.2d, v10.2d, v24.2d\n" - "uzp2 v16.2d, v10.2d, v24.2d\n" - "fmul v7.4s, v7.4s, v20.4s\n" + "ldr d16, [x26, #0x0]\n" + "uzp1 v21.2d, v6.2d, v10.2d\n" + "uzp2 v20.2d, v6.2d, v10.2d\n" + "add x26, x26, #0x8\n" + "uzp1 v19.2d, v4.2d, v18.2d\n" + "uzp2 v17.2d, v4.2d, v18.2d\n" + "shll v16.4s, v16.4h, #0x10\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" "scvtf v19.4s, v19.4s\n" - "scvtf v18.4s, v18.4s\n" "scvtf v17.4s, v17.4s\n" - "scvtf v16.4s, v16.4s\n" - "fmla v13.4s, v19.4s, v7.4s\n" - "fmla v23.4s, v18.4s, v7.4s\n" - "fmla v29.4s, v17.4s, v7.4s\n" - "fmla v15.4s, v16.4s, v7.4s\n" + "fmul v16.4s, v16.4s, v24.4s\n" + "fmla v12.4s, v21.4s, v16.4s\n" + "fmla v13.4s, v20.4s, v16.4s\n" + "fmla v22.4s, v19.4s, v16.4s\n" + "fmla v14.4s, v17.4s, v16.4s\n" "subs x21, x21, #0x1\n" "bgt 14b\n" "ld1 { v21.4s }, [x23]\n" - "ldr q10, [x26, #0x0]\n" + "ldr q20, [x26, #0x0]\n" "add x23, x23, #0x10\n" "add x20, %x[clamp_vals], #0x4\n" "ldr q19, [x23, #0x0]\n" @@ -466,41 +468,41 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( "ld1r { v17.4s }, [%x[clamp_vals]]\n" "ld1r { v16.4s }, [x20]\n" "scvtf v21.4s, v21.4s\n" - "fmla v13.4s, v10.4s, v21.s[0]\n" - "fmla v23.4s, v10.4s, v21.s[1]\n" - "fmla v29.4s, v10.4s, v21.s[2]\n" - "fmla v15.4s, v10.4s, v21.s[3]\n" - "fmul v13.4s, v13.4s, v19.s[0]\n" - "fmul v23.4s, v23.4s, v19.s[1]\n" - "fmul v29.4s, v29.4s, v19.s[2]\n" + "fmla v12.4s, v20.4s, v21.s[0]\n" + "fmla v13.4s, v20.4s, v21.s[1]\n" + "fmla v22.4s, v20.4s, v21.s[2]\n" + "fmla v14.4s, v20.4s, v21.s[3]\n" + "fmul v12.4s, v12.4s, v19.s[0]\n" + "fmul v13.4s, v13.4s, v19.s[1]\n" + "fmul v22.4s, v22.4s, v19.s[2]\n" + "fadd v12.4s, v12.4s, v18.4s\n" + "fmul v14.4s, v14.4s, v19.s[3]\n" "fadd v13.4s, v13.4s, v18.4s\n" - "fmul v15.4s, v15.4s, v19.s[3]\n" - "fadd v23.4s, v23.4s, v18.4s\n" - "fadd v29.4s, v29.4s, v18.4s\n" - "fadd v15.4s, v15.4s, v18.4s\n" + "fadd v22.4s, v22.4s, v18.4s\n" + "fadd v14.4s, v14.4s, v18.4s\n" + "fmax v12.4s, v12.4s, v17.4s\n" "fmax v13.4s, v13.4s, v17.4s\n" - "fmax v23.4s, v23.4s, v17.4s\n" - "fmax v29.4s, v29.4s, v17.4s\n" - "fmax v15.4s, v15.4s, v17.4s\n" + "fmax v22.4s, v22.4s, v17.4s\n" + "fmax v14.4s, v14.4s, v17.4s\n" + "fmin v12.4s, v12.4s, v16.4s\n" "fmin v13.4s, v13.4s, v16.4s\n" - "fmin v23.4s, v23.4s, v16.4s\n" - "fmin v29.4s, v29.4s, v16.4s\n" - "fmin v15.4s, v15.4s, v16.4s\n" + "fmin v22.4s, v22.4s, v16.4s\n" + "fmin v14.4s, v14.4s, v16.4s\n" "blt 17f\n" "mov x20, %x[dst]\n" "cmp x11, #0x1\n" - "str q13, [x20, #0x0]\n" + "str q12, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 20f\n" "cmp x11, #0x2\n" - "str q23, [x20, #0x0]\n" + "str q13, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 20f\n" "cmp x11, #0x3\n" - "str q29, [x20, #0x0]\n" + "str q22, [x20, #0x0]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 20f\n" - "str q15, [x20, #0x0]\n" + "str q14, [x20, #0x0]\n" "b 20f\n" "17:" // Row tail: Partial output "mov x23, %x[dst]\n" @@ -514,21 +516,21 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( "add x20, x21, %x[dst_stride_row]\n" "csel x20, x20, x21, GT\n" "tbz x25, #1, 18f\n" - "st1 { v15.d }[0], [x20], #0x8\n" - "st1 { v29.d }[0], [x21], #0x8\n" - "st1 { v23.d }[0], [x22], #0x8\n" - "st1 { v13.d }[0], [x23], #0x8\n" + "st1 { v14.d }[0], [x20], #0x8\n" + "st1 { v22.d }[0], [x21], #0x8\n" + "st1 { v13.d }[0], [x22], #0x8\n" + "st1 { v12.d }[0], [x23], #0x8\n" "tbz x25, #0, 19f\n" - "st1 { v15.s }[2], [x20]\n" - "st1 { v29.s }[2], [x21]\n" - "st1 { v23.s }[2], [x22]\n" - "st1 { v13.s }[2], [x23]\n" + "st1 { v14.s }[2], [x20]\n" + "st1 { v22.s }[2], [x21]\n" + "st1 { v13.s }[2], [x22]\n" + "st1 { v12.s }[2], [x23]\n" "b 19f\n" "18:" // Row tail: Output block 0: partial_1_0 - "st1 { v15.s }[0], [x20]\n" - "st1 { v29.s }[0], [x21]\n" - "st1 { v23.s }[0], [x22]\n" - "st1 { v13.s }[0], [x23]\n" + "st1 { v14.s }[0], [x20]\n" + "st1 { v22.s }[0], [x21]\n" + "st1 { v13.s }[0], [x22]\n" + "st1 { v12.s }[0], [x23]\n" "19:" // Row tail: Output block 0: Done "20:" // Row tail: Output stage exit "subs x25, x25, #0x4\n" diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c index e5472d21..b2f41a15 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c @@ -21,7 +21,7 @@ static const size_t kai_kr = 16; static const size_t kai_sr = 2; static const size_t kai_k0 = 32; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); -static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); @@ -132,7 +132,7 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( "mov x20, #0x20\n" "mov x27, %x[m]\n" "mul x28, %x[num_subblocks], x28\n" - "dup v15.4s, w21\n" + "dup v14.4s, w21\n" "madd x28, %x[num_blocks], x28, x20\n" "cbz x27, 12f\n" "1:" // Row loop @@ -140,7 +140,7 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( "mov x25, %x[n]\n" "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" "2:" // Column loop - "movi v14.16b, #0x0\n" + "movi v1.16b, #0x0\n" "movi v12.16b, #0x0\n" "mov x22, %x[lhs_packed]\n" "mov x21, %x[num_blocks]\n" @@ -148,175 +148,176 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( "movi v13.16b, #0x0\n" "movi v18.16b, #0x0\n" "movi v27.16b, #0x0\n" - "movi v29.16b, #0x0\n" - "movi v9.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v4.16b, #0x0\n" "3:" // Block loop - "movi v0.4s, #0x0\n" - "movi v5.4s, #0x0\n" + "movi v21.4s, #0x0\n" + "movi v30.4s, #0x0\n" "mov x20, %x[num_subblocks]\n" - "movi v6.4s, #0x0\n" - "movi v8.4s, #0x0\n" - "movi v4.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v23.4s, #0x0\n" "movi v7.4s, #0x0\n" - "movi v22.4s, #0x0\n" + "movi v3.4s, #0x0\n" "movi v2.4s, #0x0\n" + "movi v8.4s, #0x0\n" "4:" // Sub block loop - "ldr q24, [x26, #0x0]\n" - "ldr q23, [x26, #0x10]\n" + "ldr q6, [x26, #0x0]\n" + "ldr q0, [x26, #0x10]\n" "subs x20, x20, #0x1\n" - "ldr q31, [x26, #0x20]\n" - "ldr q16, [x26, #0x30]\n" - "ldr q20, [x22, #0x0]\n" - "ldr q3, [x22, #0x10]\n" - "ldr q19, [x26, #0x40]\n" - "ldr q28, [x26, #0x50]\n" - "shl v1.16b, v24.16b, #0x4\n" - "shl v21.16b, v23.16b, #0x4\n" + "ldr q10, [x26, #0x20]\n" + "ldr q26, [x26, #0x30]\n" + "ldr q22, [x22, #0x0]\n" + "ldr q20, [x22, #0x10]\n" + "ldr q31, [x26, #0x40]\n" + "ldr q15, [x26, #0x50]\n" + "shl v29.16b, v6.16b, #0x4\n" + "shl v9.16b, v0.16b, #0x4\n" "ldr q25, [x26, #0x60]\n" - "ldr q10, [x26, #0x70]\n" - "shl v30.16b, v31.16b, #0x4\n" - "shl v26.16b, v16.16b, #0x4\n" - "and v24.16b, v24.16b, v17.16b\n" - "and v23.16b, v23.16b, v17.16b\n" + "ldr q16, [x26, #0x70]\n" + "shl v5.16b, v10.16b, #0x4\n" + "shl v19.16b, v26.16b, #0x4\n" + "and v6.16b, v6.16b, v17.16b\n" + "and v0.16b, v0.16b, v17.16b\n" "add x26, x26, #0x80\n" - ".inst 0x4e81a680 // smmla v0.4s, v20.16b, v1.16b\n" - ".inst 0x4e95a686 // smmla v6.4s, v20.16b, v21.16b\n" - ".inst 0x4e81a464 // smmla v4.4s, v3.16b, v1.16b\n" - "ldr q1, [x22, #0x20]\n" + ".inst 0x4e9da6d5 // smmla v21.4s, v22.16b, v29.16b\n" + ".inst 0x4e89a6d8 // smmla v24.4s, v22.16b, v9.16b\n" + ".inst 0x4e9da687 // smmla v7.4s, v20.16b, v29.16b\n" + "ldr q29, [x22, #0x20]\n" + "and v10.16b, v10.16b, v17.16b\n" + ".inst 0x4e85a6de // smmla v30.4s, v22.16b, v5.16b\n" + ".inst 0x4e93a6d7 // smmla v23.4s, v22.16b, v19.16b\n" + "ldr q22, [x22, #0x30]\n" + "and v26.16b, v26.16b, v17.16b\n" + ".inst 0x4e89a682 // smmla v2.4s, v20.16b, v9.16b\n" + "ldr q9, [x22, #0x40]\n" + ".inst 0x4e85a683 // smmla v3.4s, v20.16b, v5.16b\n" + "ldr q5, [x22, #0x50]\n" + ".inst 0x4e93a688 // smmla v8.4s, v20.16b, v19.16b\n" + "ldr q19, [x22, #0x60]\n" + "shl v20.16b, v31.16b, #0x4\n" "and v31.16b, v31.16b, v17.16b\n" - ".inst 0x4e9ea685 // smmla v5.4s, v20.16b, v30.16b\n" - ".inst 0x4e9aa688 // smmla v8.4s, v20.16b, v26.16b\n" - "ldr q20, [x22, #0x30]\n" - "and v16.16b, v16.16b, v17.16b\n" - ".inst 0x4e95a476 // smmla v22.4s, v3.16b, v21.16b\n" - "ldr q21, [x22, #0x40]\n" - ".inst 0x4e9ea467 // smmla v7.4s, v3.16b, v30.16b\n" - "ldr q30, [x22, #0x50]\n" - ".inst 0x4e9aa462 // smmla v2.4s, v3.16b, v26.16b\n" - "ldr q26, [x22, #0x60]\n" - "shl v3.16b, v19.16b, #0x4\n" - "and v19.16b, v19.16b, v17.16b\n" - ".inst 0x4e83a420 // smmla v0.4s, v1.16b, v3.16b\n" - ".inst 0x4e83a684 // smmla v4.4s, v20.16b, v3.16b\n" - "ldr q3, [x22, #0x70]\n" + ".inst 0x4e94a7b5 // smmla v21.4s, v29.16b, v20.16b\n" + ".inst 0x4e94a6c7 // smmla v7.4s, v22.16b, v20.16b\n" + "ldr q20, [x22, #0x70]\n" "add x22, x22, #0x80\n" - ".inst 0x4e98a6a0 // smmla v0.4s, v21.16b, v24.16b\n" - ".inst 0x4e98a7c4 // smmla v4.4s, v30.16b, v24.16b\n" - "shl v24.16b, v28.16b, #0x4\n" - "and v28.16b, v28.16b, v17.16b\n" - ".inst 0x4e98a426 // smmla v6.4s, v1.16b, v24.16b\n" - ".inst 0x4e98a696 // smmla v22.4s, v20.16b, v24.16b\n" - "shl v24.16b, v25.16b, #0x4\n" + ".inst 0x4e86a535 // smmla v21.4s, v9.16b, v6.16b\n" + ".inst 0x4e86a4a7 // smmla v7.4s, v5.16b, v6.16b\n" + "shl v6.16b, v15.16b, #0x4\n" + "and v15.16b, v15.16b, v17.16b\n" + ".inst 0x4e86a7b8 // smmla v24.4s, v29.16b, v6.16b\n" + ".inst 0x4e86a6c2 // smmla v2.4s, v22.16b, v6.16b\n" + "shl v6.16b, v25.16b, #0x4\n" "and v25.16b, v25.16b, v17.16b\n" - ".inst 0x4e93a740 // smmla v0.4s, v26.16b, v19.16b\n" - ".inst 0x4e93a464 // smmla v4.4s, v3.16b, v19.16b\n" - "shl v19.16b, v10.16b, #0x4\n" - "and v10.16b, v10.16b, v17.16b\n" - ".inst 0x4e98a425 // smmla v5.4s, v1.16b, v24.16b\n" - ".inst 0x4e98a687 // smmla v7.4s, v20.16b, v24.16b\n" - ".inst 0x4e97a6a6 // smmla v6.4s, v21.16b, v23.16b\n" - ".inst 0x4e97a7d6 // smmla v22.4s, v30.16b, v23.16b\n" - ".inst 0x4e93a428 // smmla v8.4s, v1.16b, v19.16b\n" - ".inst 0x4e93a682 // smmla v2.4s, v20.16b, v19.16b\n" - ".inst 0x4e9fa6a5 // smmla v5.4s, v21.16b, v31.16b\n" - ".inst 0x4e9fa7c7 // smmla v7.4s, v30.16b, v31.16b\n" - ".inst 0x4e9ca746 // smmla v6.4s, v26.16b, v28.16b\n" - ".inst 0x4e9ca476 // smmla v22.4s, v3.16b, v28.16b\n" - ".inst 0x4e90a6a8 // smmla v8.4s, v21.16b, v16.16b\n" - ".inst 0x4e90a7c2 // smmla v2.4s, v30.16b, v16.16b\n" - ".inst 0x4e99a745 // smmla v5.4s, v26.16b, v25.16b\n" - ".inst 0x4e99a467 // smmla v7.4s, v3.16b, v25.16b\n" - ".inst 0x4e8aa748 // smmla v8.4s, v26.16b, v10.16b\n" - ".inst 0x4e8aa462 // smmla v2.4s, v3.16b, v10.16b\n" + ".inst 0x4e9fa675 // smmla v21.4s, v19.16b, v31.16b\n" + ".inst 0x4e9fa687 // smmla v7.4s, v20.16b, v31.16b\n" + "shl v31.16b, v16.16b, #0x4\n" + "and v16.16b, v16.16b, v17.16b\n" + ".inst 0x4e86a7be // smmla v30.4s, v29.16b, v6.16b\n" + ".inst 0x4e86a6c3 // smmla v3.4s, v22.16b, v6.16b\n" + ".inst 0x4e80a538 // smmla v24.4s, v9.16b, v0.16b\n" + ".inst 0x4e80a4a2 // smmla v2.4s, v5.16b, v0.16b\n" + ".inst 0x4e9fa7b7 // smmla v23.4s, v29.16b, v31.16b\n" + ".inst 0x4e9fa6c8 // smmla v8.4s, v22.16b, v31.16b\n" + ".inst 0x4e8aa53e // smmla v30.4s, v9.16b, v10.16b\n" + ".inst 0x4e8aa4a3 // smmla v3.4s, v5.16b, v10.16b\n" + ".inst 0x4e8fa678 // smmla v24.4s, v19.16b, v15.16b\n" + ".inst 0x4e8fa682 // smmla v2.4s, v20.16b, v15.16b\n" + ".inst 0x4e9aa537 // smmla v23.4s, v9.16b, v26.16b\n" + ".inst 0x4e9aa4a8 // smmla v8.4s, v5.16b, v26.16b\n" + ".inst 0x4e99a67e // smmla v30.4s, v19.16b, v25.16b\n" + ".inst 0x4e99a683 // smmla v3.4s, v20.16b, v25.16b\n" + ".inst 0x4e90a677 // smmla v23.4s, v19.16b, v16.16b\n" + ".inst 0x4e90a688 // smmla v8.4s, v20.16b, v16.16b\n" "bgt 4b\n" - "ldr q26, [x26, #0x0]\n" - "ldr q28, [x26, #0x10]\n" - "uzp1 v25.2d, v0.2d, v6.2d\n" - "uzp2 v24.2d, v0.2d, v6.2d\n" - "uzp1 v23.2d, v5.2d, v8.2d\n" - "uzp2 v21.2d, v5.2d, v8.2d\n" - "add x26, x26, #0x20\n" - "uzp1 v20.2d, v4.2d, v22.2d\n" - "uzp2 v19.2d, v4.2d, v22.2d\n" - "uzp1 v3.2d, v7.2d, v2.2d\n" - "uzp2 v4.2d, v7.2d, v2.2d\n" - "fmul v26.4s, v26.4s, v15.4s\n" - "fmul v28.4s, v28.4s, v15.4s\n" + "ldr q29, [x26, #0x0]\n" + "uzp1 v26.2d, v21.2d, v24.2d\n" + "uzp2 v25.2d, v21.2d, v24.2d\n" + "add x26, x26, #0x10\n" + "uzp1 v24.2d, v30.2d, v23.2d\n" + "uzp2 v23.2d, v30.2d, v23.2d\n" + "uzp1 v22.2d, v7.2d, v2.2d\n" + "uzp2 v21.2d, v7.2d, v2.2d\n" + "shll v20.4s, v29.4h, #0x10\n" + "shll2 v19.4s, v29.8h, #0x10\n" + "uzp1 v0.2d, v3.2d, v8.2d\n" + "uzp2 v8.2d, v3.2d, v8.2d\n" + "scvtf v26.4s, v26.4s\n" + "scvtf v24.4s, v24.4s\n" + "fmul v20.4s, v20.4s, v14.4s\n" + "fmul v19.4s, v19.4s, v14.4s\n" "scvtf v25.4s, v25.4s\n" "scvtf v23.4s, v23.4s\n" - "scvtf v24.4s, v24.4s\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v0.4s, v0.4s\n" "scvtf v21.4s, v21.4s\n" - "scvtf v20.4s, v20.4s\n" - "scvtf v3.4s, v3.4s\n" - "scvtf v19.4s, v19.4s\n" - "scvtf v4.4s, v4.4s\n" - "fmla v14.4s, v25.4s, v26.4s\n" - "fmla v12.4s, v23.4s, v28.4s\n" - "fmla v11.4s, v24.4s, v26.4s\n" - "fmla v13.4s, v21.4s, v28.4s\n" - "fmla v18.4s, v20.4s, v26.4s\n" - "fmla v27.4s, v3.4s, v28.4s\n" - "fmla v29.4s, v19.4s, v26.4s\n" - "fmla v9.4s, v4.4s, v28.4s\n" + "scvtf v8.4s, v8.4s\n" + "fmla v1.4s, v26.4s, v20.4s\n" + "fmla v12.4s, v24.4s, v19.4s\n" + "fmla v11.4s, v25.4s, v20.4s\n" + "fmla v13.4s, v23.4s, v19.4s\n" + "fmla v18.4s, v22.4s, v20.4s\n" + "fmla v27.4s, v0.4s, v19.4s\n" + "fmla v28.4s, v21.4s, v20.4s\n" + "fmla v4.4s, v8.4s, v19.4s\n" "subs x21, x21, #0x1\n" "bgt 3b\n" "ld1 { v23.4s }, [x22]\n" "ldr q22, [x26, #0x0]\n" "add x22, x22, #0x10\n" "add x20, %x[clamp_vals], #0x4\n" - "ldr q21, [x26, #0x10]\n" + "ldr q9, [x26, #0x10]\n" "ldr q20, [x22, #0x0]\n" "cmp x25, #0x8\n" "ldr q19, [x26, #0x20]\n" - "ldr q25, [x26, #0x30]\n" + "ldr q21, [x26, #0x30]\n" "add x26, x26, #0x40\n" - "ld1r { v31.4s }, [%x[clamp_vals]]\n" - "ld1r { v3.4s }, [x20]\n" + "ld1r { v10.4s }, [%x[clamp_vals]]\n" + "ld1r { v30.4s }, [x20]\n" "scvtf v23.4s, v23.4s\n" - "fmla v14.4s, v22.4s, v23.s[0]\n" - "fmla v12.4s, v21.4s, v23.s[0]\n" + "fmla v1.4s, v22.4s, v23.s[0]\n" + "fmla v12.4s, v9.4s, v23.s[0]\n" "fmla v11.4s, v22.4s, v23.s[1]\n" - "fmla v13.4s, v21.4s, v23.s[1]\n" + "fmla v13.4s, v9.4s, v23.s[1]\n" "fmla v18.4s, v22.4s, v23.s[2]\n" - "fmla v27.4s, v21.4s, v23.s[2]\n" - "fmla v29.4s, v22.4s, v23.s[3]\n" - "fmla v9.4s, v21.4s, v23.s[3]\n" - "fmul v14.4s, v14.4s, v20.s[0]\n" + "fmla v27.4s, v9.4s, v23.s[2]\n" + "fmla v28.4s, v22.4s, v23.s[3]\n" + "fmla v4.4s, v9.4s, v23.s[3]\n" + "fmul v1.4s, v1.4s, v20.s[0]\n" "fmul v12.4s, v12.4s, v20.s[0]\n" "fmul v11.4s, v11.4s, v20.s[1]\n" "fmul v13.4s, v13.4s, v20.s[1]\n" "fmul v18.4s, v18.4s, v20.s[2]\n" "fmul v27.4s, v27.4s, v20.s[2]\n" - "fmul v29.4s, v29.4s, v20.s[3]\n" - "fmul v9.4s, v9.4s, v20.s[3]\n" - "fadd v14.4s, v14.4s, v19.4s\n" - "fadd v12.4s, v12.4s, v25.4s\n" + "fmul v28.4s, v28.4s, v20.s[3]\n" + "fmul v4.4s, v4.4s, v20.s[3]\n" + "fadd v1.4s, v1.4s, v19.4s\n" + "fadd v12.4s, v12.4s, v21.4s\n" "fadd v11.4s, v11.4s, v19.4s\n" - "fadd v13.4s, v13.4s, v25.4s\n" + "fadd v13.4s, v13.4s, v21.4s\n" "fadd v18.4s, v18.4s, v19.4s\n" - "fadd v27.4s, v27.4s, v25.4s\n" - "fadd v29.4s, v29.4s, v19.4s\n" - "fadd v9.4s, v9.4s, v25.4s\n" - "fmax v14.4s, v14.4s, v31.4s\n" - "fmax v12.4s, v12.4s, v31.4s\n" - "fmax v11.4s, v11.4s, v31.4s\n" - "fmax v13.4s, v13.4s, v31.4s\n" - "fmax v18.4s, v18.4s, v31.4s\n" - "fmax v27.4s, v27.4s, v31.4s\n" - "fmax v29.4s, v29.4s, v31.4s\n" - "fmax v9.4s, v9.4s, v31.4s\n" - "fmin v14.4s, v14.4s, v3.4s\n" - "fmin v12.4s, v12.4s, v3.4s\n" - "fmin v11.4s, v11.4s, v3.4s\n" - "fmin v13.4s, v13.4s, v3.4s\n" - "fmin v18.4s, v18.4s, v3.4s\n" - "fmin v27.4s, v27.4s, v3.4s\n" - "fmin v29.4s, v29.4s, v3.4s\n" - "fmin v9.4s, v9.4s, v3.4s\n" + "fadd v27.4s, v27.4s, v21.4s\n" + "fadd v28.4s, v28.4s, v19.4s\n" + "fadd v4.4s, v4.4s, v21.4s\n" + "fmax v1.4s, v1.4s, v10.4s\n" + "fmax v12.4s, v12.4s, v10.4s\n" + "fmax v11.4s, v11.4s, v10.4s\n" + "fmax v13.4s, v13.4s, v10.4s\n" + "fmax v18.4s, v18.4s, v10.4s\n" + "fmax v27.4s, v27.4s, v10.4s\n" + "fmax v28.4s, v28.4s, v10.4s\n" + "fmax v4.4s, v4.4s, v10.4s\n" + "fmin v1.4s, v1.4s, v30.4s\n" + "fmin v12.4s, v12.4s, v30.4s\n" + "fmin v11.4s, v11.4s, v30.4s\n" + "fmin v13.4s, v13.4s, v30.4s\n" + "fmin v18.4s, v18.4s, v30.4s\n" + "fmin v27.4s, v27.4s, v30.4s\n" + "fmin v28.4s, v28.4s, v30.4s\n" + "fmin v4.4s, v4.4s, v30.4s\n" "blt 6f\n" "mov x20, %x[dst]\n" "cmp x27, #0x1\n" - "str q14, [x20, #0x0]\n" + "str q1, [x20, #0x0]\n" "str q12, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 11f\n" @@ -330,8 +331,8 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( "str q27, [x20, #0x10]\n" "add x20, x20, %x[dst_stride_row]\n" "ble 11f\n" - "str q29, [x20, #0x0]\n" - "str q9, [x20, #0x10]\n" + "str q28, [x20, #0x0]\n" + "str q4, [x20, #0x10]\n" "b 11f\n" "6:" // Partial output "mov x23, %x[dst]\n" @@ -345,45 +346,45 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( "add x20, x21, %x[dst_stride_row]\n" "csel x20, x20, x21, GT\n" "tbz x25, #2, 8f\n" - "st1 { v29.4s }, [x20], #0x10\n" + "st1 { v28.4s }, [x20], #0x10\n" "st1 { v18.4s }, [x21], #0x10\n" "st1 { v11.4s }, [x22], #0x10\n" - "st1 { v14.4s }, [x23], #0x10\n" + "st1 { v1.4s }, [x23], #0x10\n" "tbz x25, #1, 7f\n" - "st1 { v9.d }[0], [x20], #0x8\n" + "st1 { v4.d }[0], [x20], #0x8\n" "st1 { v27.d }[0], [x21], #0x8\n" "st1 { v13.d }[0], [x22], #0x8\n" "st1 { v12.d }[0], [x23], #0x8\n" "tbz x25, #0, 10f\n" - "st1 { v9.s }[2], [x20]\n" + "st1 { v4.s }[2], [x20]\n" "st1 { v27.s }[2], [x21]\n" "st1 { v13.s }[2], [x22]\n" "st1 { v12.s }[2], [x23]\n" "b 10f\n" "7:" // Output block 0: partial_1_4 "tbz x25, #0, 10f\n" - "st1 { v9.s }[0], [x20]\n" + "st1 { v4.s }[0], [x20]\n" "st1 { v27.s }[0], [x21]\n" "st1 { v13.s }[0], [x22]\n" "st1 { v12.s }[0], [x23]\n" "b 10f\n" "8:" // Output block 0: partial_2_0 "tbz x25, #1, 9f\n" - "st1 { v29.d }[0], [x20], #0x8\n" + "st1 { v28.d }[0], [x20], #0x8\n" "st1 { v18.d }[0], [x21], #0x8\n" "st1 { v11.d }[0], [x22], #0x8\n" - "st1 { v14.d }[0], [x23], #0x8\n" + "st1 { v1.d }[0], [x23], #0x8\n" "tbz x25, #0, 10f\n" - "st1 { v29.s }[2], [x20]\n" + "st1 { v28.s }[2], [x20]\n" "st1 { v18.s }[2], [x21]\n" "st1 { v11.s }[2], [x22]\n" - "st1 { v14.s }[2], [x23]\n" + "st1 { v1.s }[2], [x23]\n" "b 10f\n" "9:" // Output block 0: partial_1_0 - "st1 { v29.s }[0], [x20]\n" + "st1 { v28.s }[0], [x20]\n" "st1 { v18.s }[0], [x21]\n" "st1 { v11.s }[0], [x22]\n" - "st1 { v14.s }[0], [x23]\n" + "st1 { v1.s }[0], [x23]\n" "10:" // Output block 0: Done "11:" // Output stage exit "subs x25, x25, #0x8\n" diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c index b43cd2b3..66c6eee9 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c @@ -153,7 +153,16 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( for (size_t s = 0; s < num_segments_per_block; ++s) { for (size_t i = 0; i < nr; ++i) { memcpy(dst_row + i * num_bytes_per_segment, src_row + i * rhs_stride, num_bytes_per_segment); - + float d = 0.0F; + if (scale_dt == F32) { + d = ((float*)dst_scales)[i]; + } else if (scale_dt == F16) { + d = kai_f16_to_f32(((uint16_t*)dst_scales)[i]); + } else if (scale_dt == Bf16) { + d = kai_bf16_to_f32(((uint16_t*)dst_scales)[i]); + } else { + KAI_ERROR("Unsupported scale data type"); + } for (size_t b = 0; b < num_bytes_per_segment; ++b) { uint8_t qs = dst_row[i * num_bytes_per_segment + b]; @@ -163,17 +172,6 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( // Add offset (0x88) dst_row[i * num_bytes_per_segment + b] = qs ^ 0x88; - float d = 0.0F; - if (scale_dt == F32) { - d = ((float*)dst_scales)[i]; - } else if (scale_dt == F16) { - d = kai_f16_to_f32(((uint16_t*)dst_scales)[i]); - } else if (scale_dt == Bf16) { - d = kai_bf16_to_f32(((uint16_t*)dst_scales)[i]); - } else { - KAI_ERROR("Unsupported scale data type"); - } - dst_sums[i] += x0 * d; dst_sums[i] += x1 * d; } -- GitLab From 00263fa8be936babc755f14c208ef6b774217ec2 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 15 Aug 2024 15:02:06 +0100 Subject: [PATCH 09/29] Add end-to-end tests Signed-off-by: Viet-Hoa Do --- CMakeLists.txt | 3 + .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c | 1 - test/common/bfloat16.hpp | 13 +-- test/reference/matmul.cpp | 8 ++ test/reference/pack.cpp | 12 ++- test/reference/pack.hpp | 39 +++++++ test/reference/quantize.cpp | 5 + ...matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp | 102 ++++++++++++++++++ 8 files changed, 170 insertions(+), 13 deletions(-) create mode 100644 test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index c0d42de6..5451009c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,7 @@ set(KLEIDIAI_FILES_SCALAR kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c ) set(KLEIDIAI_FILES_NEON_FP16 @@ -105,6 +106,7 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c ) set(KLEIDIAI_FILES_SME @@ -194,6 +196,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp + test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp ) target_link_libraries(kleidiai_test diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c index 66c6eee9..7ba3039b 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c @@ -119,7 +119,6 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( // "k" columns and "n" rows (NxK) const size_t num_bytes_multiplier_rhs = kai_num_bytes_datatype(params->scale_dt); - const size_t rhs_zero_point = params->rhs_zero_point; const size_t rhs_stride = kai_rhs_stride(k, bl); const size_t rhs_packed_stride = kai_rhs_packed_stride(k, nr, kr, bl, params->scale_dt); const size_t rhs_packed_offset_end_of_all_blocks = diff --git a/test/common/bfloat16.hpp b/test/common/bfloat16.hpp index 8b61d92e..4fa8026c 100644 --- a/test/common/bfloat16.hpp +++ b/test/common/bfloat16.hpp @@ -38,10 +38,8 @@ public: BFloat16& operator=(BFloat16&&) = default; /// Creates a new object from the specified numeric value. - template , bool> = true> - explicit BFloat16(T value) : _data(0) { - const auto value_f32 = static_cast(value); - asm("bfcvt %h[output], %s[input]" : [output] "=w"(_data) : [input] "w"(value_f32)); + BFloat16(float value) : _data(0) { + asm("bfcvt %h[output], %s[input]" : [output] "=w"(_data) : [input] "w"(value)); } /// Assigns to the specified numeric value which will be converted to `bfloat16_t`. @@ -52,9 +50,8 @@ public: return *this; } - /// Converts to numeric type `T`. - template , bool> = true> - explicit operator T() const { + /// Converts to floating-point. + operator float() const { union { float f32; uint32_t u32; @@ -62,7 +59,7 @@ public: data.u32 = static_cast(_data) << 16; - return static_cast(data.f32); + return data.f32; } /// Equality operator. diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 72d03bb2..976ff049 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -12,6 +12,7 @@ #include #include "kai/kai_common.h" +#include "test/common/bfloat16.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" #include "test/common/float16.hpp" @@ -255,4 +256,11 @@ matmul_clamp_nt_t matmul_clamp_nt_t( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + float min_value, float max_value); + } // namespace kai::test diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 0ff3904a..e4caef26 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -206,7 +206,7 @@ std::vector pack_data_scales_interleave_block( const auto num_quant_packets_x = round_up_multiple(width, quant_width) / quant_width; const auto data_bytes = height * width * size_in_bits / 8; - const auto scales_bytes = height * num_quant_packets_x * sizeof(Scale); + const auto scales_bytes = scales != nullptr ? height * num_quant_packets_x * sizeof(Scale) : 0; std::vector dst(data_bytes + scales_bytes); @@ -215,9 +215,11 @@ std::vector pack_data_scales_interleave_block( for (size_t y = 0; y < height; ++y) { for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { - write_array(dst_ptr, 0, *scales_ptr); - dst_ptr += sizeof(Scale); - ++scales_ptr; + if (scales_ptr != nullptr) { + write_array(dst_ptr, 0, *scales_ptr); + dst_ptr += sizeof(Scale); + ++scales_ptr; + } for (size_t x_element = 0; x_element < quant_width; ++x_element) { const auto x = x_quant + x_element / 2 + (x_element % 2 != 0 ? quant_width / 2 : 0); @@ -235,6 +237,8 @@ std::vector pack_data_scales_interleave_block( template std::vector pack_data_scales_interleave_block( const void* data, const void* scales, size_t height, size_t width, size_t quant_width); +template std::vector pack_data_scales_interleave_block( + const void* data, const void* scales, size_t height, size_t width, size_t quant_width); template std::vector pack_block_data_zero_points_scale_bias( diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp index 8564c810..fd421f72 100644 --- a/test/reference/pack.hpp +++ b/test/reference/pack.hpp @@ -136,4 +136,43 @@ template std::vector pack_data_scales_interleave_block( const void* data, const void* scales, size_t height, size_t width, size_t quant_width); +/// Packs the quantized data with two halves of a block interleaved. +/// +/// ``` +/// Quantized data matrix: +/// +/// --->|-----------------|<--- Block width +/// | | +/// +-----------------+-----------------+----- ... +/// | q00 q01 q02 q03 | q04 q05 q06 q07 | ........ +/// | q10 q11 q12 q13 | q14 q15 q16 q17 | ........ +/// | q20 q21 q22 q23 | q24 q25 q26 q27 | ........ +/// | q30 q31 q32 q33 | q34 q35 q36 q37 | ........ +/// | ............... | ............... | ........ +/// : ............... : ............... : ........ +/// +/// Packed data: +/// +/// +-----------------+-----------------+----- ... +/// | q00 q02 q01 q03 | q04 q06 q05 q07 | ........ +/// | q10 q12 q11 q13 | q14 q16 q15 q17 | ........ +/// | q20 q22 q21 q23 | q24 q26 q25 q27 | ........ +/// | q30 q32 q31 q33 | q34 q36 q35 q37 | ........ +/// | ............... | ............... | ........ +/// : ............... : ............... : ........ +/// ``` +/// +/// @tparam Data The data type of the quantized value. +/// +/// @param[in] data The quantized data. +/// @param[in] height The number of rows. +/// @param[in] width The number of columns. +/// @param[in] block_width The number of columns in a block. +/// +/// @return The packed data buffer. +template +std::vector pack_data_interleave_block(const void* data, size_t height, size_t width, size_t block_width) { + return pack_data_scales_interleave_block(data, nullptr, height, width, block_width); +} + } // namespace kai::test diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp index ad4a450f..23db28dc 100644 --- a/test/reference/quantize.cpp +++ b/test/reference/quantize.cpp @@ -13,6 +13,7 @@ #include #include +#include "test/common/bfloat16.hpp" #include "test/common/int4.hpp" #include "test/common/memory.hpp" #include "test/common/numeric_limits.hpp" @@ -126,6 +127,8 @@ template std::tuple, std::vector> quantize_symmetr const void* src, size_t height, size_t width, size_t quant_width); template std::tuple, std::vector> quantize_symmetric_per_block( const void* src, size_t height, size_t width, size_t quant_width); +template std::tuple, std::vector> quantize_symmetric_per_block( + const void* src, size_t height, size_t width, size_t quant_width); template std::tuple, std::vector> quantize_symmetric_per_block( const void* src, size_t height, size_t width, size_t quant_width); @@ -192,5 +195,7 @@ std::tuple, std::vector, std::vector> qua template std::tuple, std::vector, std::vector> quantize_asymmetric_per_block< float, int8_t, float, int32_t>(const void* src, size_t height, size_t width, size_t quant_width); +template std::tuple, std::vector, std::vector> quantize_asymmetric_per_block< + float, int8_t, BFloat16, int32_t>(const void* src, size_t height, size_t width, size_t quant_width); } // namespace kai::test diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp new file mode 100644 index 00000000..3aab46c2 --- /dev/null +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -0,0 +1,102 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h" +#include "test/common/bfloat16.hpp" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" +#include "test/reference/cast.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/pack.hpp" +#include "test/reference/quantize.hpp" + +namespace kai::test { + +TEST(matmul_clamp_f32_qai8dxp_qsi4c32p, EndTOEnd) { + const uint64_t seed = 0; + + const size_t M = 32; + const size_t N = 64; + const size_t K = 64; + + const auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(); + const auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(); + const auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(); + const auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(); + const size_t bl = 32; + + // Generates input data. + const auto ref_lhs = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit asymmetric quantization. + // * Quantizes the RHS matrix using 4-bit symmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi4, ref_rhs_scales] = + quantize_symmetric_per_block(ref_rhs.data(), N, K, bl); + + const auto ref_dst = matmul_clamp_nt_t( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), + ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), + std::numeric_limits::max()); + + // Runs the LHS packing micro-kernel. + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + std::vector imp_packed_lhs(imp_packed_lhs_size); + kai_run_lhs_quant_pack_qai8dxp_f32( + M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + + // Runs the RHS packing micro-kernel. + // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. + // * Packs the RHS matrix. + const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); + const auto ref_rhs_qsu4_interleaved_block = pack_data_interleave_block(ref_rhs_qsu4.data(), N, K, bl); + + const auto imp_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(N, K, nr, kr, bl, kai_datatype::Bf16); + std::vector imp_packed_rhs(imp_packed_rhs_size); + const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params params{ + .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4_interleaved_block.data(), nullptr, + reinterpret_cast(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, ¶ms); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + std::vector imp_dst(imp_dst_size); + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + M, N, K, bl, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), + N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); + + // Compares the output of the micro-kernels against the output of the reference implementation. + for (size_t y = 0; y < M; ++y) { + for (size_t x = 0; x < N; ++x) { + const auto imp_value = read_array(imp_dst.data(), y * N + x); + const auto ref_value = read_array(ref_dst.data(), y * N + x); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); + + if (rel_error > 0.0001F) { + ASSERT_EQ(imp_value, ref_value); + } + } + } +} + +} // namespace kai::test -- GitLab From 435b830a3e31bc3f7536f60c08a407752bb7a09d Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 29 Aug 2024 10:22:30 +0100 Subject: [PATCH 10/29] Change the native RHS layout - Adjust the RHS packing function to load the K values in sequential order, similar to the per-channel int4 quantization. Previously the K values were stored in blocks of 32 values - Adjust the reference implementation to store the weights and the scale quantization parameters in 2 separate buffers - Adjust the matrix multiplication reference implementation to load the RHS values and scales from 2 different buffers - Remove the restrictions on the N dimension. Signed-off-by: Gian Marco Iodice --- .../CMakeLists.txt | 1 - .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 244 +++++++++--------- ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c | 4 - ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c | 4 - ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c | 4 - .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c | 156 +++++++---- 6 files changed, 227 insertions(+), 186 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt index 4fb9edd5..10fbcc5b 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt @@ -27,4 +27,3 @@ add_executable(matmul_clamp_f32_qai8dxp_qsi4c32p ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c) - diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index a24535cf..2c6a298d 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -27,101 +27,103 @@ #define INT4_MAX (7) // Micro-kernel interface -struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { +struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel; std::string name = {}; }; -kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, +kai_matmul_ukernel_f32_qa8dxp_qs4c32p ukernel_variants[] = { + {{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod}, "matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + {{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}, "matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + {{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm}, "matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + {{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm}, "matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm"}, }; // Number of micro-kernel variants stored in the array const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); -static inline size_t num_blocks_per_row(size_t k, size_t bl) { - return k / bl; +static size_t roundup(size_t a, size_t b) { + return ((a + b - 1) / b) * b; } -static inline size_t num_bytes_per_block(size_t bl) { - return (bl / 2) + sizeof(float); +static inline size_t num_blocks_per_row(size_t k, size_t bl) { + return k / bl; } -static inline uint16_t kai_f32_to_bf16(float f32) { +static inline uint16_t convert_f32_to_bf16(float f32) { const uint32_t* i32 = reinterpret_cast(&f32); uint16_t bf16 = (*i32 >> 16); return bf16; } + +static inline float convert_bf16_to_f32(uint16_t bf16) { + const uint32_t i32 = (bf16 << 16); + float f32; + memcpy(&f32, &i32, sizeof(i32)); + return f32; +} + static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, size_t seed) { std::srand(seed); // Fill the array with random values between -1 and 1 - for (int i = 0; i < num_rows * num_cols; i++) { + for (size_t i = 0; i < num_rows * num_cols; i++) { dst[i] = (float)((double)std::rand() / RAND_MAX) * 2 - 1; } } static void quant_qs4c32_f32( - size_t n, size_t k, size_t bl, const float* rhs_f32, uint8_t* rhs_qs4c32, uint8_t* rhs_with_no_scale_qs4c32, - uint16_t* rhs_scales_bf16) { + size_t n, size_t k, size_t bl, const float* rhs_f32, uint8_t* rhs_qs4c32, uint16_t* rhs_scales_bf16) { + const size_t rhs_qs4c32_stride = (roundup(k, 2) / 2); const size_t num_blocks_row = num_blocks_per_row(k, bl); - const size_t num_bytes_block = num_bytes_per_block(bl); - const size_t dst_stride = num_blocks_row * num_bytes_block; - const size_t dst_with_no_scale_stride = num_blocks_row * (bl / 2); for (size_t row_idx = 0; row_idx < n; ++row_idx) { const float* src_ptr = rhs_f32 + row_idx * k; - uint8_t* dst_ptr = (uint8_t*)rhs_qs4c32 + row_idx * dst_stride; - uint8_t* dst_with_no_scale_ptr = (uint8_t*)rhs_with_no_scale_qs4c32 + row_idx * dst_with_no_scale_stride; - for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { float amax = 0.0f; float max = 0.0f; @@ -139,37 +141,34 @@ static void quant_qs4c32_f32( const float scale = max / -8.0; const float recip_scale = scale ? 1.0f / scale : 0.0f; - // Store the scale at the beginning of the block - uint16_t bf16_scale = kai_f32_to_bf16(scale); - *((float*)dst_ptr) = kai_bf16_to_f32(bf16_scale); - *rhs_scales_bf16 = bf16_scale; + // Store the scale in the dedicated buffer + *rhs_scales_bf16 = convert_f32_to_bf16(scale); - dst_ptr += sizeof(float); rhs_scales_bf16 += 1; - const size_t block_size = 32; - const size_t num_subblocks = bl / 32; - - for (size_t subblock_idx = 0; subblock_idx < num_subblocks; ++subblock_idx) { - for (size_t i = 0; i < block_size / 2; ++i) { - const size_t src_base_addr = block_idx * bl + i + subblock_idx * block_size; - float v0_f32 = src_ptr[src_base_addr]; - float v1_f32 = src_ptr[src_base_addr + block_size / 2]; + for (size_t i = 0; i < bl; ++i) { + const size_t k_idx = block_idx * bl + i; + const float src0_0 = src_ptr[k_idx]; - v0_f32 *= recip_scale; - v1_f32 *= recip_scale; + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * recip_scale)); - const uint8_t v0_u8 = (uint8_t)std::min((int8_t)15, (int8_t)(v0_f32 + 8.0f)); - const uint8_t v1_u8 = (uint8_t)std::min((int8_t)15, (int8_t)(v1_f32 + 8.0f)); + // Maximum/minimum int4 values + v0_s32 = std::max(v0_s32, INT4_MIN); + v0_s32 = std::min(v0_s32, INT4_MAX); - const uint8_t rhs_v0 = (v1_u8 << 4) | v0_u8; + const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8); - dst_ptr[0] = rhs_v0; - dst_ptr += sizeof(uint8_t); + const size_t dst_addr = (k_idx / 2) + row_idx * rhs_qs4c32_stride; + uint8_t rhs_v0 = rhs_qs4c32[dst_addr]; - dst_with_no_scale_ptr[0] = rhs_v0; - dst_with_no_scale_ptr += sizeof(uint8_t); + if ((k_idx % 2) == 0) { + rhs_v0 = v0_u8; + } else { + rhs_v0 |= (v0_u8 << 4); } + + rhs_qs4c32[dst_addr] = rhs_v0; } } } @@ -244,13 +243,12 @@ static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t } static void ref_matmul_f32_qa8dx_qs4c32( - size_t m, size_t n, size_t k, size_t bl, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4c32, float* dst_f32, - float scalar_min, float scalar_max) { + size_t m, size_t n, size_t k, size_t bl, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4c32, + const uint16_t* scale_bf16, float* dst_f32, float scalar_min, float scalar_max) { const size_t num_blocks_row = num_blocks_per_row(k, bl); - const size_t num_bytes_block = num_bytes_per_block(bl); const size_t lhs_stride = k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); - const size_t rhs_stride = num_blocks_row * num_bytes_block; + const size_t rhs_stride = num_blocks_row * (bl / 2); for (size_t row_idx = 0; row_idx < m; ++row_idx) { const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; @@ -271,37 +269,35 @@ static void ref_matmul_f32_qa8dx_qs4c32( lhs_ptr += sizeof(int32_t); for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { - const float rhs_scale = *(const float*)rhs_ptr; - rhs_ptr += sizeof(float); + const uint16_t rhs_scale_bf16 = scale_bf16[block_idx + col_idx * num_blocks_row]; + const float rhs_scale = convert_bf16_to_f32(rhs_scale_bf16); int32_t iacc = 0; - const size_t block_size = 32; - const size_t num_subblocks = bl / 32; + for (size_t i = 0; i < bl; ++i) { + const size_t k_idx = block_idx * bl + i; - for (size_t subblock_idx = 0; subblock_idx < num_subblocks; ++subblock_idx) { - for (size_t i = 0; i < block_size / 2; ++i) { - // Get the LHS values - const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; - const int32_t lhs_v1 = (int32_t)lhs_ptr[block_size / 2]; + // Get the LHS values + const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; - // Get the RHS values - const uint8_t rhs_byte = rhs_ptr[0]; + // Get the RHS values + const uint8_t rhs_byte = rhs_ptr[0]; - // Unpack the RHS values - const int32_t rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8); - const int32_t rhs_v1 = (((int32_t)(rhs_byte >> 4)) - 8); + // Unpack the RHS values + int32_t rhs_v0 = 0; + if ((k_idx % 2) == 0) { + rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8); + } else { + rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8); + } - iacc += lhs_v0 * rhs_v0; - iacc += lhs_v1 * rhs_v1; - iacc += lhs_offset * rhs_v0; - iacc += lhs_offset * rhs_v1; + iacc += lhs_v0 * rhs_v0; + iacc += lhs_offset * rhs_v0; - lhs_ptr += 1; - rhs_ptr += 1; - } + lhs_ptr += 1; - lhs_ptr += (block_size / 2); + // Increment only when k_idx is not a multiple of 2 + rhs_ptr += k_idx % 2; } main_acc += iacc * rhs_scale; @@ -333,13 +329,13 @@ static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, return is_valid; } -int main(int argc, char** argv) { +int main() { const size_t m = 37; - const size_t n = 1024; + const size_t n = 37; const size_t k = 256; const size_t bl = 64; + const size_t num_blocks_per_row = k / bl; - const size_t num_byte_per_block = bl / 2 + sizeof(float); const size_t seed_lhs = 4568; const size_t seed_rhs = seed_lhs + 4; @@ -347,23 +343,23 @@ int main(int argc, char** argv) { const size_t lhs_native_size_f32 = m * k * sizeof(float); const size_t rhs_native_size_f32 = n * k * sizeof(float); - const size_t rhs_native_size_qs4c32 = n * num_blocks_per_row * num_byte_per_block; - const size_t rhs_native_with_no_scale_size_qs4c32 = n * num_blocks_per_row * (bl / 2); + const size_t rhs_native_size_qs4c32 = n * num_blocks_per_row * (bl / 2); const size_t rhs_scales_size_bf16 = n * num_blocks_per_row * sizeof(uint16_t); // Allocate the memory uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; uint8_t* rhs_native_mtx_qs4c32 = new uint8_t[rhs_native_size_qs4c32]; - uint8_t* rhs_native_with_no_scale_mtx_qs4c32 = new uint8_t[rhs_native_with_no_scale_size_qs4c32]; uint8_t* rhs_scales_mtx_bf16 = new uint8_t[rhs_scales_size_bf16]; fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); quant_qs4c32_f32( - n, k, bl, (const float*)rhs_native_mtx_f32, (uint8_t*)rhs_native_mtx_qs4c32, - rhs_native_with_no_scale_mtx_qs4c32, (uint16_t*)rhs_scales_mtx_bf16); + n, k, bl, // Dimensions + (const float*)rhs_native_mtx_f32, // RHS (F32) + rhs_native_mtx_qs4c32, // RHS (QS4C32) + (uint16_t*)rhs_scales_mtx_bf16); // Scales (Bf16) delete[] rhs_native_mtx_f32; @@ -386,6 +382,7 @@ int main(int argc, char** argv) { bl, // Block length (const int8_t*)lhs_ref_mtx_qa8dx, // LHS (const uint8_t*)rhs_native_mtx_qs4c32, // RHS + (const uint16_t*)rhs_scales_mtx_bf16, // Scale (float*)dst_ref_mtx_f32, // DST -FLT_MAX, FLT_MAX); // Min and max for the clamp operation @@ -428,13 +425,13 @@ int main(int argc, char** argv) { // RHS packing kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - 1, n, k, // Dimensions - nr, kr, sr, // Packing arguments - bl, // Block length - (const uint8_t*)(rhs_native_with_no_scale_mtx_qs4c32), // RHS - NULL, // Bias - rhs_scales_mtx_bf16, // Scale - rhs_packed_mtx_qs4cx, // RHS packed + 1, n, k, // Dimensions + nr, kr, sr, // Packing arguments + bl, // Block length + (const uint8_t*)(rhs_native_mtx_qs4c32), // RHS + NULL, // Bias + rhs_scales_mtx_bf16, // Scale + rhs_packed_mtx_qs4cx, // RHS packed 0, ¶ms); const auto time_s = std::chrono::high_resolution_clock::now(); @@ -492,6 +489,7 @@ int main(int argc, char** argv) { } delete[] lhs_native_mtx_f32; delete[] rhs_native_mtx_qs4c32; + delete[] rhs_scales_mtx_bf16; delete[] dst_ref_mtx_f32; } diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c index 46147992..f01277dc 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c @@ -101,9 +101,6 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_do } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(size_t m, size_t n) { - // Temporary assert - KAI_ASSERT((n % kai_nr) == 0); - return m * n * sizeof(float); } @@ -111,7 +108,6 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( size_t m, size_t n, size_t k, size_t bl, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { // Temporary asserts - KAI_ASSERT(n % kai_nr == 0); KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT((bl % 32) == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c index 91b4f4f2..9b271272 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -101,9 +101,6 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8 } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(size_t m, size_t n) { - // Temporary assert - KAI_ASSERT((n % kai_nr) == 0); - return m * n * sizeof(float); } @@ -111,7 +108,6 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { // Temporary asserts - KAI_ASSERT(n % kai_nr == 0); KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT((bl % 32) == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c index b2f41a15..60a12adc 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c @@ -101,9 +101,6 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8 } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(size_t m, size_t n) { - // Temporary assert - KAI_ASSERT((n % kai_nr) == 0); - return m * n * sizeof(float); } @@ -111,7 +108,6 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { // Temporary asserts - KAI_ASSERT(n % kai_nr == 0); KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT((bl % 32) == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c index 7ba3039b..6b9e1435 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c @@ -29,7 +29,7 @@ inline static size_t kai_rhs_stride(size_t k, size_t bl) { const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); // The RHS matrix (not packed) does not pack the scale. - // Therefore, the numbr of bytes per scale must be 0 + // Therefore, the number of bytes for the scale is 0 const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, 0); return num_bytes_per_block * num_blocks_per_row; @@ -83,12 +83,9 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( KAI_ASSERT((k % 2) == 0); KAI_ASSERT((k % kr) == 0); KAI_ASSERT((k % bl) == 0); - KAI_ASSERT((n % nr) == 0); KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); - KAI_UNUSED(kr); - - const size_t num_rows = n / nr; + const size_t num_rows = kai_roundup(n, nr) / nr; return num_rows * kai_rhs_packed_stride(k, nr, kr, bl, scale_dt); } @@ -97,10 +94,8 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, const float* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params* params) { - // Temporary asserts KAI_ASSERT(num_groups == 1); KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((n % nr) == 0); KAI_ASSERT((k % kr) == 0); KAI_ASSERT((k % bl) == 0); KAI_ASSERT(bias == NULL); @@ -124,77 +119,138 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( const size_t rhs_packed_offset_end_of_all_blocks = kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_segments_per_block = bl / kr; - const size_t num_bytes_per_segment = kr / 2; + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + const size_t num_bytes_per_block_k = bl / 2; + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t k_interleaved_v = 16U; + const size_t block_length_in_bytes = kr / sr; + const size_t scale_stride = num_bytes_multiplier_rhs * num_blocks_per_row; + + const int32_t rhs_zero_point = params->rhs_zero_point; const enum kai_datatype scale_dt = params->scale_dt; - for (size_t y = 0; y < n; y += nr) { - const uint8_t* src_row = (const uint8_t*)rhs + y * rhs_stride; - uint8_t* dst_row = (uint8_t*)rhs_packed + (y / nr) * rhs_packed_stride; + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + // Before packing, it keeps the pointer to the first quantized block + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; - float* dst_sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); + float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); // Initialize the RHS reduction sums to zero - memset(dst_sums, 0, nr * kai_num_bytes_sum_rhs); + memset(sums, 0, nr * kai_num_bytes_sum_rhs); - for (size_t x = 0; x < num_blocks_per_row; ++x) { - // Store the scales at the end of the block - uint8_t* dst_scales = (dst_row + (bl / 2) * nr); + // Iterate over the quantized blocks + for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_blocks_per_row; ++dst_qblock_idx) { + // Store the scales after packing all K values + void* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr; for (size_t i = 0; i < nr; ++i) { + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs; + void* src_scales_ptr = (void*)(scale + dst_qblock_idx * num_bytes_multiplier_rhs + // + (src_row_idx * scale_stride)); // + memcpy( - dst_scales + i * num_bytes_multiplier_rhs, // - scale + ((y + i) * num_blocks_per_row + x) * num_bytes_multiplier_rhs, // - num_bytes_multiplier_rhs); // + dst_scales_ptr, // + src_scales_ptr, // + num_bytes_multiplier_rhs); // } - // Store the segments - for (size_t s = 0; s < num_segments_per_block; ++s) { - for (size_t i = 0; i < nr; ++i) { - memcpy(dst_row + i * num_bytes_per_segment, src_row + i * rhs_stride, num_bytes_per_segment); - float d = 0.0F; - if (scale_dt == F32) { - d = ((float*)dst_scales)[i]; - } else if (scale_dt == F16) { - d = kai_f16_to_f32(((uint16_t*)dst_scales)[i]); - } else if (scale_dt == Bf16) { - d = kai_bf16_to_f32(((uint16_t*)dst_scales)[i]); - } else { - KAI_ERROR("Unsupported scale data type"); - } - for (size_t b = 0; b < num_bytes_per_segment; ++b) { - uint8_t qs = dst_row[i * num_bytes_per_segment + b]; + for (size_t dst_byte_idx = 0; dst_byte_idx < nr * num_bytes_per_block_k; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = + dst_qblock_idx * bl + block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; - const int32_t x0 = (qs & 0x0F) - 8; - const int32_t x1 = (qs >> 4) - 8; + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); - // Add offset (0x88) - dst_row[i * num_bytes_per_segment + b] = qs ^ 0x88; + const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } - dst_sums[i] += x0 * d; - dst_sums[i] += x1 * d; + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + // The following operations where we extract the values from the bytes + // can be also written in the following and less efficient manner: + /* + uint8_t src_x0_lo = 0; + uint8_t src_x0_hi = 0; + + if ((k0_idx % 2) == 0) { + src_x0_lo = (byte0 & 0x0F); + } else { + src_x0_lo = (byte0 >> 4); } + + if ((k1_idx % 2) == 0) { + src_x0_hi = (byte1 & 0x0F); + } else { + src_x0_hi = (byte1 >> 4); + } + */ + const size_t shift_right_x0 = (k0_idx % 2) * 4; + const size_t shift_right_x1 = (k1_idx % 2) * 4; + + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + + float d = 0.0F; + switch (scale_dt) { + case F32: + d = ((float*)rhs_packed_scale)[nr_idx]; + break; + case F16: + d = kai_f16_to_f32(((uint16_t*)rhs_packed_scale)[nr_idx]); + break; + case Bf16: + d = kai_bf16_to_f32(((uint16_t*)rhs_packed_scale)[nr_idx]); + break; + default: + KAI_ERROR("Unsupported scale data type"); + break; } - src_row += num_bytes_per_segment; - dst_row += num_bytes_per_segment * nr; - } + sums[nr_idx] += ((int32_t)src_x0_lo - rhs_zero_point) * d; + sums[nr_idx] += ((int32_t)src_x0_hi - rhs_zero_point) * d; - dst_row += (num_bytes_multiplier_rhs * nr); + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + dst_row[dst_byte_idx] = dst_qs0 ^ 0x88; + } + // Move the pointer after K values + dst_row += num_bytes_per_block * nr; } - // Skip the row sum - dst_row += (kai_num_bytes_sum_rhs * nr); + // Move the pointer after the row sum + dst_row += kai_num_bytes_sum_rhs * nr; // Set the bias if (bias == NULL) { memset(dst_row, 0, nr * kai_num_bytes_bias); } else { for (size_t i = 0; i < nr; ++i) { - ((float*)dst_row)[i] = bias[y + nr]; + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; } } - dst_row += (kai_num_bytes_bias * nr); + // Move the pointer after the biases + dst_row += kai_num_bytes_bias * nr; } } -- GitLab From d175b929d6ad3f7385a038aefc6710db4a024dfb Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 29 Aug 2024 12:49:51 +0100 Subject: [PATCH 11/29] Update unit tests * Update native RHS layout in unit test * Add test for all microkernel variants Signed-off-by: Anitha Raj --- CMakeLists.txt | 3 + ...matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp | 225 +++++++++++++++++- 2 files changed, 224 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5451009c..6c7b5ba6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,6 +98,8 @@ set(KLEIDIAI_FILES_NEON_DOTPROD kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c ) set(KLEIDIAI_FILES_NEON_I8MM @@ -107,6 +109,7 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c ) set(KLEIDIAI_FILES_SME diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp index 3aab46c2..8988793a 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -11,7 +11,10 @@ #include #include +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h" #include "test/common/bfloat16.hpp" @@ -25,7 +28,151 @@ namespace kai::test { -TEST(matmul_clamp_f32_qai8dxp_qsi4c32p, EndTOEnd) { +TEST(matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, EndTOEnd) { + const uint64_t seed = 0; + + const size_t M = 32; + const size_t N = 64; + const size_t K = 64; + + const auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(); + const auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(); + const auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(); + const auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(); + const size_t bl = 32; + + // Generates input data. + const auto ref_lhs = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit asymmetric quantization. + // * Quantizes the RHS matrix using 4-bit symmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi4, ref_rhs_scales] = + quantize_symmetric_per_block(ref_rhs.data(), N, K, bl); + + const auto ref_dst = matmul_clamp_nt_t( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), + ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), + std::numeric_limits::max()); + + // Runs the LHS packing micro-kernel. + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + std::vector imp_packed_lhs(imp_packed_lhs_size); + kai_run_lhs_quant_pack_qai8dxp_f32( + M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + + // Runs the RHS packing micro-kernel. + // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. + // * Packs the RHS matrix. + const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); + + const auto imp_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(N, K, nr, kr, bl, kai_datatype::Bf16); + std::vector imp_packed_rhs(imp_packed_rhs_size); + const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params params{ + .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), nullptr, reinterpret_cast(ref_rhs_scales.data()), + imp_packed_rhs.data(), 0, ¶ms); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + std::vector imp_dst(imp_dst_size); + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + M, N, K, bl, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), + N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); + + // Compares the output of the micro-kernels against the output of the reference implementation. + for (size_t y = 0; y < M; ++y) { + for (size_t x = 0; x < N; ++x) { + const auto imp_value = read_array(imp_dst.data(), y * N + x); + const auto ref_value = read_array(ref_dst.data(), y * N + x); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); + + if (rel_error > 0.0001F) { + ASSERT_EQ(imp_value, ref_value); + } + } + } +} +TEST(kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, EndTOEnd) { + const uint64_t seed = 0; + + const size_t M = 32; + const size_t N = 64; + const size_t K = 64; + + const auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(); + const auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(); + const auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(); + const auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(); + const size_t bl = 32; + + // Generates input data. + const auto ref_lhs = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit asymmetric quantization. + // * Quantizes the RHS matrix using 4-bit symmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi4, ref_rhs_scales] = + quantize_symmetric_per_block(ref_rhs.data(), N, K, bl); + + const auto ref_dst = matmul_clamp_nt_t( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), + ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), + std::numeric_limits::max()); + + // Runs the LHS packing micro-kernel. + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + std::vector imp_packed_lhs(imp_packed_lhs_size); + kai_run_lhs_quant_pack_qai8dxp_f32( + M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + + // Runs the RHS packing micro-kernel. + // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. + // * Packs the RHS matrix. + const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); + + const auto imp_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(N, K, nr, kr, bl, kai_datatype::Bf16); + std::vector imp_packed_rhs(imp_packed_rhs_size); + const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params params{ + .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), nullptr, reinterpret_cast(ref_rhs_scales.data()), + imp_packed_rhs.data(), 0, ¶ms); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + std::vector imp_dst(imp_dst_size); + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + M, N, K, bl, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), + N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); + + // Compares the output of the micro-kernels against the output of the reference implementation. + for (size_t y = 0; y < M; ++y) { + for (size_t x = 0; x < N; ++x) { + const auto imp_value = read_array(imp_dst.data(), y * N + x); + const auto ref_value = read_array(ref_dst.data(), y * N + x); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); + + if (rel_error > 0.0001F) { + ASSERT_EQ(imp_value, ref_value); + } + } + } +} +TEST(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, EndTOEnd) { const uint64_t seed = 0; const size_t M = 32; @@ -66,7 +213,6 @@ TEST(matmul_clamp_f32_qai8dxp_qsi4c32p, EndTOEnd) { // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. // * Packs the RHS matrix. const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); - const auto ref_rhs_qsu4_interleaved_block = pack_data_interleave_block(ref_rhs_qsu4.data(), N, K, bl); const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(N, K, nr, kr, bl, kai_datatype::Bf16); @@ -74,8 +220,8 @@ TEST(matmul_clamp_f32_qai8dxp_qsi4c32p, EndTOEnd) { const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params params{ .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4_interleaved_block.data(), nullptr, - reinterpret_cast(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, ¶ms); + 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), nullptr, reinterpret_cast(ref_rhs_scales.data()), + imp_packed_rhs.data(), 0, ¶ms); // Runs the GEMM micro-kernel. const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(M, N); @@ -98,5 +244,76 @@ TEST(matmul_clamp_f32_qai8dxp_qsi4c32p, EndTOEnd) { } } } +TEST(matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, EndTOEnd) { + const uint64_t seed = 0; + + const size_t M = 32; + const size_t N = 64; + const size_t K = 64; + + const auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(); + const auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(); + const auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(); + const auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(); + const size_t bl = 32; + + // Generates input data. + const auto ref_lhs = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit asymmetric quantization. + // * Quantizes the RHS matrix using 4-bit symmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi4, ref_rhs_scales] = + quantize_symmetric_per_block(ref_rhs.data(), N, K, bl); + + const auto ref_dst = matmul_clamp_nt_t( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), + ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), + std::numeric_limits::max()); + + // Runs the LHS packing micro-kernel. + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + std::vector imp_packed_lhs(imp_packed_lhs_size); + kai_run_lhs_quant_pack_qai8dxp_f32( + M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + + // Runs the RHS packing micro-kernel. + // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. + // * Packs the RHS matrix. + const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); + const auto imp_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(N, K, nr, kr, bl, kai_datatype::Bf16); + std::vector imp_packed_rhs(imp_packed_rhs_size); + const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params params{ + .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), nullptr, reinterpret_cast(ref_rhs_scales.data()), + imp_packed_rhs.data(), 0, ¶ms); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + std::vector imp_dst(imp_dst_size); + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + M, N, K, bl, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), + N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); + + // Compares the output of the micro-kernels against the output of the reference implementation. + for (size_t y = 0; y < M; ++y) { + for (size_t x = 0; x < N; ++x) { + const auto imp_value = read_array(imp_dst.data(), y * N + x); + const auto ref_value = read_array(ref_dst.data(), y * N + x); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); + + if (rel_error > 0.0001F) { + ASSERT_EQ(imp_value, ref_value); + } + } + } +} } // namespace kai::test -- GitLab From 7a24ed10343655c9684cb9a3af8a4fb234a41724 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 29 Aug 2024 14:36:33 +0100 Subject: [PATCH 12/29] Fix for pipeline failure Signed-off-by: Anitha Raj --- test/reference/pack.cpp | 2 +- test/reference/pack.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index e4caef26..221ba360 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -237,7 +237,7 @@ std::vector pack_data_scales_interleave_block( template std::vector pack_data_scales_interleave_block( const void* data, const void* scales, size_t height, size_t width, size_t quant_width); -template std::vector pack_data_scales_interleave_block( +template std::vector pack_data_scales_interleave_block( const void* data, const void* scales, size_t height, size_t width, size_t quant_width); template diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp index fd421f72..10d76a7f 100644 --- a/test/reference/pack.hpp +++ b/test/reference/pack.hpp @@ -172,7 +172,7 @@ std::vector pack_data_scales_interleave_block( /// @return The packed data buffer. template std::vector pack_data_interleave_block(const void* data, size_t height, size_t width, size_t block_width) { - return pack_data_scales_interleave_block(data, nullptr, height, width, block_width); + return pack_data_scales_interleave_block(data, nullptr, height, width, block_width); } } // namespace kai::test -- GitLab From f8f1a26bce2e421ac5735808b9ca20d1f13b3b65 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 29 Aug 2024 17:00:02 +0100 Subject: [PATCH 13/29] Remove restriction on the K dimension (part 1) - Remove asserts in the ukernels related to K dimension - Fix the reference implementation This commit does not remove this restriction. A further change is needed in the lhs packing function. In particular, the per-row LHS packing function needs the bl (block length) information to get the same alignment as the RHS packed matrix Signed-off-by: Gian Marco Iodice --- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 37 +++++-- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c | 38 +++---- ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c | 38 +++---- ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c | 38 +++---- ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c | 38 +++---- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c | 103 +++++++++--------- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h | 34 +++++- 7 files changed, 186 insertions(+), 140 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index 2c6a298d..eeaf981e 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -90,8 +90,8 @@ static size_t roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; } -static inline size_t num_blocks_per_row(size_t k, size_t bl) { - return k / bl; +static inline size_t get_num_blocks_per_row(size_t k, size_t bl) { + return roundup(k, bl) / bl; } static inline uint16_t convert_f32_to_bf16(float f32) { @@ -118,8 +118,8 @@ static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, si static void quant_qs4c32_f32( size_t n, size_t k, size_t bl, const float* rhs_f32, uint8_t* rhs_qs4c32, uint16_t* rhs_scales_bf16) { + const size_t num_blocks_row = get_num_blocks_per_row(k, bl); const size_t rhs_qs4c32_stride = (roundup(k, 2) / 2); - const size_t num_blocks_row = num_blocks_per_row(k, bl); for (size_t row_idx = 0; row_idx < n; ++row_idx) { const float* src_ptr = rhs_f32 + row_idx * k; @@ -129,7 +129,13 @@ static void quant_qs4c32_f32( float max = 0.0f; for (size_t b = 0; b < bl; ++b) { - const float src0_0 = src_ptr[block_idx * bl + b]; + const size_t k_idx = block_idx * bl + b; + + if (k_idx >= k) { + break; + } + + const float src0_0 = src_ptr[k_idx]; const float asrc0_0 = fabsf(src0_0); if (amax < asrc0_0) { @@ -148,6 +154,11 @@ static void quant_qs4c32_f32( for (size_t i = 0; i < bl; ++i) { const size_t k_idx = block_idx * bl + i; + + if (k_idx >= k) { + break; + } + const float src0_0 = src_ptr[k_idx]; // Scale the values @@ -245,10 +256,10 @@ static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t static void ref_matmul_f32_qa8dx_qs4c32( size_t m, size_t n, size_t k, size_t bl, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4c32, const uint16_t* scale_bf16, float* dst_f32, float scalar_min, float scalar_max) { - const size_t num_blocks_row = num_blocks_per_row(k, bl); + const size_t num_blocks_row = get_num_blocks_per_row(k, bl); - const size_t lhs_stride = k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); - const size_t rhs_stride = num_blocks_row * (bl / 2); + const size_t lhs_stride = k + sizeof(float) + sizeof(int32_t); + const size_t rhs_stride = (roundup(k, 2) / 2); for (size_t row_idx = 0; row_idx < m; ++row_idx) { const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; @@ -277,6 +288,10 @@ static void ref_matmul_f32_qa8dx_qs4c32( for (size_t i = 0; i < bl; ++i) { const size_t k_idx = block_idx * bl + i; + if (k_idx >= k) { + break; + } + // Get the LHS values const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; @@ -331,11 +346,11 @@ static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, int main() { const size_t m = 37; - const size_t n = 37; + const size_t n = 75; const size_t k = 256; const size_t bl = 64; - const size_t num_blocks_per_row = k / bl; + const size_t num_blocks_per_row = get_num_blocks_per_row(k, bl); const size_t seed_lhs = 4568; const size_t seed_rhs = seed_lhs + 4; @@ -343,7 +358,7 @@ int main() { const size_t lhs_native_size_f32 = m * k * sizeof(float); const size_t rhs_native_size_f32 = n * k * sizeof(float); - const size_t rhs_native_size_qs4c32 = n * num_blocks_per_row * (bl / 2); + const size_t rhs_native_size_qs4c32 = n * (roundup(k, 2) / 2); const size_t rhs_scales_size_bf16 = n * num_blocks_per_row * sizeof(uint16_t); // Allocate the memory @@ -406,7 +421,7 @@ int main() { // Get the size in bytes for the packed matrices const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); const size_t rhs_packed_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(n, k, nr, kr, bl, Bf16); + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(n, k, nr, kr, sr, bl, Bf16); const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); // Allocate the matrices diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c index 38d496a3..12d9d6ea 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c @@ -19,7 +19,7 @@ static const size_t kai_mr = 1; static const size_t kai_nr = 4; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_k0 = 32; +static const size_t kai_bl_multiple_of = 32; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); @@ -27,23 +27,28 @@ static const size_t kai_num_bytes_sum_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % bl) == 0); - return k / bl; + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); } inline static size_t kai_lhs_packed_stride(size_t k) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kai_kr) == 0); + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); - return kai_mr * (k * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); } inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kai_kr) == 0); KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; @@ -84,10 +89,6 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( size_t n_idx, size_t k, size_t bl) { KAI_ASSERT((n_idx % kai_n_step) == 0); - KAI_ASSERT((bl % 32) == 0); - - // Temporary assert - KAI_ASSERT((k % kai_k0) == 0); return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); } @@ -107,17 +108,16 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotp void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( size_t m, size_t n, size_t k, size_t bl, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { - // Temporary asserts - KAI_ASSERT(k % kai_k0 == 0); - KAI_ASSERT((bl % 32) == 0); + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { return; } - size_t num_subblocks = bl / 32; - size_t num_blocks = k / bl; + size_t num_subblocks = bl / kai_bl_multiple_of; + size_t num_blocks = kai_num_blocks_per_row(k, bl); float clamp_vals[2] = {scalar_min, scalar_max}; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c index f01277dc..edd09711 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c @@ -19,7 +19,7 @@ static const size_t kai_mr = 1; static const size_t kai_nr = 8; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_k0 = 32; +static const size_t kai_bl_multiple_of = 32; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); @@ -27,23 +27,28 @@ static const size_t kai_num_bytes_sum_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % bl) == 0); - return k / bl; + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); } inline static size_t kai_lhs_packed_stride(size_t k) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kai_kr) == 0); + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); - return kai_mr * (k * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); } inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kai_kr) == 0); KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; @@ -84,10 +89,6 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( size_t n_idx, size_t k, size_t bl) { KAI_ASSERT((n_idx % kai_n_step) == 0); - KAI_ASSERT((bl % 32) == 0); - - // Temporary assert - KAI_ASSERT((k % kai_k0) == 0); return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); } @@ -107,17 +108,16 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotp void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( size_t m, size_t n, size_t k, size_t bl, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { - // Temporary asserts - KAI_ASSERT(k % kai_k0 == 0); - KAI_ASSERT((bl % 32) == 0); + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { return; } - size_t num_subblocks = bl / 32; - size_t num_blocks = k / bl; + size_t num_subblocks = bl / kai_bl_multiple_of; + size_t num_blocks = kai_num_blocks_per_row(k, bl); float clamp_vals[2] = {scalar_min, scalar_max}; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c index 9b271272..3b9cb2c1 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -19,7 +19,7 @@ static const size_t kai_mr = 4; static const size_t kai_nr = 4; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_k0 = 32; +static const size_t kai_bl_multiple_of = 32; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); @@ -27,23 +27,28 @@ static const size_t kai_num_bytes_sum_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % bl) == 0); - return k / bl; + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); } inline static size_t kai_lhs_packed_stride(size_t k) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kai_kr) == 0); + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); - return kai_mr * (k * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); } inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kai_kr) == 0); KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; @@ -84,10 +89,6 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( size_t n_idx, size_t k, size_t bl) { KAI_ASSERT((n_idx % kai_n_step) == 0); - KAI_ASSERT((bl % 32) == 0); - - // Temporary assert - KAI_ASSERT((k % kai_k0) == 0); return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); } @@ -107,17 +108,16 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { - // Temporary asserts - KAI_ASSERT(k % kai_k0 == 0); - KAI_ASSERT((bl % 32) == 0); + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { return; } - size_t num_subblocks = bl / 32; - size_t num_blocks = k / bl; + size_t num_subblocks = bl / kai_bl_multiple_of; + size_t num_blocks = kai_num_blocks_per_row(k, bl); float clamp_vals[2] = {scalar_min, scalar_max}; diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c index 60a12adc..4dcb1955 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c @@ -19,7 +19,7 @@ static const size_t kai_mr = 4; static const size_t kai_nr = 8; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_k0 = 32; +static const size_t kai_bl_multiple_of = 32; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); @@ -27,23 +27,28 @@ static const size_t kai_num_bytes_sum_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % bl) == 0); - return k / bl; + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); } inline static size_t kai_lhs_packed_stride(size_t k) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kai_kr) == 0); + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); - return kai_mr * (k * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); } inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kai_kr) == 0); KAI_ASSERT((bl % kai_kr) == 0); - KAI_ASSERT((k % bl) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; @@ -84,10 +89,6 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( size_t n_idx, size_t k, size_t bl) { KAI_ASSERT((n_idx % kai_n_step) == 0); - KAI_ASSERT((bl % 32) == 0); - - // Temporary assert - KAI_ASSERT((k % kai_k0) == 0); return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); } @@ -107,17 +108,16 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { - // Temporary asserts - KAI_ASSERT(k % kai_k0 == 0); - KAI_ASSERT((bl % 32) == 0); + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { return; } - size_t num_subblocks = bl / 32; - size_t num_blocks = k / bl; + size_t num_subblocks = bl / kai_bl_multiple_of; + size_t num_blocks = kai_num_blocks_per_row(k, bl); float clamp_vals[2] = {scalar_min, scalar_max}; diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c index 6b9e1435..3eccea14 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c @@ -13,47 +13,28 @@ static const size_t kai_num_bytes_sum_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); +static const size_t kai_nr_multiple_of = 4; +static const size_t kai_bl_multiple_of = 32; inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % bl) == 0); - return k / bl; + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; } inline static size_t kai_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); return (bl / 2) + num_bytes_multiplier_rhs; } -inline static size_t kai_rhs_stride(size_t k, size_t bl) { - KAI_ASSERT((k % 2) == 0); - - const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - // The RHS matrix (not packed) does not pack the scale. - // Therefore, the number of bytes for the scale is 0 - const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, 0); - - return num_bytes_per_block * num_blocks_per_row; -} - -inline static size_t kai_rhs_packed_stride(size_t k, size_t nr, size_t kr, size_t bl, enum kai_datatype scale_dt) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kr) == 0); - KAI_ASSERT((k % bl) == 0); - KAI_ASSERT((bl % kr) == 0); - - const size_t num_bytes_multiplier_rhs = kai_num_bytes_datatype(scale_dt); - const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); - - return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +inline static size_t kai_rhs_stride(size_t k) { + return kai_roundup(k, 2) / 2; } inline static size_t kai_rhs_packed_offset_end_of_all_blocks( size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kr) == 0); - KAI_ASSERT((k % bl) == 0); KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); @@ -61,33 +42,52 @@ inline static size_t kai_rhs_packed_offset_end_of_all_blocks( return (nr * num_bytes_per_block * num_blocks_per_row); } +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr) { + return nr; +} + size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(size_t n_idx, size_t rhs_stride) { return n_idx * rhs_stride; } -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - size_t n_idx, size_t k, size_t nr, size_t kr, size_t bl, enum kai_datatype scale_dt) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kr) == 0); - KAI_ASSERT((k % bl) == 0); - KAI_ASSERT((n_idx % nr) == 0); +size_t kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s16s0( + size_t k, size_t nr, size_t kr, size_t sr, size_t bl, enum kai_datatype scale_dt) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); KAI_UNUSED(kr); + KAI_UNUSED(sr); - return (n_idx / nr) * kai_rhs_packed_stride(k, nr, kr, bl, scale_dt); + const size_t num_bytes_multiplier_rhs = kai_num_bytes_datatype(scale_dt); + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + + return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, enum kai_datatype scale_dt) { + KAI_ASSERT((n_idx % nr) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); + + return (n_idx / nr) * kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s16s0(k, nr, kr, sr, bl, scale_dt); } size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - size_t n, size_t k, size_t nr, size_t kr, size_t bl, enum kai_datatype scale_dt) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kr) == 0); - KAI_ASSERT((k % bl) == 0); + size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, enum kai_datatype scale_dt) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * kai_rhs_packed_stride(k, nr, kr, bl, scale_dt); + return num_rows * kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s16s0(k, nr, kr, sr, bl, scale_dt); } void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( @@ -95,36 +95,35 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( const float* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params* params) { KAI_ASSERT(num_groups == 1); - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kr) == 0); - KAI_ASSERT((k % bl) == 0); - KAI_ASSERT(bias == NULL); KAI_ASSERT(extra_bytes == 0); - - KAI_ASSERT(sr == 2); - KAI_ASSERT(kr >= 1 && kr <= 16); KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); KAI_ASSERT(rhs_packed != NULL); KAI_ASSERT(params != NULL); KAI_ASSERT(params->rhs_zero_point == 8); KAI_ASSERT(params->lhs_zero_point == 1); + + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); KAI_ASSERT(params->scale_dt == F32 || params->scale_dt == F16 || params->scale_dt == Bf16); // Note: The input matrix (rhs) is expected with: // "k" columns and "n" rows (NxK) const size_t num_bytes_multiplier_rhs = kai_num_bytes_datatype(params->scale_dt); - const size_t rhs_stride = kai_rhs_stride(k, bl); - const size_t rhs_packed_stride = kai_rhs_packed_stride(k, nr, kr, bl, params->scale_dt); + const size_t rhs_stride = kai_rhs_stride(k); + const size_t rhs_packed_stride = + kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s16s0(k, nr, kr, sr, bl, params->scale_dt); const size_t rhs_packed_offset_end_of_all_blocks = kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); - const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_qblocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); const size_t num_bytes_per_block_k = bl / 2; const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t k_interleaved_v = 16U; const size_t block_length_in_bytes = kr / sr; - const size_t scale_stride = num_bytes_multiplier_rhs * num_blocks_per_row; + const size_t scale_stride = num_bytes_multiplier_rhs * num_qblocks_per_row; const int32_t rhs_zero_point = params->rhs_zero_point; const enum kai_datatype scale_dt = params->scale_dt; @@ -139,7 +138,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( memset(sums, 0, nr * kai_num_bytes_sum_rhs); // Iterate over the quantized blocks - for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_blocks_per_row; ++dst_qblock_idx) { + for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) { // Store the scales after packing all K values void* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr; diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h index 9132d9be..6039b588 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h @@ -23,6 +23,15 @@ struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params { enum kai_datatype scale_dt; }; +/// Get the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// +/// @return the n step value +size_t kai_get_n_step_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(size_t nr); + /// Gets the offset in bytes for the RHS matrix (not packed), which holds /// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. /// @@ -42,12 +51,32 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( size_t n_idx, // size_t rhs_stride); // +/// Get the row stride in bytes to the packed RHS matrix +/// +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the stride in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + /// Gets the offset in bytes for the packed RHS matrix. /// /// @param[in] n_idx Row index in the RHS matrix (not packed). /// @param[in] k The common dimension between the LHS and RHS matrix (K) /// @param[in] nr The number of columns written by the matmul micro-kernel /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a /// multiple of 32. /// @param[in] scale_dt Block scale data type @@ -58,6 +87,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( size_t k, // size_t nr, // size_t kr, // + size_t sr, // size_t bl, // enum kai_datatype scale_dt); // @@ -67,6 +97,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( /// @param[in] k The number of columns in the RHS matrix (not packed). /// @param[in] nr The number of columns written by the matmul micro-kernel /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple /// of 32. /// @param[in] scale_dt Block scale data type @@ -77,6 +108,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( size_t k, // size_t nr, // size_t kr, // + size_t sr, // size_t bl, // enum kai_datatype scale_dt); // @@ -89,7 +121,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( /// @param[in] num_groups The number of groups. It must be 1. /// @param[in] n The number of columns of the output matrix (N). /// @param[in] k The common dimension between the LHS and RHS matrix (K). -/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] nr The number of columns written by the matmul micro-kernel. It must be a multiple of 4. /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// However, kr must be multiple of sr. -- GitLab From eb9dc3d444551bfed25a8885a6b24832d6373eb2 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 29 Aug 2024 21:30:30 +0100 Subject: [PATCH 14/29] Include rhs_stride and rhs_scale_stride in the list of input args Signed-off-by: Gian Marco Iodice --- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 20 ++++++-- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c | 49 ++++++++++++++----- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h | 40 ++++++++------- 3 files changed, 74 insertions(+), 35 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index eeaf981e..fa46ac25 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -94,6 +94,15 @@ static inline size_t get_num_blocks_per_row(size_t k, size_t bl) { return roundup(k, bl) / bl; } +static inline size_t get_rhs_native_stride(size_t k) { + return roundup(k, 2) / 2; +} + +static inline size_t get_rhs_scale_stride(size_t k, size_t bl) { + const size_t num_blocks_per_row = get_num_blocks_per_row(k, bl); + return num_blocks_per_row * sizeof(uint16_t); +} + static inline uint16_t convert_f32_to_bf16(float f32) { const uint32_t* i32 = reinterpret_cast(&f32); uint16_t bf16 = (*i32 >> 16); @@ -119,7 +128,7 @@ static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, si static void quant_qs4c32_f32( size_t n, size_t k, size_t bl, const float* rhs_f32, uint8_t* rhs_qs4c32, uint16_t* rhs_scales_bf16) { const size_t num_blocks_row = get_num_blocks_per_row(k, bl); - const size_t rhs_qs4c32_stride = (roundup(k, 2) / 2); + const size_t rhs_qs4c32_stride = get_rhs_native_stride(k); for (size_t row_idx = 0; row_idx < n; ++row_idx) { const float* src_ptr = rhs_f32 + row_idx * k; @@ -259,7 +268,7 @@ static void ref_matmul_f32_qa8dx_qs4c32( const size_t num_blocks_row = get_num_blocks_per_row(k, bl); const size_t lhs_stride = k + sizeof(float) + sizeof(int32_t); - const size_t rhs_stride = (roundup(k, 2) / 2); + const size_t rhs_stride = get_rhs_native_stride(k); for (size_t row_idx = 0; row_idx < m; ++row_idx) { const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; @@ -350,7 +359,6 @@ int main() { const size_t k = 256; const size_t bl = 64; - const size_t num_blocks_per_row = get_num_blocks_per_row(k, bl); const size_t seed_lhs = 4568; const size_t seed_rhs = seed_lhs + 4; @@ -358,8 +366,8 @@ int main() { const size_t lhs_native_size_f32 = m * k * sizeof(float); const size_t rhs_native_size_f32 = n * k * sizeof(float); - const size_t rhs_native_size_qs4c32 = n * (roundup(k, 2) / 2); - const size_t rhs_scales_size_bf16 = n * num_blocks_per_row * sizeof(uint16_t); + const size_t rhs_native_size_qs4c32 = n * get_rhs_native_stride(k); + const size_t rhs_scales_size_bf16 = n * get_rhs_scale_stride(k, bl); // Allocate the memory uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; @@ -444,8 +452,10 @@ int main() { nr, kr, sr, // Packing arguments bl, // Block length (const uint8_t*)(rhs_native_mtx_qs4c32), // RHS + get_rhs_native_stride(k), // RHS stride NULL, // Bias rhs_scales_mtx_bf16, // Scale + get_rhs_scale_stride(k, bl), // Scale stride rhs_packed_mtx_qs4cx, // RHS packed 0, ¶ms); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c index 3eccea14..276318d9 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c @@ -26,10 +26,6 @@ inline static size_t kai_num_bytes_per_block(size_t bl, size_t num_bytes_multipl return (bl / 2) + num_bytes_multiplier_rhs; } -inline static size_t kai_rhs_stride(size_t k) { - return kai_roundup(k, 2) / 2; -} - inline static size_t kai_rhs_packed_offset_end_of_all_blocks( size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { KAI_ASSERT((bl % kr) == 0); @@ -46,12 +42,19 @@ size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr) { return nr; } -size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(size_t n_idx, size_t rhs_stride) { +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + size_t n_idx, // + size_t rhs_stride) { return n_idx * rhs_stride; } size_t kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s16s0( - size_t k, size_t nr, size_t kr, size_t sr, size_t bl, enum kai_datatype scale_dt) { + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); @@ -68,7 +71,13 @@ size_t kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s16s0( } size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, enum kai_datatype scale_dt) { + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { KAI_ASSERT((n_idx % nr) == 0); KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); @@ -79,7 +88,13 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( } size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, enum kai_datatype scale_dt) { + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); @@ -91,8 +106,20 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( } void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, - const float* bias, const void* scale, void* rhs_packed, size_t extra_bytes, + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params* params) { KAI_ASSERT(num_groups == 1); KAI_ASSERT(extra_bytes == 0); @@ -112,7 +139,6 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( // "k" columns and "n" rows (NxK) const size_t num_bytes_multiplier_rhs = kai_num_bytes_datatype(params->scale_dt); - const size_t rhs_stride = kai_rhs_stride(k); const size_t rhs_packed_stride = kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s16s0(k, nr, kr, sr, bl, params->scale_dt); const size_t rhs_packed_offset_end_of_all_blocks = @@ -123,7 +149,6 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t k_interleaved_v = 16U; const size_t block_length_in_bytes = kr / sr; - const size_t scale_stride = num_bytes_multiplier_rhs * num_qblocks_per_row; const int32_t rhs_zero_point = params->rhs_zero_point; const enum kai_datatype scale_dt = params->scale_dt; diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h index 6039b588..b5bed391 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h @@ -118,24 +118,26 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( /// Two int4 values are stored in one byte. The lower order part of the byte (low) holds /// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). /// -/// @param[in] num_groups The number of groups. It must be 1. -/// @param[in] n The number of columns of the output matrix (N). -/// @param[in] k The common dimension between the LHS and RHS matrix (K). -/// @param[in] nr The number of columns written by the matmul micro-kernel. It must be a multiple of 4. -/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. -/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. -/// However, kr must be multiple of sr. -/// @param[in] bl The block length, which defines the number of -/// K values stored in a single block. It must be a multiple of 32. -/// @param[in] rhs The RHS matrix containing the 4-bit values. -/// Size in bytes is expected to be greater than or equal to n/// k/// (sizeof(uint8_t) / 2). -/// @param[in] bias The biases. -/// @param[in] scale The per-block quantization scales. -/// The scale data type must be proviided with the params object. -/// Supported scale data types are FP32, FP16 and BF16. -/// @param[out] rhs_packed The packed RHS matrix. -/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. -/// @param[in] params Parameters for the micro-kernel. +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of columns of the output matrix (N). +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] nr The number of columns written by the matmul micro-kernel. It must be a multiple of 4. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] bl The block length, which defines the number of +/// K values stored in a single block. It must be a multiple of 32. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n/// k/// (sizeof(uint8_t) / 2). +/// @param[in] rhs_stride The number of bytes per row in bytes of the RHS matrix +/// @param[in] bias The biases. +/// @param[in] scale The per-block quantization scales. +/// The scale data type must be proviided with the params object. +/// Supported scale data types are FP32, FP16 and BF16. +/// @param[in] scale_stride The number of bytes per row in bytes of the scale matrix +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( size_t num_groups, // size_t n, // @@ -145,8 +147,10 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( size_t sr, // size_t bl, // const uint8_t* rhs, // + size_t rhs_stride, // const float* bias, // const void* scale, // + size_t scale_stride, // void* rhs_packed, // size_t extra_bytes, // const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params* params); // -- GitLab From 8e488240e006b69e3d60a4848a0cb8be8d64e2b6 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 29 Aug 2024 21:39:15 +0100 Subject: [PATCH 15/29] Rename rhs packing kernel - Change the s16s0 suffix to s1s0 as the K values in the native matrix are stored in sequential order Signed-off-by: Gian Marco Iodice --- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 16 +++---- ...> kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c} | 20 ++++----- ...> kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h} | 44 +++++++++---------- 3 files changed, 40 insertions(+), 40 deletions(-) rename kai/ukernels/matmul/pack/{kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c => kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c} (94%) rename kai/ukernels/matmul/pack/{kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h => kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h} (91%) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index fa46ac25..9b262c01 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -21,7 +21,7 @@ #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" -#include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h" +#include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" #define INT4_MIN (-8) #define INT4_MAX (7) @@ -429,25 +429,25 @@ int main() { // Get the size in bytes for the packed matrices const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); const size_t rhs_packed_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(n, k, nr, kr, sr, bl, Bf16); + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, Bf16); const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); // Allocate the matrices uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; - uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; + uint8_t* rhs_packed_mtx_qs4c32 = new uint8_t[rhs_packed_size]; uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; memset(dst_act_mtx_f32, 0, dst_size); // If the RHS matrix contains constant values, the packing can be performed // only once - struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params params; + struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; params.scale_dt = Bf16; // RHS packing - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( 1, n, k, // Dimensions nr, kr, sr, // Packing arguments bl, // Block length @@ -456,7 +456,7 @@ int main() { NULL, // Bias rhs_scales_mtx_bf16, // Scale get_rhs_scale_stride(k, bl), // Scale stride - rhs_packed_mtx_qs4cx, // RHS packed + rhs_packed_mtx_qs4c32, // RHS packed 0, ¶ms); const auto time_s = std::chrono::high_resolution_clock::now(); @@ -477,7 +477,7 @@ int main() { const size_t dst_offset = ukernel_variants[idx_variant].ukernel.get_dst_offset(0, 0, dst_stride); const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); - const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + rhs_offset); + const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4c32 + rhs_offset); float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); ukernel_variants[idx_variant].ukernel.run_matmul( @@ -509,7 +509,7 @@ int main() { } std::cout << "------------" << std::endl; delete[] lhs_packed_mtx_qa8dx; - delete[] rhs_packed_mtx_qs4cx; + delete[] rhs_packed_mtx_qs4c32; delete[] dst_act_mtx_f32; } delete[] lhs_native_mtx_f32; diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c similarity index 94% rename from kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c rename to kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index 276318d9..687f3cef 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h" +#include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" #include #include @@ -42,13 +42,13 @@ size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr) { return nr; } -size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t n_idx, // size_t rhs_stride) { return n_idx * rhs_stride; } -size_t kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s16s0( +size_t kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t k, // size_t nr, // size_t kr, // @@ -70,7 +70,7 @@ size_t kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s16s0( return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t n_idx, // size_t k, // size_t nr, // @@ -84,10 +84,10 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( KAI_ASSERT((bl % kai_bl_multiple_of) == 0); KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); - return (n_idx / nr) * kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s16s0(k, nr, kr, sr, bl, scale_dt); + return (n_idx / nr) * kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); } -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t n, // size_t k, // size_t nr, // @@ -102,10 +102,10 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s16s0(k, nr, kr, sr, bl, scale_dt); + return num_rows * kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); } -void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( +void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t num_groups, // size_t n, // size_t k, // @@ -120,7 +120,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( size_t scale_stride, // void* rhs_packed, // size_t extra_bytes, // - const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params* params) { + const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params) { KAI_ASSERT(num_groups == 1); KAI_ASSERT(extra_bytes == 0); KAI_ASSERT(rhs != NULL); @@ -140,7 +140,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( const size_t num_bytes_multiplier_rhs = kai_num_bytes_datatype(params->scale_dt); const size_t rhs_packed_stride = - kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s16s0(k, nr, kr, sr, bl, params->scale_dt); + kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, params->scale_dt); const size_t rhs_packed_offset_end_of_all_blocks = kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); const size_t num_qblocks_per_row = kai_num_blocks_per_row(k, bl); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h similarity index 91% rename from kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h rename to kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h index b5bed391..33a7486a 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h @@ -17,7 +17,7 @@ extern "C" { #endif -struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params { +struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params { int8_t lhs_zero_point; uint8_t rhs_zero_point; enum kai_datatype scale_dt; @@ -30,7 +30,7 @@ struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params { /// @param[in] nr The number of columns written by the matmul micro-kernel /// /// @return the n step value -size_t kai_get_n_step_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(size_t nr); +size_t kai_get_n_step_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(size_t nr); /// Gets the offset in bytes for the RHS matrix (not packed), which holds /// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. @@ -47,7 +47,7 @@ size_t kai_get_n_step_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(size_t nr); /// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) /// /// @return the offset in bytes to the RHS matrix (not packed) -size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t n_idx, // size_t rhs_stride); // @@ -82,7 +82,7 @@ size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( /// @param[in] scale_dt Block scale data type /// /// @return the offset in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t n_idx, // size_t k, // size_t nr, // @@ -103,7 +103,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( /// @param[in] scale_dt Block scale data type /// /// @return the packed RHS matrix size in bytes -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t n, // size_t k, // size_t nr, // @@ -132,28 +132,28 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( /// @param[in] rhs_stride The number of bytes per row in bytes of the RHS matrix /// @param[in] bias The biases. /// @param[in] scale The per-block quantization scales. -/// The scale data type must be proviided with the params object. +/// The scale data type must be provided with the params object. /// Supported scale data types are FP32, FP16 and BF16. /// @param[in] scale_stride The number of bytes per row in bytes of the scale matrix /// @param[out] rhs_packed The packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. /// @param[in] params Parameters for the micro-kernel. -void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - size_t num_groups, // - size_t n, // - size_t k, // - size_t nr, // - size_t kr, // - size_t sr, // - size_t bl, // - const uint8_t* rhs, // - size_t rhs_stride, // - const float* bias, // - const void* scale, // - size_t scale_stride, // - void* rhs_packed, // - size_t extra_bytes, // - const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params* params); // +void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params); // #ifdef __cplusplus } -- GitLab From 9b09324d64a83a674743847918481012b89a4ee5 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 30 Aug 2024 10:05:13 +0100 Subject: [PATCH 16/29] Address review comments Signed-off-by: Gian Marco Iodice --- CHANGELOG.md | 3 + CMakeLists.txt | 2 +- .../CMakeLists.txt | 5 +- kai/kai_common.h | 4 +- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 7 +-- ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 7 +-- ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h | 7 +-- ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h | 7 +-- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | 10 ++-- ...matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp | 59 ++++++++++++------- 10 files changed, 61 insertions(+), 50 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index af6db5b7..9c7a7c3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## v0.3.0 - Upcoming Release - Advanced SIMD FP32 GEMM and GEMV micro kernels +- Micro-kernels to compute the matrix multiplication of dynamically quantized asymmetric signed 8-bit integer with per-row quantization (QAI8DX) LHS and quantized symmetric 4-bit unsigned integer with per-block quantization (QSU4C32) RHS. The destination matrix data type is single-precision floating-point (F32). The micro-kernels have been optimized using the ArmĀ® CPU feature FEAT_I8MM for the matrix-by-matrix cases and the FEAT_DotProd for the vector-by-matrix cases. +- RHS matrix packing micro-kernels to pack the RHS matrix holding the QSU4C32 values. +- Unit test and example for integer micro-kernels. ## v0.2.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index 6c7b5ba6..07be019a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,7 +79,7 @@ set(KLEIDIAI_FILES_SCALAR kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c - kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c ) set(KLEIDIAI_FILES_NEON_FP16 diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt index 10fbcc5b..eb6499e9 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt @@ -21,9 +21,12 @@ include_directories( add_executable(matmul_clamp_f32_qai8dxp_qsi4c32p matmul_clamp_f32_qai8dxp_qsi4c32p.cpp ${KLEIDIAI_PATH}/kai/kai_common.h - ${MATMUL_PACK_PATH}/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.c + ${MATMUL_PACK_PATH}/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_qai8dxp_f32.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c) + +target_compile_options(matmul_clamp_f32_qai8dxp_qsi4c32p + PRIVATE -march=armv8.2-a+dotprod+i8mm) diff --git a/kai/kai_common.h b/kai/kai_common.h index 790c2d26..f1cbdbb4 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -70,7 +70,7 @@ enum kai_datatype { /// @param[in] dt KleidiAI data type /// /// @return the numbers of bytes for the data type -inline static size_t kai_num_bytes_datatype(enum kai_datatype dt) { +inline static size_t kai_get_datatype_size_in_bytes(enum kai_datatype dt) { return (size_t)(dt >> 8); } @@ -90,7 +90,7 @@ inline static float kai_cast_f32_f16(uint16_t f16) { /// @param[in] bf16 The f16 value /// /// @return the f32 value -inline static float kai_bf16_to_f32(uint16_t bf16) { +inline static float kai_cast_bf16_f32(uint16_t bf16) { const uint32_t i32 = (bf16 << 16); float f32; memcpy(&f32, &i32, sizeof(i32)); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h index ce0aa648..32ec885e 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -6,9 +6,6 @@ // #pragma once -#ifndef __cplusplus -#include -#endif #include #ifdef __cplusplus @@ -18,7 +15,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 to pack the RHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix /// -------------------------------------------------- @@ -121,7 +118,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotp /// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs /// both the dynamic quantization to 8-bit and activation packing in a single step. /// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 +/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h index bbe749f1..b2be5c67 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -6,9 +6,6 @@ // #pragma once -#ifndef __cplusplus -#include -#endif #include #ifdef __cplusplus @@ -18,7 +15,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 to pack the RHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix /// -------------------------------------------------- @@ -121,7 +118,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotp /// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs /// both the dynamic quantization to 8-bit and activation packing in a single step. /// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 +/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h index 8b20a31d..4d185f6b 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h @@ -6,9 +6,6 @@ // #pragma once -#ifndef __cplusplus -#include -#endif #include #ifdef __cplusplus @@ -18,7 +15,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 to pack the RHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix /// -------------------------------------------------- @@ -121,7 +118,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm /// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs /// both the dynamic quantization to 8-bit and activation packing in a single step. /// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 +/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h index 31614599..36c4330b 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h @@ -6,9 +6,6 @@ // #pragma once -#ifndef __cplusplus -#include -#endif #include #ifdef __cplusplus @@ -18,7 +15,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 to pack the RHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix /// -------------------------------------------------- @@ -121,7 +118,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm /// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs /// both the dynamic quantization to 8-bit and activation packing in a single step. /// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0 +/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index 687f3cef..014a4fc9 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -38,7 +38,7 @@ inline static size_t kai_rhs_packed_offset_end_of_all_blocks( return (nr * num_bytes_per_block * num_blocks_per_row); } -size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr) { +size_t kai_get_n_step_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(size_t nr) { return nr; } @@ -63,7 +63,7 @@ size_t kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s1s0( KAI_UNUSED(kr); KAI_UNUSED(sr); - const size_t num_bytes_multiplier_rhs = kai_num_bytes_datatype(scale_dt); + const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); @@ -138,7 +138,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( // Note: The input matrix (rhs) is expected with: // "k" columns and "n" rows (NxK) - const size_t num_bytes_multiplier_rhs = kai_num_bytes_datatype(params->scale_dt); + const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(params->scale_dt); const size_t rhs_packed_stride = kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, params->scale_dt); const size_t rhs_packed_offset_end_of_all_blocks = @@ -239,10 +239,10 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( d = ((float*)rhs_packed_scale)[nr_idx]; break; case F16: - d = kai_f16_to_f32(((uint16_t*)rhs_packed_scale)[nr_idx]); + d = kai_cast_f16_f32(((uint16_t*)rhs_packed_scale)[nr_idx]); break; case Bf16: - d = kai_bf16_to_f32(((uint16_t*)rhs_packed_scale)[nr_idx]); + d = kai_cast_bf16_f32(((uint16_t*)rhs_packed_scale)[nr_idx]); break; default: KAI_ERROR("Unsupported scale data type"); diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp index 8988793a..5302cccc 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -16,10 +16,11 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" -#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" #include "test/common/bfloat16.hpp" #include "test/common/int4.hpp" #include "test/common/memory.hpp" +#include "test/common/round.hpp" #include "test/reference/cast.hpp" #include "test/reference/fill.hpp" #include "test/reference/matmul.hpp" @@ -70,14 +71,18 @@ TEST(matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, EndTOEnd) { // * Packs the RHS matrix. const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); + const size_t ref_rhs_qsu4_stride = round_up_division(K, 2); + const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); + const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(N, K, nr, kr, bl, kai_datatype::Bf16); + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::Bf16); std::vector imp_packed_rhs(imp_packed_rhs_size); - const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params params{ + const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{ .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), nullptr, reinterpret_cast(ref_rhs_scales.data()), - imp_packed_rhs.data(), 0, ¶ms); + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), ref_rhs_qsu4_stride, nullptr, + reinterpret_cast(ref_rhs_scales.data()), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, + ¶ms); // Runs the GEMM micro-kernel. const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(M, N); @@ -142,14 +147,18 @@ TEST(kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, EndTOEnd) // * Packs the RHS matrix. const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); + const size_t ref_rhs_qsu4_stride = round_up_division(K, 2); + const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); + const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(N, K, nr, kr, bl, kai_datatype::Bf16); + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::Bf16); std::vector imp_packed_rhs(imp_packed_rhs_size); - const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params params{ + const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{ .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), nullptr, reinterpret_cast(ref_rhs_scales.data()), - imp_packed_rhs.data(), 0, ¶ms); + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), ref_rhs_qsu4_stride, nullptr, + reinterpret_cast(ref_rhs_scales.data()), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, + ¶ms); // Runs the GEMM micro-kernel. const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(M, N); @@ -214,14 +223,18 @@ TEST(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, EndTOEnd) { // * Packs the RHS matrix. const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); + const size_t ref_rhs_qsu4_stride = round_up_division(K, 2); + const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); + const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(N, K, nr, kr, bl, kai_datatype::Bf16); + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::Bf16); std::vector imp_packed_rhs(imp_packed_rhs_size); - const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params params{ + const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{ .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), nullptr, reinterpret_cast(ref_rhs_scales.data()), - imp_packed_rhs.data(), 0, ¶ms); + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), ref_rhs_qsu4_stride, nullptr, + reinterpret_cast(ref_rhs_scales.data()), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, + ¶ms); // Runs the GEMM micro-kernel. const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(M, N); @@ -286,14 +299,18 @@ TEST(matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, EndTOEnd) { // * Packs the RHS matrix. const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); + const size_t ref_rhs_qsu4_stride = round_up_division(K, 2); + const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); + const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0(N, K, nr, kr, bl, kai_datatype::Bf16); + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::Bf16); std::vector imp_packed_rhs(imp_packed_rhs_size); - const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0_params params{ + const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{ .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s16s0( - 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), nullptr, reinterpret_cast(ref_rhs_scales.data()), - imp_packed_rhs.data(), 0, ¶ms); + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), ref_rhs_qsu4_stride, nullptr, + reinterpret_cast(ref_rhs_scales.data()), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, + ¶ms); // Runs the GEMM micro-kernel. const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(M, N); -- GitLab From a8faa56732156d7a623d388dbc43165574072430 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 30 Aug 2024 10:33:04 +0100 Subject: [PATCH 17/29] Trivial fixes to function names in RHS packing Signed-off-by: Anitha Raj --- .../matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | 2 +- .../matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index 014a4fc9..e32d5c67 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -48,7 +48,7 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( return n_idx * rhs_stride; } -size_t kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s1s0( +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t k, // size_t nr, // size_t kr, // diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h index 33a7486a..7e5ac118 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h @@ -62,7 +62,7 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( /// @param[in] scale_dt Block scale data type /// /// @return the stride in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t k, // size_t nr, // size_t kr, // -- GitLab From 3a27c3323789843acddd2f706ce87142ade3843c Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 30 Aug 2024 10:57:19 +0100 Subject: [PATCH 18/29] More fixes to RHS packing function names Signed-off-by: Anitha Raj --- .../matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index e32d5c67..1f1416d1 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -84,7 +84,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((bl % kai_bl_multiple_of) == 0); KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); - return (n_idx / nr) * kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); + return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); } size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( @@ -102,7 +102,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); + return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); } void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( @@ -140,7 +140,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(params->scale_dt); const size_t rhs_packed_stride = - kai_get_rhs_packed_stride_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, params->scale_dt); + kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, params->scale_dt); const size_t rhs_packed_offset_end_of_all_blocks = kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); const size_t num_qblocks_per_row = kai_num_blocks_per_row(k, bl); -- GitLab From 19c6d4a00753322f3bfc1775d1ee7e4f89af8189 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 30 Aug 2024 11:30:33 +0100 Subject: [PATCH 19/29] Address review comments * Add new files to BUILD.bazel * Add new test and example to .gitlab-ci.yml Signed-off-by: Anitha Raj --- .gitlab-ci.yml | 2 ++ kai/ukernels/matmul/BUILD.bazel | 58 +++++++++++++++++++++++++++++++++ test/BUILD.bazel | 1 + 3 files changed, 61 insertions(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 431edc88..4e84a67d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -122,6 +122,7 @@ build-examples: - matmul_clamp_f16_f16_f16p - matmul_clamp_f32_qai8dxp_qsi4cxp - matmul_clamp_f32_qsi8d32p_qsi4c32p + - matmul_clamp_f32_qai8dxp_qsi4c32p script: - mkdir -p build/$EXAMPLE - cmake -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release -S examples/$EXAMPLE -B build/$EXAMPLE @@ -143,6 +144,7 @@ test-examples: - matmul_clamp_f16_f16_f16p - matmul_clamp_f32_qai8dxp_qsi4cxp - matmul_clamp_f32_qsi8d32p_qsi4c32p + - matmul_clamp_f32_qai8dxp_qsi4c32p script: - build/${EXAMPLE}/${EXAMPLE} | tee -a ${EXAMPLE}.log artifacts: diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 44245175..b699af56 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -208,6 +208,58 @@ kai_c_library( cpu_uarch = kai_cpu_neon(), ) +cc_library( + name = "clamp_f32_qai8dxp_qsi4c32p_interface", + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h"], +) + +kai_c_library( + name = "clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"], + cpu_uarch = kai_cpu_dotprod(), + deps = [ + ":clamp_f32_qai8dxp_qsi4c32p_interface", + ], +) + +kai_c_library( + name = "clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h"], + cpu_uarch = kai_cpu_dotprod(), + deps = [ + ":clamp_f32_qai8dxp_qsi4c32p_interface", + ], +) + +kai_c_library( + name = "clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h"], + cpu_uarch = kai_cpu_i8mm(), + deps = [ + ":clamp_f32_qai8dxp_qsi4c32p_interface", + ], +) + +kai_c_library( + name = "clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h"], + cpu_uarch = kai_cpu_i8mm(), + deps = [ + ":clamp_f32_qai8dxp_qsi4c32p_interface", + ], +) + +kai_c_library( + name = "rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", + srcs = ["pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c"], + hdrs = ["pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h"], + cpu_uarch = kai_cpu_neon(), +) + kai_c_library( name = "matmul", deps = [ @@ -215,12 +267,17 @@ kai_c_library( ":clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla", ":clamp_f32_f32_f32p", ":clamp_f32_f32p_f32p", + ":clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod", + ":clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod", ":clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod", ":clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod", + ":clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", + ":clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm", + ":clamp_f32_qai8dxp_qsi4c32p_interface", ":clamp_f32_qsi8d32p_qsi4c32p_dotprod", ":clamp_f32_qsi8d32p_qsi4c32p_i8mm", ":clamp_f32_qsi8d32p_qsi4c32p_interface", @@ -230,6 +287,7 @@ kai_c_library( ":rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", ":rhs_pack_kxn_qsi4cxp_qsu4cxs1s0", + ":rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", ":rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", ":rhs_pack_nxk_qsi4cxp_qsu4cxs1s0", ], diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 2ad11215..c806b5d5 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -53,6 +53,7 @@ kai_cxx_library( cc_test( name = "kleidiai_test", srcs = [ + "tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp", "tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp", "tests/matmul_test.cpp", ], -- GitLab From 1162d8b532458aca77211d3279fb134ad3e1343e Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 30 Aug 2024 14:17:22 +0100 Subject: [PATCH 20/29] Implement RHS packing function for the non-transposed case - Write a new ukernel to pack the RHS matrix when the RHS input matrix is non-transposed (kxn) - Extend the example to validate the non-transposed case Signed-off-by: Gian Marco Iodice --- CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 513 +++++++++++++----- .../kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c | 273 ++++++++++ .../kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h | 153 ++++++ .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h | 19 +- 6 files changed, 797 insertions(+), 163 deletions(-) create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 07be019a..3e198665 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,6 +76,7 @@ endif() set(KLEIDIAI_FILES_SCALAR kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt index eb6499e9..47527352 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt @@ -21,6 +21,7 @@ include_directories( add_executable(matmul_clamp_f32_qai8dxp_qsi4c32p matmul_clamp_f32_qai8dxp_qsi4c32p.cpp ${KLEIDIAI_PATH}/kai/kai_common.h + ${MATMUL_PACK_PATH}/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c ${MATMUL_PACK_PATH}/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_qai8dxp_f32.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index 9b262c01..65ba5f53 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -21,11 +22,17 @@ #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" +#include "kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" #include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" #define INT4_MIN (-8) #define INT4_MAX (7) +enum class rhs_format { + nxk, + kxn, +}; + // Micro-kernel interface struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel; @@ -94,8 +101,8 @@ static inline size_t get_num_blocks_per_row(size_t k, size_t bl) { return roundup(k, bl) / bl; } -static inline size_t get_rhs_native_stride(size_t k) { - return roundup(k, 2) / 2; +static inline size_t get_rhs_native_stride(size_t x) { + return roundup(x, 2) / 2; } static inline size_t get_rhs_scale_stride(size_t k, size_t bl) { @@ -125,11 +132,14 @@ static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, si } } -static void quant_qs4c32_f32( +static void quant_nxk_qs4c32_f32( size_t n, size_t k, size_t bl, const float* rhs_f32, uint8_t* rhs_qs4c32, uint16_t* rhs_scales_bf16) { const size_t num_blocks_row = get_num_blocks_per_row(k, bl); const size_t rhs_qs4c32_stride = get_rhs_native_stride(k); + // Make sure the output is filled with zeros + std::memset(rhs_qs4c32, 0, n * rhs_qs4c32_stride); + for (size_t row_idx = 0; row_idx < n; ++row_idx) { const float* src_ptr = rhs_f32 + row_idx * k; @@ -194,6 +204,88 @@ static void quant_qs4c32_f32( } } +static void quant_kxn_qs4c32_f32( + size_t n, size_t k, size_t bl, const float* rhs_f32, uint8_t* rhs_qs4c32, uint16_t* rhs_scales_bf16) { + const size_t num_blocks_row = get_num_blocks_per_row(k, bl); + const size_t rhs_qs4c32_stride = get_rhs_native_stride(n); + + // Make sure the output is filled with zeros + std::memset(rhs_qs4c32, 0, k * rhs_qs4c32_stride); + + for (size_t row_idx = 0; row_idx < n; ++row_idx) { + const float* src_ptr = rhs_f32 + row_idx * k; + + for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { + float amax = 0.0f; + float max = 0.0f; + + for (size_t b = 0; b < bl; ++b) { + const size_t k_idx = block_idx * bl + b; + + if (k_idx >= k) { + break; + } + + const float src0_0 = src_ptr[k_idx]; + const float asrc0_0 = fabsf(src0_0); + + if (amax < asrc0_0) { + amax = asrc0_0; + max = src0_0; + } + } + + const float scale = max / -8.0; + const float recip_scale = scale ? 1.0f / scale : 0.0f; + + // Store the scale in the dedicated buffer + *rhs_scales_bf16 = convert_f32_to_bf16(scale); + + rhs_scales_bf16 += 1; + + for (size_t i = 0; i < bl; ++i) { + const size_t k_idx = block_idx * bl + i; + + if (k_idx >= k) { + break; + } + + const float src0_0 = src_ptr[k_idx]; + + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * recip_scale)); + + // Maximum/minimum int4 values + v0_s32 = std::max(v0_s32, INT4_MIN); + v0_s32 = std::min(v0_s32, INT4_MAX); + + const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8); + + const size_t dst_addr = (row_idx / 2) + k_idx * rhs_qs4c32_stride; + uint8_t rhs_v0 = rhs_qs4c32[dst_addr]; + + if ((row_idx % 2) == 0) { + rhs_v0 = v0_u8; + } else { + rhs_v0 |= (v0_u8 << 4); + } + + rhs_qs4c32[dst_addr] = rhs_v0; + } + } + } +} + +static void quant_qs4cx_f32( + size_t n, size_t k, size_t bl, rhs_format format, const float* rhs_f32, uint8_t* rhs_qs4c32, + uint16_t* rhs_scales_bf16) { + if (rhs_format::nxk == format) { + quant_nxk_qs4c32_f32(n, k, bl, rhs_f32, rhs_qs4c32, rhs_scales_bf16); + } else { + quant_kxn_qs4c32_f32(n, k, bl, rhs_f32, rhs_qs4c32, rhs_scales_bf16); + } +}; + static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { const size_t dst_stride = (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); @@ -262,7 +354,7 @@ static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t } } -static void ref_matmul_f32_qa8dx_qs4c32( +static void ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4c32( size_t m, size_t n, size_t k, size_t bl, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4c32, const uint16_t* scale_bf16, float* dst_f32, float scalar_min, float scalar_max) { const size_t num_blocks_row = get_num_blocks_per_row(k, bl); @@ -339,6 +431,93 @@ static void ref_matmul_f32_qa8dx_qs4c32( } }; +static void ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4c32( + size_t m, size_t n, size_t k, size_t bl, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4c32, + const uint16_t* scale_bf16, float* dst_f32, float scalar_min, float scalar_max) { + const size_t num_blocks_row = get_num_blocks_per_row(k, bl); + + const size_t lhs_stride = k + sizeof(float) + sizeof(int32_t); + const size_t rhs_stride = get_rhs_native_stride(n); + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; + + for (size_t col_idx = 0; col_idx < n; ++col_idx) { + // Main f32 accumulator + float main_acc = 0.0f; + + const int8_t* lhs_ptr = lhs_ptr_start; + const uint8_t* rhs_ptr = rhs_qs4c32 + (col_idx / 2); + + // Get the LHS quantization parameters stored at the + // beginning of each row + const float lhs_scale = *(const float*)lhs_ptr; + lhs_ptr += sizeof(float); + + const int32_t lhs_offset = *(const int32_t*)lhs_ptr; + lhs_ptr += sizeof(int32_t); + + for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { + const uint16_t rhs_scale_bf16 = scale_bf16[block_idx + col_idx * num_blocks_row]; + const float rhs_scale = convert_bf16_to_f32(rhs_scale_bf16); + + int32_t iacc = 0; + + for (size_t i = 0; i < bl; ++i) { + const size_t k_idx = block_idx * bl + i; + + if (k_idx >= k) { + break; + } + + // Get the LHS values + const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; + + // Get the RHS values + const uint8_t rhs_byte = rhs_ptr[0]; + + // Unpack the RHS values + int32_t rhs_v0 = 0; + if ((col_idx % 2) == 0) { + rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8); + } else { + rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8); + } + + iacc += lhs_v0 * rhs_v0; + iacc += lhs_offset * rhs_v0; + + lhs_ptr += 1; + rhs_ptr += rhs_stride; + } + + main_acc += iacc * rhs_scale; + } + + main_acc = main_acc * lhs_scale; + + // Clamp (min-max) operation + main_acc = std::max(main_acc, scalar_min); + main_acc = std::min(main_acc, scalar_max); + + dst_f32[0] = main_acc; + dst_f32 += 1; + } + } +}; + +static void ref_matmul_f32_qa8dx_qs4c32( + size_t m, size_t n, size_t k, size_t bl, rhs_format format, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4c32, + const uint16_t* rhs_scales_bf16, float* dst_f32, float scalar_min, float scalar_max) { + if (rhs_format::nxk == format) { + ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4c32( + m, n, k, bl, lhs_qa8dx, rhs_qs4c32, rhs_scales_bf16, dst_f32, scalar_min, scalar_max); + } else { + ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4c32( + m, n, k, bl, lhs_qa8dx, rhs_qs4c32, rhs_scales_bf16, dst_f32, scalar_min, scalar_max); + } +}; + static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, const float* ref, const float* act) { bool is_valid = true; @@ -364,158 +543,192 @@ int main() { std::cout << "------------" << std::endl; - const size_t lhs_native_size_f32 = m * k * sizeof(float); - const size_t rhs_native_size_f32 = n * k * sizeof(float); - const size_t rhs_native_size_qs4c32 = n * get_rhs_native_stride(k); - const size_t rhs_scales_size_bf16 = n * get_rhs_scale_stride(k, bl); - - // Allocate the memory - uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; - uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; - uint8_t* rhs_native_mtx_qs4c32 = new uint8_t[rhs_native_size_qs4c32]; - uint8_t* rhs_scales_mtx_bf16 = new uint8_t[rhs_scales_size_bf16]; - - fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); - fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); - - quant_qs4c32_f32( - n, k, bl, // Dimensions - (const float*)rhs_native_mtx_f32, // RHS (F32) - rhs_native_mtx_qs4c32, // RHS (QS4C32) - (uint16_t*)rhs_scales_mtx_bf16); // Scales (Bf16) - - delete[] rhs_native_mtx_f32; - - //----------- REFERENCE IMPLEMENTATION - //------------------------------------ - //------------------------------------ - // Memory sizes for the reference implementation - // After dynamically quantized the LHS matrix, we have the scale and offset for each - // row. The scale (f32) and offset (int32) are stored at the beginning of each row - const size_t lhs_ref_size_qa8dx = m * (k + sizeof(int32_t) + sizeof(float)); - const size_t dst_ref_size_f32 = m * n * sizeof(float); - - uint8_t* lhs_ref_mtx_qa8dx = new uint8_t[lhs_ref_size_qa8dx]; - uint8_t* dst_ref_mtx_f32 = new uint8_t[dst_ref_size_f32]; - - ref_quant_qa8dx_f32(m, k, (const float*)lhs_native_mtx_f32, (int8_t*)lhs_ref_mtx_qa8dx); - - ref_matmul_f32_qa8dx_qs4c32( - m, n, k, // Dimensions - bl, // Block length - (const int8_t*)lhs_ref_mtx_qa8dx, // LHS - (const uint8_t*)rhs_native_mtx_qs4c32, // RHS - (const uint16_t*)rhs_scales_mtx_bf16, // Scale - (float*)dst_ref_mtx_f32, // DST - -FLT_MAX, FLT_MAX); // Min and max for the clamp operation - - // Remove the unnecessary buffer - delete[] lhs_ref_mtx_qa8dx; - - //----------- END REFERENCE IMPLEMENTATION - //------------------------------------ - //------------------------------------ - - //----------- MICRO-KERNELS TESTS - //------------------------------------ - //------------------------------------ - for (size_t idx_variant = 0; idx_variant < num_ukernel_variants; ++idx_variant) { - // Get the packing parameters - const size_t mr = ukernel_variants[idx_variant].ukernel.get_mr(); - const size_t nr = ukernel_variants[idx_variant].ukernel.get_nr(); - const size_t kr = ukernel_variants[idx_variant].ukernel.get_kr(); - const size_t sr = ukernel_variants[idx_variant].ukernel.get_sr(); - - // Get the size in bytes for the packed matrices - const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); - const size_t rhs_packed_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, Bf16); - const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); - - // Allocate the matrices - uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; - uint8_t* rhs_packed_mtx_qs4c32 = new uint8_t[rhs_packed_size]; - uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; - - memset(dst_act_mtx_f32, 0, dst_size); - - // If the RHS matrix contains constant values, the packing can be performed - // only once - struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - params.scale_dt = Bf16; - - // RHS packing - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - 1, n, k, // Dimensions - nr, kr, sr, // Packing arguments - bl, // Block length - (const uint8_t*)(rhs_native_mtx_qs4c32), // RHS - get_rhs_native_stride(k), // RHS stride - NULL, // Bias - rhs_scales_mtx_bf16, // Scale - get_rhs_scale_stride(k, bl), // Scale stride - rhs_packed_mtx_qs4c32, // RHS packed - 0, ¶ms); - - const auto time_s = std::chrono::high_resolution_clock::now(); - - // LHS packing - kai_run_lhs_quant_pack_qai8dxp_f32( - m, k, // Dimensions - mr, kr, sr, 0, // Packing arguments - (const float*)lhs_native_mtx_f32, // LHS - k * sizeof(float), // LHS stride - lhs_packed_mtx_qa8dx); // LHS packed - - // Matmul - { - const size_t dst_stride = n * sizeof(float); - const size_t lhs_offset = ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k); - const size_t rhs_offset = ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(0, k, bl); - const size_t dst_offset = ukernel_variants[idx_variant].ukernel.get_dst_offset(0, 0, dst_stride); - - const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); - const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4c32 + rhs_offset); - float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); - - ukernel_variants[idx_variant].ukernel.run_matmul( - m, n, k, // Dimensions - bl, // Block length - lhs_ptr, // LHS packed - rhs_ptr, // RHS packed - dst_ptr, // DST - dst_stride, // DST stride (row) - sizeof(float), // DST stride (col) - -FLT_MAX, FLT_MAX // Min and max for the clamp operation - ); - } + // Iterate over the RHS format (NxK or KxN) + for (const rhs_format& format : {rhs_format::nxk, rhs_format::kxn}) { + std::cout << "Testing RHS format = " << (format == rhs_format::nxk ? "N x K" : "K x N") << std::endl; + + const size_t lhs_native_size_f32 = m * k * sizeof(float); + const size_t rhs_native_size_f32 = n * k * sizeof(float); + const size_t rhs_native_size_qs4c32 = + format == rhs_format::nxk ? n * get_rhs_native_stride(k) : k * get_rhs_native_stride(n); + const size_t rhs_scales_size_bf16 = n * get_rhs_scale_stride(k, bl); + + // Allocate the memory + uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; + uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; + uint8_t* rhs_native_mtx_qs4c32 = new uint8_t[rhs_native_size_qs4c32]; + uint8_t* rhs_scales_mtx_bf16 = new uint8_t[rhs_scales_size_bf16]; + + fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); + fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); + + quant_qs4cx_f32( + n, k, bl, // Dimensions + format, // Format (NxK or KxN) + (const float*)rhs_native_mtx_f32, // RHS (F32) + rhs_native_mtx_qs4c32, // RHS (QS4C32) + (uint16_t*)rhs_scales_mtx_bf16); // Scales (Bf16) + + delete[] rhs_native_mtx_f32; + + //----------- REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + // Memory sizes for the reference implementation + // After dynamically quantized the LHS matrix, we have the scale and offset for each + // row. The scale (f32) and offset (int32) are stored at the beginning of each row + const size_t lhs_ref_size_qa8dx = m * (k + sizeof(int32_t) + sizeof(float)); + const size_t dst_ref_size_f32 = m * n * sizeof(float); + + uint8_t* lhs_ref_mtx_qa8dx = new uint8_t[lhs_ref_size_qa8dx]; + uint8_t* dst_ref_mtx_f32 = new uint8_t[dst_ref_size_f32]; + + ref_quant_qa8dx_f32(m, k, (const float*)lhs_native_mtx_f32, (int8_t*)lhs_ref_mtx_qa8dx); + + ref_matmul_f32_qa8dx_qs4c32( + m, n, k, // Dimensions + bl, // Block length + format, // Format (NxK or KxN) + (const int8_t*)lhs_ref_mtx_qa8dx, // LHS + (const uint8_t*)rhs_native_mtx_qs4c32, // RHS + (const uint16_t*)rhs_scales_mtx_bf16, // Scale + (float*)dst_ref_mtx_f32, // DST + -FLT_MAX, FLT_MAX); // Min and max for the clamp operation + + // Remove the unnecessary buffer + delete[] lhs_ref_mtx_qa8dx; + + //----------- END REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + + //----------- MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + for (size_t idx_variant = 0; idx_variant < num_ukernel_variants; ++idx_variant) { + // Get the packing parameters + const size_t mr = ukernel_variants[idx_variant].ukernel.get_mr(); + const size_t nr = ukernel_variants[idx_variant].ukernel.get_nr(); + const size_t kr = ukernel_variants[idx_variant].ukernel.get_kr(); + const size_t sr = ukernel_variants[idx_variant].ukernel.get_sr(); + + // Get the size in bytes for the packed matrices + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); + size_t rhs_packed_size = 0; + + if (format == rhs_format::nxk) { + rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, Bf16); + + } else { + rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, Bf16); + } - const auto time_e = std::chrono::high_resolution_clock::now(); + const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); + + // Allocate the matrices + uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; + uint8_t* rhs_packed_mtx_qs4c32 = new uint8_t[rhs_packed_size]; + uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; + + memset(dst_act_mtx_f32, 0, dst_size); + + // If the RHS matrix contains constant values, the packing can be performed + // only once + if (format == rhs_format::nxk) { + struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = Bf16; + + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + 1, n, k, // Dimensions + nr, kr, sr, // Packing arguments + bl, // Block length + (const uint8_t*)(rhs_native_mtx_qs4c32), // RHS + get_rhs_native_stride(k), // RHS stride + NULL, // Bias + rhs_scales_mtx_bf16, // Scale + get_rhs_scale_stride(k, bl), // Scale stride + rhs_packed_mtx_qs4c32, // RHS packed + 0, ¶ms); + + } else { + struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = Bf16; + + kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + 1, n, k, // Dimensions + nr, kr, sr, // Packing arguments + bl, // Block length + (const uint8_t*)(rhs_native_mtx_qs4c32), // RHS + get_rhs_native_stride(n), // RHS stride + NULL, // Bias + rhs_scales_mtx_bf16, // Scale + get_rhs_scale_stride(k, bl), // Scale stride + rhs_packed_mtx_qs4c32, // RHS packed + 0, ¶ms); + } - const auto elap = std::chrono::duration_cast(time_e - time_s); + const auto time_s = std::chrono::high_resolution_clock::now(); + + // LHS packing + kai_run_lhs_quant_pack_qai8dxp_f32( + m, k, // Dimensions + mr, kr, sr, 0, // Packing arguments + (const float*)lhs_native_mtx_f32, // LHS + k * sizeof(float), // LHS stride + lhs_packed_mtx_qa8dx); // LHS packed + + // Matmul + { + const size_t dst_stride = n * sizeof(float); + const size_t lhs_offset = ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k); + const size_t rhs_offset = ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(0, k, bl); + const size_t dst_offset = ukernel_variants[idx_variant].ukernel.get_dst_offset(0, 0, dst_stride); + + const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); + const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4c32 + rhs_offset); + float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); + + ukernel_variants[idx_variant].ukernel.run_matmul( + m, n, k, // Dimensions + bl, // Block length + lhs_ptr, // LHS packed + rhs_ptr, // RHS packed + dst_ptr, // DST + dst_stride, // DST stride (row) + sizeof(float), // DST stride (col) + -FLT_MAX, FLT_MAX // Min and max for the clamp operation + ); + } + + const auto time_e = std::chrono::high_resolution_clock::now(); - const bool is_valid = - is_output_correct(m, n, 0.0001f, (const float*)dst_ref_mtx_f32, (const float*)dst_act_mtx_f32); + const auto elap = std::chrono::duration_cast(time_e - time_s); - std::cout << "TEST[" << idx_variant << "]: Dynamic quantization + matmul" << std::endl; - std::cout << "- ukernel: " << ukernel_variants[idx_variant].name << std::endl; - if (is_valid) { - std::cout << "- Status: PASSED" << std::endl; - std::cout << "- Performance: " << elap.count() << " us" << std::endl; - } else { - std::cout << "Status: FAILED" << std::endl; + const bool is_valid = + is_output_correct(m, n, 0.0001f, (const float*)dst_ref_mtx_f32, (const float*)dst_act_mtx_f32); + + std::cout << "TEST[" << idx_variant << "]: Dynamic quantization + matmul" << std::endl; + std::cout << "- ukernel: " << ukernel_variants[idx_variant].name << std::endl; + if (is_valid) { + std::cout << "- Status: PASSED" << std::endl; + std::cout << "- Performance: " << elap.count() << " us" << std::endl; + } else { + std::cout << "Status: FAILED" << std::endl; + } + std::cout << "------------" << std::endl; + delete[] lhs_packed_mtx_qa8dx; + delete[] rhs_packed_mtx_qs4c32; + delete[] dst_act_mtx_f32; } - std::cout << "------------" << std::endl; - delete[] lhs_packed_mtx_qa8dx; - delete[] rhs_packed_mtx_qs4c32; - delete[] dst_act_mtx_f32; + delete[] lhs_native_mtx_f32; + delete[] rhs_native_mtx_qs4c32; + delete[] rhs_scales_mtx_bf16; + delete[] dst_ref_mtx_f32; } - delete[] lhs_native_mtx_f32; - delete[] rhs_native_mtx_qs4c32; - delete[] rhs_scales_mtx_bf16; - delete[] dst_ref_mtx_f32; } //----------- END MICRO-KERNELS TESTS diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c new file mode 100644 index 00000000..d60ef3fd --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c @@ -0,0 +1,273 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); +static const size_t kai_nr_multiple_of = 4; +static const size_t kai_bl_multiple_of = 32; + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return (bl / 2) + num_bytes_multiplier_rhs; +} + +inline static size_t kai_rhs_packed_offset_end_of_all_blocks( + size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + + return (nr * num_bytes_per_block * num_blocks_per_row); +} + +size_t kai_get_n_step_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(size_t nr) { + return nr; +} + +size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t n_idx, // + size_t rhs_stride) { + KAI_UNUSED(rhs_stride); + KAI_ASSERT((n_idx % 2) == 0); + return (n_idx / 2) * sizeof(int8_t); +} + +size_t kai_get_rhs_packed_stride_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); + + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + + return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((n_idx % nr) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); + + return (n_idx / nr) * kai_get_rhs_packed_stride_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); +} + +size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); + + const size_t num_rows = kai_roundup(n, nr) / nr; + + return num_rows * kai_get_rhs_packed_stride_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); +} + +void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(params->scale_dt == F32 || params->scale_dt == F16 || params->scale_dt == Bf16); + + // Note: The input matrix (rhs) is expected with: + // "n" columns and "k" rows (kxn) + + const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(params->scale_dt); + const size_t rhs_packed_stride = + kai_get_rhs_packed_stride_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, params->scale_dt); + const size_t rhs_packed_offset_end_of_all_blocks = + kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); + const size_t num_qblocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + const size_t num_bytes_per_block_k = bl / 2; + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t k_interleaved_v = 16U; + const size_t block_length_in_bytes = kr / sr; + + const int32_t rhs_zero_point = params->rhs_zero_point; + const enum kai_datatype scale_dt = params->scale_dt; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + // Before packing, it keeps the pointer to the first quantized block + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); + + // Initialize the RHS reduction sums to zero + memset(sums, 0, nr * kai_num_bytes_sum_rhs); + + // Iterate over the quantized blocks + for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) { + // Store the scales after packing all K values + void* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr; + + for (size_t i = 0; i < nr; ++i) { + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs; + void* src_scales_ptr = (void*)(scale + dst_qblock_idx * num_bytes_multiplier_rhs + // + (src_row_idx * scale_stride)); // + + memcpy( + dst_scales_ptr, // + src_scales_ptr, // + num_bytes_multiplier_rhs); // + } + + for (size_t dst_byte_idx = 0; dst_byte_idx < nr * num_bytes_per_block_k; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = + dst_qblock_idx * bl + block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = (n0_valid_idx / 2) + k0_idx * rhs_stride; + const size_t src_addr_byte1 = (n0_valid_idx / 2) + k1_idx * rhs_stride; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + float d = 0.0F; + switch (scale_dt) { + case F32: + d = ((float*)rhs_packed_scale)[nr_idx]; + break; + case F16: + d = kai_cast_f16_f32(((uint16_t*)rhs_packed_scale)[nr_idx]); + break; + case Bf16: + d = kai_cast_bf16_f32(((uint16_t*)rhs_packed_scale)[nr_idx]); + break; + default: + KAI_ERROR("Unsupported scale data type"); + break; + } + + if ((n0_idx % 2) == 0) { + const uint8_t src_x0_lo = (byte0 & 0x0F); + const uint8_t src_x0_hi = (byte1 & 0x0F); + + sums[nr_idx] += ((int32_t)src_x0_lo - rhs_zero_point) * d; + sums[nr_idx] += ((int32_t)src_x0_hi - rhs_zero_point) * d; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + dst_row[dst_byte_idx] = dst_qs0 ^ 0x88; + } else { + const uint8_t src_x1_lo = (byte0 >> 4); + const uint8_t src_x1_hi = (byte1 >> 4); + + sums[nr_idx] += ((int32_t)src_x1_lo - rhs_zero_point) * d; + sums[nr_idx] += ((int32_t)src_x1_hi - rhs_zero_point) * d; + + const uint8_t dst_qs1 = src_x1_lo | (src_x1_hi << 4); + + dst_row[dst_byte_idx] = dst_qs1 ^ 0x88; + } + } + // Move the pointer after K values + dst_row += num_bytes_per_block * nr; + } + + // Move the pointer after the row sum + dst_row += kai_num_bytes_sum_rhs * nr; + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + + // Move the pointer after the biases + dst_row += kai_num_bytes_bias * nr; + } +} diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h new file mode 100644 index 00000000..e8874f98 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h @@ -0,0 +1,153 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params { + int8_t lhs_zero_point; + uint8_t rhs_zero_point; + enum kai_datatype scale_dt; +}; + +/// Get the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// +/// @return the n step value +size_t kai_get_n_step_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(size_t nr); + +/// Gets the offset in bytes for the RHS matrix (not packed), which holds +/// the int4 values in a K x N matrix, where N is number of columns and K is the number of rows. +/// +/// Two int4 values are stored in one byte. +/// The lower order part of the byte (low) holds the first nibble (K-index + 0). +/// The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t n_idx, // + size_t rhs_stride); // + +/// Get the row stride in bytes to the packed RHS matrix +/// +/// @param[in] k The number of rows in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the stride in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Gets the offset in bytes for the packed RHS matrix. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The number of rows in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Gets the size in bytes for the quantized and packed RHS matrix. +/// +/// @param[in] n The number of columns in the RHS matrix (not packed). +/// @param[in] k The number of rows in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple +/// of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Runs the RHS packing micro-kernel. +/// +/// The int4 values are stored in a K x N matrix, where N is number of columns and K is the number of rows. +/// Two int4 values are stored in one byte. The lower order part of the byte (low) holds +/// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of columns in the RHS matrix (not packed). +/// @param[in] k The number of rows in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel. It must be a multiple of 4. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] bl The block length, which defines the number of +/// K values stored in a single block. It must be a multiple of 32. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n/// k/// (sizeof(uint8_t) / 2). +/// @param[in] rhs_stride The number of bytes per row in bytes of the RHS matrix +/// @param[in] bias The biases. +/// @param[in] scale The per-block quantization scales. +/// The scale data type must be provided with the params object. +/// Supported scale data types are FP32, FP16 and BF16. +/// @param[in] scale_stride The number of bytes per row in bytes of the scale matrix +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params* params); // + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h index 7e5ac118..4d073878 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h @@ -5,9 +5,6 @@ // #pragma once -#ifndef __cplusplus -#include -#endif #include #include @@ -35,13 +32,9 @@ size_t kai_get_n_step_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(size_t nr); /// Gets the offset in bytes for the RHS matrix (not packed), which holds /// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. /// -/// Two int4 K values are stored in one byte. These values are stored in blocks, where each block -/// has it own scale factor. The quantization scale factors are stored in a separate buffer. -/// The first byte in the block holds the K-index + 0 and K-index + 16 values. -/// The K-index + 0 value is stored in the lower order part of the byte (low nibble) while -/// the K-index + 16 value is stored in the higher order part (high nibble). -/// For example, if the block length is 32, the values are store in the following order: -/// |byte(s16, s0),byte(s17, s1),byte(s18, s2),...,byte(s31, s15)| +/// Two int4 values are stored in one byte. +/// The lower order part of the byte (low) holds the first nibble (K-index + 0). +/// The higher order of the byte holds the second nibble (K-index + 1). /// /// @param[in] n_idx Row index in the RHS matrix (not packed). /// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) @@ -73,7 +66,7 @@ size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( /// Gets the offset in bytes for the packed RHS matrix. /// /// @param[in] n_idx Row index in the RHS matrix (not packed). -/// @param[in] k The common dimension between the LHS and RHS matrix (K) +/// @param[in] k The number of columns in the RHS matrix (not packed). /// @param[in] nr The number of columns written by the matmul micro-kernel /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. @@ -119,8 +112,8 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( /// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). /// /// @param[in] num_groups The number of groups. It must be 1. -/// @param[in] n The number of columns of the output matrix (N). -/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] n The number of rows in the RHS matrix (not packed). +/// @param[in] k The number of columns in the RHS matrix (not packed). /// @param[in] nr The number of columns written by the matmul micro-kernel. It must be a multiple of 4. /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. -- GitLab From d98d4881338a428257b70ba770240334ce136c1a Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 30 Aug 2024 16:21:53 +0100 Subject: [PATCH 21/29] Fix reference to the packing function Signed-off-by: Gian Marco Iodice --- ...atmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 2 +- ...atmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 2 +- ...i_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h | 2 +- ...i_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h index 32ec885e..b9530a1e 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -118,7 +118,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotp /// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs /// both the dynamic quantization to 8-bit and activation packing in a single step. /// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 +/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h index b2be5c67..a6c2ddb5 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -118,7 +118,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotp /// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs /// both the dynamic quantization to 8-bit and activation packing in a single step. /// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 +/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h index 4d185f6b..465b9e52 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h @@ -118,7 +118,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm /// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs /// both the dynamic quantization to 8-bit and activation packing in a single step. /// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 +/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h index 36c4330b..647802a6 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h @@ -118,7 +118,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm /// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs /// both the dynamic quantization to 8-bit and activation packing in a single step. /// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref -/// kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 +/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 /// @param[out] dst The DST matrix. /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). -- GitLab From 4ece3b0e9633f7195b4884f0c34cde642c88bfb4 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 30 Aug 2024 16:28:14 +0100 Subject: [PATCH 22/29] Fix list of dependencies in the matmul ukernels Signed-off-by: Gian Marco Iodice --- ...atmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 2 +- ...i_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h | 2 +- ...i_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h index b9530a1e..6fb4736e 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -15,7 +15,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix /// -------------------------------------------------- diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h index 465b9e52..765d84ab 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h @@ -15,7 +15,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix /// -------------------------------------------------- diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h index 647802a6..b87c0220 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h @@ -15,7 +15,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix /// -------------------------------------------------- -- GitLab From a17876ba7dca86cf57e06daed6accd78d7c5f666 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Mon, 2 Sep 2024 09:51:47 +0100 Subject: [PATCH 23/29] Parameterize matmul unit tests Signed-off-by: Anitha Raj --- ...matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp | 259 ++---------------- 1 file changed, 27 insertions(+), 232 deletions(-) diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp index 5302cccc..abd3ef17 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -15,12 +15,14 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" #include "test/common/bfloat16.hpp" #include "test/common/int4.hpp" #include "test/common/memory.hpp" #include "test/common/round.hpp" +#include "test/common/test_suite.hpp" #include "test/reference/cast.hpp" #include "test/reference/fill.hpp" #include "test/reference/matmul.hpp" @@ -29,169 +31,31 @@ namespace kai::test { -TEST(matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, EndTOEnd) { - const uint64_t seed = 0; - - const size_t M = 32; - const size_t N = 64; - const size_t K = 64; - - const auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(); - const auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(); - const auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(); - const auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(); - const size_t bl = 32; - - // Generates input data. - const auto ref_lhs = fill_random(M * K, seed + 0); - const auto ref_rhs = fill_random(N * K, seed + 1); - - // Runs the reference implementation. - // * Quantizes the LHS matrix using 8-bit asymmetric quantization. - // * Quantizes the RHS matrix using 4-bit symmetric quantization. - // * Performs GEMM. - const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = - quantize_asymmetric_per_block(ref_lhs.data(), M, K, K); - const auto [ref_rhs_qsi4, ref_rhs_scales] = - quantize_symmetric_per_block(ref_rhs.data(), N, K, bl); - - const auto ref_dst = matmul_clamp_nt_t( - M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), - ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), - std::numeric_limits::max()); - - // Runs the LHS packing micro-kernel. - const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); - std::vector imp_packed_lhs(imp_packed_lhs_size); - kai_run_lhs_quant_pack_qai8dxp_f32( - M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); - - // Runs the RHS packing micro-kernel. - // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. - // * Packs the RHS matrix. - const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); - - const size_t ref_rhs_qsu4_stride = round_up_division(K, 2); - const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); +static const std::array, 4> + variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p = {{ + UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod), + UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod), + UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm), + UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm), + }}; - const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::Bf16); - std::vector imp_packed_rhs(imp_packed_rhs_size); - const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{ - .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), ref_rhs_qsu4_stride, nullptr, - reinterpret_cast(ref_rhs_scales.data()), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, - ¶ms); - - // Runs the GEMM micro-kernel. - const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(M, N); - ASSERT_EQ(imp_dst_size, ref_dst.size()); - std::vector imp_dst(imp_dst_size); - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( - M, N, K, bl, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), - N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); +class MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p : public UkernelVariantTest {}; - // Compares the output of the micro-kernels against the output of the reference implementation. - for (size_t y = 0; y < M; ++y) { - for (size_t x = 0; x < N; ++x) { - const auto imp_value = read_array(imp_dst.data(), y * N + x); - const auto ref_value = read_array(ref_dst.data(), y * N + x); - const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); +TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd) { + auto& [variant_index, matmul_shape] = GetParam(); + const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_index); - if (rel_error > 0.0001F) { - ASSERT_EQ(imp_value, ref_value); - } - } - } -} -TEST(kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, EndTOEnd) { const uint64_t seed = 0; - const size_t M = 32; - const size_t N = 64; - const size_t K = 64; - - const auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(); - const auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(); - const auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(); - const auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(); - const size_t bl = 32; - - // Generates input data. - const auto ref_lhs = fill_random(M * K, seed + 0); - const auto ref_rhs = fill_random(N * K, seed + 1); - - // Runs the reference implementation. - // * Quantizes the LHS matrix using 8-bit asymmetric quantization. - // * Quantizes the RHS matrix using 4-bit symmetric quantization. - // * Performs GEMM. - const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = - quantize_asymmetric_per_block(ref_lhs.data(), M, K, K); - const auto [ref_rhs_qsi4, ref_rhs_scales] = - quantize_symmetric_per_block(ref_rhs.data(), N, K, bl); - - const auto ref_dst = matmul_clamp_nt_t( - M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), - ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), - std::numeric_limits::max()); - - // Runs the LHS packing micro-kernel. - const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); - std::vector imp_packed_lhs(imp_packed_lhs_size); - kai_run_lhs_quant_pack_qai8dxp_f32( - M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); - - // Runs the RHS packing micro-kernel. - // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. - // * Packs the RHS matrix. - const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); - - const size_t ref_rhs_qsu4_stride = round_up_division(K, 2); - const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); - - const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::Bf16); - std::vector imp_packed_rhs(imp_packed_rhs_size); - const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{ - .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), ref_rhs_qsu4_stride, nullptr, - reinterpret_cast(ref_rhs_scales.data()), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, - ¶ms); - - // Runs the GEMM micro-kernel. - const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(M, N); - ASSERT_EQ(imp_dst_size, ref_dst.size()); - std::vector imp_dst(imp_dst_size); - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( - M, N, K, bl, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), - N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); - - // Compares the output of the micro-kernels against the output of the reference implementation. - for (size_t y = 0; y < M; ++y) { - for (size_t x = 0; x < N; ++x) { - const auto imp_value = read_array(imp_dst.data(), y * N + x); - const auto ref_value = read_array(ref_dst.data(), y * N + x); - const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); - - if (rel_error > 0.0001F) { - ASSERT_EQ(imp_value, ref_value); - } - } - } -} -TEST(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, EndTOEnd) { - const uint64_t seed = 0; + const size_t M = matmul_shape.m; + const size_t N = matmul_shape.n; + const size_t K = matmul_shape.k; - const size_t M = 32; - const size_t N = 64; - const size_t K = 64; + const auto mr = ukernel_variant.interface.get_mr(); + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); - const auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(); - const auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(); - const auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(); - const auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(); const size_t bl = 32; // Generates input data. @@ -237,10 +101,10 @@ TEST(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, EndTOEnd) { ¶ms); // Runs the GEMM micro-kernel. - const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(M, N); + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); ASSERT_EQ(imp_dst_size, ref_dst.size()); std::vector imp_dst(imp_dst_size); - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + ukernel_variant.interface.run_matmul( M, N, K, bl, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); @@ -257,80 +121,11 @@ TEST(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, EndTOEnd) { } } } -TEST(matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, EndTOEnd) { - const uint64_t seed = 0; - - const size_t M = 32; - const size_t N = 64; - const size_t K = 64; - - const auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(); - const auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(); - const auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(); - const auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(); - const size_t bl = 32; - - // Generates input data. - const auto ref_lhs = fill_random(M * K, seed + 0); - const auto ref_rhs = fill_random(N * K, seed + 1); - - // Runs the reference implementation. - // * Quantizes the LHS matrix using 8-bit asymmetric quantization. - // * Quantizes the RHS matrix using 4-bit symmetric quantization. - // * Performs GEMM. - const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = - quantize_asymmetric_per_block(ref_lhs.data(), M, K, K); - const auto [ref_rhs_qsi4, ref_rhs_scales] = - quantize_symmetric_per_block(ref_rhs.data(), N, K, bl); - - const auto ref_dst = matmul_clamp_nt_t( - M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), - ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), - std::numeric_limits::max()); - - // Runs the LHS packing micro-kernel. - const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); - std::vector imp_packed_lhs(imp_packed_lhs_size); - kai_run_lhs_quant_pack_qai8dxp_f32( - M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); - - // Runs the RHS packing micro-kernel. - // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. - // * Packs the RHS matrix. - const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); - - const size_t ref_rhs_qsu4_stride = round_up_division(K, 2); - const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); - - const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::Bf16); - std::vector imp_packed_rhs(imp_packed_rhs_size); - const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{ - .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), ref_rhs_qsu4_stride, nullptr, - reinterpret_cast(ref_rhs_scales.data()), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, - ¶ms); - - // Runs the GEMM micro-kernel. - const auto imp_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(M, N); - ASSERT_EQ(imp_dst_size, ref_dst.size()); - std::vector imp_dst(imp_dst_size); - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( - M, N, K, bl, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), - N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); - // Compares the output of the micro-kernels against the output of the reference implementation. - for (size_t y = 0; y < M; ++y) { - for (size_t x = 0; x < N; ++x) { - const auto imp_value = read_array(imp_dst.data(), y * N + x); - const auto ref_value = read_array(ref_dst.data(), y * N + x); - const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, + testing::Combine( + testing::Range(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.size()), + testing::Values(MatMulShape{16, 32, 64}))); - if (rel_error > 0.0001F) { - ASSERT_EQ(imp_value, ref_value); - } - } - } -} } // namespace kai::test -- GitLab From 79b5ebacd72afbe1155bc7cae894326e32c7dc42 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Mon, 2 Sep 2024 10:27:36 +0100 Subject: [PATCH 24/29] Add non-transposed RHS packing kernels to Bazel builds Signed-off-by: Anitha Raj --- kai/ukernels/matmul/BUILD.bazel | 8 ++++++++ ...clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index b699af56..7dbec375 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -260,6 +260,13 @@ kai_c_library( cpu_uarch = kai_cpu_neon(), ) +kai_c_library( + name = "rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", + srcs = ["pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c"], + hdrs = ["pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h"], + cpu_uarch = kai_cpu_neon(), +) + kai_c_library( name = "matmul", deps = [ @@ -286,6 +293,7 @@ kai_c_library( ":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", ":rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", + ":rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", ":rhs_pack_kxn_qsi4cxp_qsu4cxs1s0", ":rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", ":rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h index a6c2ddb5..5c549ee5 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -15,7 +15,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix /// -------------------------------------------------- -- GitLab From be0a763ab951d506d2098a92d668ce3e3710bc8e Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Mon, 2 Sep 2024 14:53:45 +0100 Subject: [PATCH 25/29] Add matmul unit tests with non-transposed RHS packing ukernels Signed-off-by: Anitha Raj --- test/reference/matmul.cpp | 79 +++++++++++++++++ test/reference/matmul.hpp | 40 +++++++++ test/reference/quantize.cpp | 16 ++-- test/reference/quantize.hpp | 2 +- ...matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp | 84 ++++++++++++++++++- 5 files changed, 213 insertions(+), 8 deletions(-) diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 976ff049..381cf1d5 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -263,4 +263,83 @@ template std::vector matmul_clamp_nt_t +std::vector matmul_clamp_nt_nt( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + DstData min_value, DstData max_value) { + const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width); + const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width); + + std::vector dst(m * n * sizeof(DstData)); + + const auto* lhs_scales_ptr = reinterpret_cast(lhs_scales); + const auto* rhs_scales_ptr = reinterpret_cast(rhs_scales); + const auto* lhs_zero_points_ptr = reinterpret_cast(lhs_zero_points); + const auto* rhs_zero_points_ptr = reinterpret_cast(rhs_zero_points); + const auto* biases_ptr = reinterpret_cast(biases); + auto* dst_ptr = reinterpret_cast(dst.data()); + + for (size_t y = 0; y < m; ++y) { + for (size_t x = 0; x < n; ++x) { + DstData acc = 0; + + for (size_t i = 0; i < k; ++i) { + const auto lhs_value = read_array(lhs_data, y * k + i); + const auto lhs_scale = lhs_scales_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width]; + const auto lhs_zero_point = lhs_zero_points_ptr != nullptr + ? lhs_zero_points_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width] + : 0; + + const auto rhs_value = read_array(rhs_data, x + i * n); + const auto rhs_scale = rhs_scales_ptr[x * rhs_num_quant_per_row + i / rhs_quant_width]; + const auto rhs_zero_point = rhs_zero_points_ptr != nullptr + ? rhs_zero_points_ptr[y * rhs_num_quant_per_row + i / rhs_quant_width] + : 0; + + acc += static_cast( + (static_cast(lhs_value) + static_cast(lhs_zero_point)) * + (static_cast(rhs_value) + static_cast(rhs_zero_point))) * + static_cast(lhs_scale) * static_cast(rhs_scale); + } + + if (biases_ptr != nullptr) { + acc += static_cast(biases_ptr[x]); + } + + acc = std::clamp(acc, min_value, max_value); + dst_ptr[y * n + x] = acc; + } + } + + return dst; +} + +template std::vector matmul_clamp_nt_nt( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + float min_value, float max_value); + +template std::vector +matmul_clamp_nt_nt( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + float min_value, float max_value); + +template std::vector +matmul_clamp_nt_nt( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + float min_value, float max_value); + } // namespace kai::test diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp index 40fb684d..88a0729f 100644 --- a/test/reference/matmul.hpp +++ b/test/reference/matmul.hpp @@ -105,4 +105,44 @@ std::vector matmul_clamp_nt_t( const void* biases, // DstData min_value, DstData max_value); +/// Matrix multiplication with quantized input and floating-point output. +/// +/// The LHS matrix is non-transposed and the RHS matrix is non-transposed. +/// +/// @tparam LhsData The data type of the LHS matrix. +/// @tparam LhsScale The data type of the quantization scales of the LHS matrix. +/// @tparam LhsZeroPoint The data type of the quantization zero points of the LHS matrix. +/// @tparam Rhsdata The data type of the RHS matrix. +/// @tparam RhsScale The data type of the quantization scales of the RHS matrix. +/// @tparam RhsZeroPoint The data type of the quantization zero points of the RHS matrix. +/// @tparam Bias The data type of the bias vector. +/// @tparam IntAcc The data type of the intermediate integer accumulator. +/// @tparam DstData The data type of the floating-point accumulator and the output matrix. +/// +/// @param[in] m The LHS and output height. +/// @param[in] n The RHS height and output width. +/// @param[in] k The LHS and RHS width. +/// @param[in] lhs_data The LHS data matrix. +/// @param[in] lhs_scales The LHS quantization scales matrix. +/// @param[in] lhs_zero_points The LHS quantization zero points matrix. +/// @param[in] lhs_quant_width The LHS quantization block width. +/// @param[in] rhs_data The RHS data matrix. +/// @param[in] rhs_scales The RHS quantization scales matrix. +/// @param[in] rhs_zero_points The RHS quantization zero points matrix. +/// @param[in] rhs_quant_width The RHS quantization block width. +/// @param[in] biases The biases vector. +/// @param[in] min_value The minimum output value. +/// @param[in] max_value The maximum output value. +/// +/// @return The output matrix. +template < + typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, + typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> +std::vector matmul_clamp_nt_nt( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + DstData min_value, DstData max_value); + } // namespace kai::test diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp index 23db28dc..7d7a012a 100644 --- a/test/reference/quantize.cpp +++ b/test/reference/quantize.cpp @@ -75,7 +75,7 @@ IntType quantize_asymmetric(FloatType value, FloatType scale, ZeroPointType zero template std::tuple, std::vector> quantize_symmetric_per_block( - const void* src, size_t height, size_t width, size_t quant_width) { + const void* src, size_t height, size_t width, size_t quant_width, bool is_transposed) { static_assert(is_floating_point); static_assert(is_integral); static_assert(is_floating_point); @@ -114,7 +114,11 @@ std::tuple, std::vector> quantize_symmetric_per_bl if (x < width) { const auto quantized = quantize_symmetric(src_ptr[y * width + x], scale); - write_array(data.data(), y * width + x, quantized); + if (is_transposed) { + write_array(data.data(), y * width + x, quantized); + } else { + write_array(data.data(), x * height + y, quantized); + } } } } @@ -124,13 +128,13 @@ std::tuple, std::vector> quantize_symmetric_per_bl } template std::tuple, std::vector> quantize_symmetric_per_block( - const void* src, size_t height, size_t width, size_t quant_width); + const void* src, size_t height, size_t width, size_t quant_width, bool is_transposed); template std::tuple, std::vector> quantize_symmetric_per_block( - const void* src, size_t height, size_t width, size_t quant_width); + const void* src, size_t height, size_t width, size_t quant_width, bool is_transposed); template std::tuple, std::vector> quantize_symmetric_per_block( - const void* src, size_t height, size_t width, size_t quant_width); + const void* src, size_t height, size_t width, size_t quant_width, bool is_transposed); template std::tuple, std::vector> quantize_symmetric_per_block( - const void* src, size_t height, size_t width, size_t quant_width); + const void* src, size_t height, size_t width, size_t quant_width, bool is_transposed); template std::tuple, std::vector, std::vector> quantize_asymmetric_per_block( diff --git a/test/reference/quantize.hpp b/test/reference/quantize.hpp index 58eb88bb..77bbfc84 100644 --- a/test/reference/quantize.hpp +++ b/test/reference/quantize.hpp @@ -89,7 +89,7 @@ enum class QuantizationMethod : uint32_t { /// @return The quantized data matrix and the quantization scale matrix. template std::tuple, std::vector> quantize_symmetric_per_block( - const void* src, size_t height, size_t width, size_t quant_width); + const void* src, size_t height, size_t width, size_t quant_width, bool is_transposed = true); /// Quantizes each subblock of the matrix using asymmetric quantization method. /// diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp index abd3ef17..7040418b 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -17,6 +17,7 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" #include "test/common/bfloat16.hpp" #include "test/common/int4.hpp" @@ -41,7 +42,7 @@ static const std::array(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit asymmetric quantization. + // * Quantizes the RHS matrix using 4-bit symmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi4, ref_rhs_scales] = + quantize_symmetric_per_block(ref_rhs.data(), N, K, bl, false /* is_transposed */); + + const auto ref_dst = matmul_clamp_nt_nt( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), + ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), + std::numeric_limits::max()); + + // Runs the LHS packing micro-kernel. + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + std::vector imp_packed_lhs(imp_packed_lhs_size); + kai_run_lhs_quant_pack_qai8dxp_f32( + M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + + // Runs the RHS packing micro-kernel. + // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. + // * Packs the RHS matrix. + const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); + + const size_t ref_rhs_qsu4_stride = round_up_division(N, 2); + const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); + + const auto imp_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::Bf16); + std::vector imp_packed_rhs(imp_packed_rhs_size); + const kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params{ + .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; + kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), ref_rhs_qsu4_stride, nullptr, + reinterpret_cast(ref_rhs_scales.data()), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, + ¶ms); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + std::vector imp_dst(imp_dst_size); + ukernel_variant.interface.run_matmul( + M, N, K, bl, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), + N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); + + // Compares the output of the micro-kernels against the output of the reference implementation. + for (size_t y = 0; y < M; ++y) { + for (size_t x = 0; x < N; ++x) { + const auto imp_value = read_array(imp_dst.data(), y * N + x); + const auto ref_value = read_array(ref_dst.data(), y * N + x); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); + + if (rel_error > 0.0001F) { + ASSERT_EQ(imp_value, ref_value); + } + } + } +} + INSTANTIATE_TEST_SUITE_P( MatMul, MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, testing::Combine( -- GitLab From 73879f23674db81b429657ec7c52049aae86f911 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Tue, 10 Sep 2024 23:27:18 +0100 Subject: [PATCH 26/29] Address review comments * Minor fixes to cast function names * Header includes Signed-off-by: Anitha Raj --- kai/kai_common.h | 2 +- .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h | 2 +- .../matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c | 5 +++-- .../matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h | 3 +-- .../matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | 5 +++-- .../matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h | 3 +-- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/kai/kai_common.h b/kai/kai_common.h index f1cbdbb4..b3b6f7ee 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -90,7 +90,7 @@ inline static float kai_cast_f32_f16(uint16_t f16) { /// @param[in] bf16 The f16 value /// /// @return the f32 value -inline static float kai_cast_bf16_f32(uint16_t bf16) { +inline static float kai_cast_f32_bf16(uint16_t bf16) { const uint32_t i32 = (bf16 << 16); float f32; memcpy(&f32, &i32, sizeof(i32)); diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h index fdc7b16e..f32d677c 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h @@ -5,7 +5,7 @@ // #pragma once -#include +#include #ifdef __cplusplus extern "C" { diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c index d60ef3fd..2ea36b71 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c @@ -133,6 +133,7 @@ void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( KAI_ASSERT(params->lhs_zero_point == 1); KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((kr % sr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); KAI_ASSERT(params->scale_dt == F32 || params->scale_dt == F16 || params->scale_dt == Bf16); @@ -217,10 +218,10 @@ void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( d = ((float*)rhs_packed_scale)[nr_idx]; break; case F16: - d = kai_cast_f16_f32(((uint16_t*)rhs_packed_scale)[nr_idx]); + d = kai_cast_f32_f16(((uint16_t*)rhs_packed_scale)[nr_idx]); break; case Bf16: - d = kai_cast_bf16_f32(((uint16_t*)rhs_packed_scale)[nr_idx]); + d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); break; default: KAI_ERROR("Unsupported scale data type"); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h index e8874f98..856639bc 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h @@ -5,8 +5,7 @@ // #pragma once -#include -#include +#include #include "kai/kai_common.h" diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index 1f1416d1..76b9b533 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -131,6 +131,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( KAI_ASSERT(params->lhs_zero_point == 1); KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((kr % sr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); KAI_ASSERT(params->scale_dt == F32 || params->scale_dt == F16 || params->scale_dt == Bf16); @@ -239,10 +240,10 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( d = ((float*)rhs_packed_scale)[nr_idx]; break; case F16: - d = kai_cast_f16_f32(((uint16_t*)rhs_packed_scale)[nr_idx]); + d = kai_cast_f32_f16(((uint16_t*)rhs_packed_scale)[nr_idx]); break; case Bf16: - d = kai_cast_bf16_f32(((uint16_t*)rhs_packed_scale)[nr_idx]); + d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); break; default: KAI_ERROR("Unsupported scale data type"); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h index 4d073878..2b411f43 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h @@ -5,8 +5,7 @@ // #pragma once -#include -#include +#include #include "kai/kai_common.h" -- GitLab From f187cd9389118f927b7e9177496fd99fb50a79f4 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 12 Sep 2024 16:30:42 +0100 Subject: [PATCH 27/29] Move convert bf16/f32 functions to kai_common.h Signed-off-by: Anitha Raj --- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 21 ++++--------------- kai/kai_common.h | 10 +++++++++ 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index 65ba5f53..97d60c3a 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -110,19 +110,6 @@ static inline size_t get_rhs_scale_stride(size_t k, size_t bl) { return num_blocks_per_row * sizeof(uint16_t); } -static inline uint16_t convert_f32_to_bf16(float f32) { - const uint32_t* i32 = reinterpret_cast(&f32); - uint16_t bf16 = (*i32 >> 16); - return bf16; -} - -static inline float convert_bf16_to_f32(uint16_t bf16) { - const uint32_t i32 = (bf16 << 16); - float f32; - memcpy(&f32, &i32, sizeof(i32)); - return f32; -} - static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, size_t seed) { std::srand(seed); @@ -167,7 +154,7 @@ static void quant_nxk_qs4c32_f32( const float recip_scale = scale ? 1.0f / scale : 0.0f; // Store the scale in the dedicated buffer - *rhs_scales_bf16 = convert_f32_to_bf16(scale); + *rhs_scales_bf16 = kai_cast_bf16_f32(scale); rhs_scales_bf16 += 1; @@ -239,7 +226,7 @@ static void quant_kxn_qs4c32_f32( const float recip_scale = scale ? 1.0f / scale : 0.0f; // Store the scale in the dedicated buffer - *rhs_scales_bf16 = convert_f32_to_bf16(scale); + *rhs_scales_bf16 = kai_cast_bf16_f32(scale); rhs_scales_bf16 += 1; @@ -382,7 +369,7 @@ static void ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4c32( for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { const uint16_t rhs_scale_bf16 = scale_bf16[block_idx + col_idx * num_blocks_row]; - const float rhs_scale = convert_bf16_to_f32(rhs_scale_bf16); + const float rhs_scale = kai_cast_f32_bf16(rhs_scale_bf16); int32_t iacc = 0; @@ -459,7 +446,7 @@ static void ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4c32( for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { const uint16_t rhs_scale_bf16 = scale_bf16[block_idx + col_idx * num_blocks_row]; - const float rhs_scale = convert_bf16_to_f32(rhs_scale_bf16); + const float rhs_scale = kai_cast_f32_bf16(rhs_scale_bf16); int32_t iacc = 0; diff --git a/kai/kai_common.h b/kai/kai_common.h index b3b6f7ee..92faf2fe 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -97,6 +97,16 @@ inline static float kai_cast_f32_bf16(uint16_t bf16) { return f32; } +/// Converts a f32 value to bf16 +/// @param[in] f32 The f32 value +/// +/// @return the bf16 value +inline static uint16_t kai_cast_bf16_f32(float f32) { + const uint32_t* i32 = (uint32_t*)(&f32); + uint16_t bf16 = (*i32 >> 16); + return bf16; +} + /// Converts a scalar f32 value to f16 /// @param[in] f32 The f32 value /// -- GitLab From baf27efcb24f567b2902216a70c97e859f8f4372 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 12 Sep 2024 16:51:33 +0100 Subject: [PATCH 28/29] Rename enum kai_datatype Signed-off-by: Anitha Raj --- CHANGELOG.md | 4 ++-- kai/kai_common.h | 22 +++++++++---------- .../kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c | 14 ++++++------ .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | 14 ++++++------ ...matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp | 8 +++---- 5 files changed, 31 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c7a7c3b..5428c542 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,8 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## v0.3.0 - Upcoming Release - Advanced SIMD FP32 GEMM and GEMV micro kernels -- Micro-kernels to compute the matrix multiplication of dynamically quantized asymmetric signed 8-bit integer with per-row quantization (QAI8DX) LHS and quantized symmetric 4-bit unsigned integer with per-block quantization (QSU4C32) RHS. The destination matrix data type is single-precision floating-point (F32). The micro-kernels have been optimized using the ArmĀ® CPU feature FEAT_I8MM for the matrix-by-matrix cases and the FEAT_DotProd for the vector-by-matrix cases. -- RHS matrix packing micro-kernels to pack the RHS matrix holding the QSU4C32 values. +- Micro-kernels to compute the matrix multiplication of dynamically quantized asymmetric signed 8-bit integer with per-row quantization (QAI8DX) LHS and quantized symmetric 4-bit signed integer with per-block quantization (QSI4C32) RHS. The destination matrix data type is single-precision floating-point (F32). The micro-kernels have been optimized using the ArmĀ® CPU feature FEAT_I8MM for the matrix-by-matrix cases and the FEAT_DotProd for the vector-by-matrix cases. +- RHS matrix packing micro-kernels to pack the RHS matrix holding the QSI4C32 values. - Unit test and example for integer micro-kernels. ## v0.2.0 diff --git a/kai/kai_common.h b/kai/kai_common.h index 92faf2fe..47831cd6 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -53,17 +53,17 @@ extern "C" { /// KleidiAI data types /// Format: (reserved)|(num-bytes)|(type)|(variant-type) enum kai_datatype { - Unknown = 0x0000, - F32 = 0x0411, - F16 = 0x0212, - Bf16 = 0x0213, - Int32 = 0x0421, - Int16 = 0x0222, - Int8 = 0x0124, - Uint32 = 0x0431, - Uint16 = 0x0232, - Uint8 = 0x0134, - Bool = 0x0441 + kai_dt_unknown = 0x0000, + kai_dt_f32 = 0x0411, + kai_dt_f16 = 0x0212, + kai_dt_bf16 = 0x0213, + kai_dt_int32 = 0x0421, + kai_dt_int16 = 0x0222, + kai_dt_int8 = 0x0124, + kai_dt_uint32 = 0x0431, + kai_dt_uint16 = 0x0232, + kai_dt_uint8 = 0x0134, + kai_dt_bool = 0x0441 }; /// Gets number of bytes for a given data type diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c index 2ea36b71..1ed35a16 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c @@ -60,7 +60,7 @@ size_t kai_get_rhs_packed_stride_pack_kxn_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); + KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); KAI_UNUSED(kr); KAI_UNUSED(sr); @@ -84,7 +84,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); + KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); return (n_idx / nr) * kai_get_rhs_packed_stride_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); } @@ -100,7 +100,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); + KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); const size_t num_rows = kai_roundup(n, nr) / nr; @@ -136,7 +136,7 @@ void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((kr % sr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(params->scale_dt == F32 || params->scale_dt == F16 || params->scale_dt == Bf16); + KAI_ASSERT(params->scale_dt == kai_dt_f32 || params->scale_dt == kai_dt_f16 || params->scale_dt == kai_dt_bf16); // Note: The input matrix (rhs) is expected with: // "n" columns and "k" rows (kxn) @@ -214,13 +214,13 @@ void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( float d = 0.0F; switch (scale_dt) { - case F32: + case kai_dt_f32: d = ((float*)rhs_packed_scale)[nr_idx]; break; - case F16: + case kai_dt_f16: d = kai_cast_f32_f16(((uint16_t*)rhs_packed_scale)[nr_idx]); break; - case Bf16: + case kai_dt_bf16: d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); break; default: diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index 76b9b533..c6d897e3 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -58,7 +58,7 @@ size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); + KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); KAI_UNUSED(kr); KAI_UNUSED(sr); @@ -82,7 +82,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); + KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); } @@ -98,7 +98,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == F32 || scale_dt == F16 || scale_dt == Bf16); + KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); const size_t num_rows = kai_roundup(n, nr) / nr; @@ -134,7 +134,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((kr % sr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(params->scale_dt == F32 || params->scale_dt == F16 || params->scale_dt == Bf16); + KAI_ASSERT(params->scale_dt == kai_dt_f32 || params->scale_dt == kai_dt_f16 || params->scale_dt == kai_dt_bf16); // Note: The input matrix (rhs) is expected with: // "k" columns and "n" rows (NxK) @@ -236,13 +236,13 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( float d = 0.0F; switch (scale_dt) { - case F32: + case kai_dt_f32: d = ((float*)rhs_packed_scale)[nr_idx]; break; - case F16: + case kai_dt_f16: d = kai_cast_f32_f16(((uint16_t*)rhs_packed_scale)[nr_idx]); break; - case Bf16: + case kai_dt_bf16: d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); break; default: diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp index 7040418b..15acde4e 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -92,10 +92,10 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_Transpose const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::Bf16); + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::kai_dt_bf16); std::vector imp_packed_rhs(imp_packed_rhs_size); const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{ - .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; + .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::kai_dt_bf16}; kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), ref_rhs_qsu4_stride, nullptr, reinterpret_cast(ref_rhs_scales.data()), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, @@ -173,10 +173,10 @@ TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_NonTransp const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); const auto imp_packed_rhs_size = - kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::Bf16); + kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::kai_dt_bf16); std::vector imp_packed_rhs(imp_packed_rhs_size); const kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params{ - .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::Bf16}; + .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::kai_dt_bf16}; kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), ref_rhs_qsu4_stride, nullptr, reinterpret_cast(ref_rhs_scales.data()), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, -- GitLab From 294ef19e1df4921e5c1fb9cdfd635e3ccfab79af Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 12 Sep 2024 17:15:50 +0100 Subject: [PATCH 29/29] Fix kai_datatype used in examples Signed-off-by: Anitha Raj --- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index 97d60c3a..87bdad8f 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -604,10 +604,12 @@ int main() { size_t rhs_packed_size = 0; if (format == rhs_format::nxk) { - rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, Bf16); + rhs_packed_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, kai_dt_bf16); } else { - rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, Bf16); + rhs_packed_size = + kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, kai_dt_bf16); } const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); @@ -625,7 +627,7 @@ int main() { struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; - params.scale_dt = Bf16; + params.scale_dt = kai_dt_bf16; kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( 1, n, k, // Dimensions @@ -643,7 +645,7 @@ int main() { struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; - params.scale_dt = Bf16; + params.scale_dt = kai_dt_bf16; kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( 1, n, k, // Dimensions -- GitLab