From 380c4f94f34b307285dff4b975aeb761a34f0e9f Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 9 Apr 2024 15:30:58 +0100 Subject: [PATCH 01/14] Matmul int4 micro-kernels for QA8DX (LHS) x QS8CX (RHS) -> F32 - The LHS matrix is quantized (Q) Asymmetric (A) 8-bit (8) with per-row (DX) quantization parameters - The RHS matrix is quantized (Q) Symmetric (S) 4-bit (4) with per-channel (cx) quantization parameters - The destination is F32 - Implement matmul int4 micro-kernels with intrinsics by using the dotprod and i8mm extensions - Implement a micro-kernel to pack the RHS matrix - Implement two micro-kernels to dynamically quantize and pack the LHS matrix - Add README.md - No test added into this PR. Test will be added in a separate PR Signed-off-by: Gian Marco Iodice --- README.md | 88 ++++- src/kai_common.h | 23 +- src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.c | 142 ++++++++ src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.h | 65 ++++ src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c | 315 ++++++++++++++++++ src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h | 65 ++++ .../kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c | 158 +++++++++ .../kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h | 88 +++++ ..._qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c | 214 ++++++++++++ ..._qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h | 115 +++++++ ...f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c | 240 +++++++++++++ ...f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h | 115 +++++++ 12 files changed, 1625 insertions(+), 3 deletions(-) create mode 100644 src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.c create mode 100644 src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.h create mode 100644 src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c create mode 100644 src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h create mode 100644 src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c create mode 100644 src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h create mode 100644 src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c create mode 100644 src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h create mode 100644 src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c create mode 100644 src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h diff --git a/README.md b/README.md index cc139c11..48d62e3c 100644 --- a/README.md +++ b/README.md @@ -29,9 +29,9 @@ For example, consider the convolution 2d operator performed through the Winograd - Matrix multiplication - Winograd output transform -Each of the preceding operations is a micro-kernel. For an example, please refer to the [first micro kernel PR](https://gitlab.arm.com/kleidi/kleidiai/-/merge_requests/2) +Each of the preceding operations is a micro-kernel. -However, why are the preceding operations not called kernels or functions? +However, why the preceding operations are not called kernels or functions instead? Because the micro-kernels are designed to give the flexibility to process also a portion of the output tensor, which is the reason why we call it micro-kernel. @@ -54,6 +54,90 @@ Some of the key features of KleidiAI are the following: > ℹ️ The micro-kernel API is designed to be as generic as possible for integration into third-party runtimes. +

Current supported Arm® CPUs technologies and features

+ +Arm® Neon™ + +- dotprod (Armv8.2-A onwards) +- i8mm (Armv8.6-A onwards) + +

Filename convention

+ +The `src/` directory is the home for all micro-kernels. The micro-kernels are grouped in separate directories based on the performed operation. For example, all the matrix-multiplication micro-kernels are held in the `matmul/` operator directory. + +Inside the operator directory, you can find: + +- *The common micro-kernels*, which are micro-kernels necessary for the correct functioning of the micro-kernels. For example, some of these may be required for packing the input tensors. +- *The micro-kernels* files, which are held in separate sub-directories. + +The name of the micro-kernel folder provides the description of the operation performed and the data type of the destination and source tensors. The general syntax for the micro-kernel folder is as follows: + +`____...` + +All .c and .h pair files in that folders are micro-kernel variants. The variants are differentiated by specifying the computational paramaters (for example, the block size), the Arm® technology, and Arm® architecture feature exploited. The general syntax for the micro-kernel variant is as follows: + +`kai____.c/.h` + +> ℹ️ These files, only depend on the `kai_common.h` file. + +All functions defined in the .h header file of the micro-kernel variant has the following syntax: + +`kai__.c/.h` + +

Data types

+ +Some of the data types currently supported with the KleidiAI library are the following: + +| Data type | Abbreviation | Notes | +| ----------- | ----------- | ----------- | +| Floating-point 32-bit | f32 | | +| Quantized (q) Symmetric (s) 4-bit (4) Per-Channel (cx) quantization parameters | qs4cx | An fp32 multiplier shared among all values of the same channel. `x` denotes the entirety of the channel | +| Quantized (q) Asymmetric (a) 8-bit (8) Per-Dimension (dx) (for example, Per-Row) quantization parameters | qa8dx | An fp32 multiplier and a int32 zero offset shared among all values of the same dimension. | + +> ℹ️ In some cases, we may append the letter `P` to the data type to specify that the tensor is expected to be packed. A packed tensor is a tensor that has been rearranged in our preferred data layout from the original data layout to improve the performance of the micro-kernel. In addition to the letter `P`, we may append other upper-case alphanumerical values to specify the attributes of the data packing (for example, the block packing size). + +

Supported micro-kernels

+ + + + + + + + + + + + + + + + + + + + + + + +
Micro-kernelAbbreviationData typeReference frameworkNotes
Matrix-multiplication with LHS packed and RHS packed matricesmatmul_clip_f32_qa8dxP_qs4cxP + LHS: qa8dxP
+ RHS: qs4cxP
+ DST: f32
+
+ TensorFlow Lite
+
+ The packing function for the RHS matrix is available in the `kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c/.h` files.
+ Since the RHS matrix often contains constant values, we recommend packing the RHS matrix only once and freeing the content of the original RHS matrix.
+
Dynamic quantization and LHS matrix packingkai_lhs_quant_pack_qa8dxP1X8_f32, kai_lhs_quant_pack_qa8dxP4X8_f32 + SRC: f32
+ DST: qa8cx
+
+ TensorFlow Lite
+
+
+
+

Frequently Asked Questions (FAQ)

What is the difference between the Compute Library for the Arm® Architecture (ACL) and KleidiAI?

diff --git a/src/kai_common.h b/src/kai_common.h index b44dad5b..01d70824 100644 --- a/src/kai_common.h +++ b/src/kai_common.h @@ -3,12 +3,15 @@ // // SPDX-License-Identifier: Apache-2.0 // - #pragma once #include #include +#ifdef __cplusplus +extern "C" { +#endif + // NOLINTBEGIN(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) // // * cppcoreguidelines-avoid-do-while: do-while is necessary for macros. @@ -40,3 +43,21 @@ #define KAI_ASSUME_IF KAI_ASSERT_IF #define KAI_UNUSED(x) (void)(x) + +#define KAI_UNUSED(x) (void)(x) +#define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b)) +#define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b)) + +inline static size_t kai_roundup(size_t a, size_t b) { + size_t rem = a % b; + if (rem) { + return a + b - rem; + } else { + return a; + } +} + + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.c b/src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.c new file mode 100644 index 00000000..f4a4f752 --- /dev/null +++ b/src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.c @@ -0,0 +1,142 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_lhs_quant_pack_qa8dxP1X8_f32.h" + +#include "../kai_common.h" + +#include +#include +#include +#include + +static const size_t kai_kk0 = 8; +static const size_t kai_km0 = 1; +static const size_t kai_num_bytes_per_multiplier = sizeof(float); +static const size_t kai_num_bytes_per_offset = sizeof(float); + +size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP1X8_f32(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP1X8_f32(size_t m_idx, size_t k) { + const size_t dst_stride = kai_km0 * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); + + return (m_idx / kai_km0) * dst_stride; +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP1X8_f32(size_t m, size_t k) { + const size_t m_roundup4 = kai_roundup(m, kai_km0); + + const size_t t_size = m_roundup4 * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); + return t_size; +} + +void kai_run_lhs_quant_pack_qa8dxP1X8_f32( + size_t m, size_t k, const float* restrict lhs, size_t lhs_stride, void* restrict lhs_p) { + KAI_ASSERT(k % kai_kk0 == 0); + + if (m == 0) { + return; + } + + const size_t num_rows = m; + const size_t num_cols = k; + + const float* src_ptr = lhs; + uint8_t* dst_ptr = lhs_p; + + for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_km0) { + float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); + + // Find min/max for each channel + for (size_t j = 0; j < num_cols; j += 8) { + const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + j); + const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + j); + + // Calculate the max + vmax0 = vmaxq_f32(src0_0, vmax0); + vmax0 = vmaxq_f32(vmax0, src0_1); + + // Calculate the min + vmin0 = vminq_f32(src0_0, vmin0); + vmin0 = vminq_f32(vmin0, src0_1); + } + + // Get the max/min from each row + const float max0 = vmaxvq_f32(vmax0); + const float min0 = vminvq_f32(vmin0); + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = KAI_MIN(0.0f, min0); + const float rmax0 = KAI_MAX(0.0f, max0); + + const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; + + zero_point0 = KAI_MAX(zero_point0, qmin); + zero_point0 = KAI_MIN(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = lrintf(zero_point0); + + // LHS offset at the beginning of the row + *((int32_t*)dst_ptr) = -nudged_zero_point0; + dst_ptr += sizeof(int32_t); + + // Quantize the channels + for (size_t j = 0; j < num_cols; j += 8) { + const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + j); + const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + j); + + float32x4_t v0_f32; + float32x4_t v1_f32; + int32x4_t v0_s32; + int32x4_t v1_s32; + + // Scale the values + v0_f32 = vmulq_n_f32(src0_0, scale0); + v1_f32 = vmulq_n_f32(src0_1, scale0); + v0_s32 = vcvtnq_s32_f32(v0_f32); + v1_s32 = vcvtnq_s32_f32(v1_f32); + v0_s32 = vaddq_s32(v0_s32, vdupq_n_s32(nudged_zero_point0)); + v1_s32 = vaddq_s32(v1_s32, vdupq_n_s32(nudged_zero_point0)); + v0_s32 = vmaxq_s32(v0_s32, vdupq_n_s32(INT8_MIN)); + v0_s32 = vminq_s32(v0_s32, vdupq_n_s32(INT8_MAX)); + v1_s32 = vmaxq_s32(v1_s32, vdupq_n_s32(INT8_MIN)); + v1_s32 = vminq_s32(v1_s32, vdupq_n_s32(INT8_MAX)); + *((int8_t*)(dst_ptr + 0)) = (int8_t)vgetq_lane_s32(v0_s32, 0); + *((int8_t*)(dst_ptr + 1)) = (int8_t)vgetq_lane_s32(v0_s32, 1); + *((int8_t*)(dst_ptr + 2)) = (int8_t)vgetq_lane_s32(v0_s32, 2); + *((int8_t*)(dst_ptr + 3)) = (int8_t)vgetq_lane_s32(v0_s32, 3); + *((int8_t*)(dst_ptr + 4)) = (int8_t)vgetq_lane_s32(v1_s32, 0); + *((int8_t*)(dst_ptr + 5)) = (int8_t)vgetq_lane_s32(v1_s32, 1); + *((int8_t*)(dst_ptr + 6)) = (int8_t)vgetq_lane_s32(v1_s32, 2); + *((int8_t*)(dst_ptr + 7)) = (int8_t)vgetq_lane_s32(v1_s32, 3); + dst_ptr += sizeof(int8_t) * kai_kk0; + } + + // Store the scale quantization params + *((float*)dst_ptr) = recip_scale0; + dst_ptr += sizeof(float); + + src_ptr += kai_km0 * (lhs_stride / sizeof(float)); + } +} diff --git a/src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.h b/src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.h new file mode 100644 index 00000000..3c494877 --- /dev/null +++ b/src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.h @@ -0,0 +1,65 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +#include +#include + +/** + * @brief Function to calculate the offset in bytes for the LHS matrix (not packed) + * + * This function should be called before passing the pointer to the LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) + * + * return the offset in bytes to the LHS matrix + */ +size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP1X8_f32(size_t m_idx, size_t lhs_stride); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP1X8_f32(size_t m_idx, size_t k); + +/** + * @brief Function to return the memory required for storing the quantized and packed LHS matrix + * + * @param[in] m Total number of rows in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). + * + * return the size in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP1X8_f32(size_t m, size_t k); + +/** + * @brief Micro-kernel to quantize and pack the LHS matrix + * + * @param[in] m The number of output rows written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. + * @param[in] lhs LHS of the vector-by-matrix. + * @param[in] lhs_stride Stride in bytes between two rows of LHS. + * @param[out] lhs_p The quantized and packed LHS matrix. + */ +void kai_run_lhs_quant_pack_qa8dxP1X8_f32(size_t m, size_t k, const float* lhs, size_t lhs_stride, void* lhs_p); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c b/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c new file mode 100644 index 00000000..89fc26d9 --- /dev/null +++ b/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c @@ -0,0 +1,315 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_lhs_quant_pack_qa8dxP4X8_f32.h" + +#include "../kai_common.h" + +#include +#include +#include +#include + +static const size_t kai_kk0 = 8; +static const size_t kai_km0 = 4; +static const size_t kai_num_bytes_per_multiplier = sizeof(float); +static const size_t kai_num_bytes_per_offset = sizeof(float); + +size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP4X8_f32(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP4X8_f32(size_t m_idx, size_t k) { + const size_t dst_stride = kai_km0 * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); + + return (m_idx / kai_km0) * dst_stride; +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP4X8_f32(size_t m, size_t k) { + const size_t m_roundup4 = kai_roundup(m, kai_km0); + + const size_t t_size = m_roundup4 * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); + return t_size; +} + +void kai_run_lhs_quant_pack_qa8dxP4X8_f32( + size_t m, size_t k, const float* restrict lhs, size_t lhs_stride, void* restrict lhs_p) { + KAI_ASSERT(k % kai_kk0 == 0); + KAI_ASSERT(m <= 3 || (m % kai_km0 == 0)); + + if (m == 0) { + return; + } + + const size_t num_cols = k; + const size_t num_rows = m; + + const float* src_ptr0 = (const float*)((const uint8_t*)lhs + 0 * lhs_stride); + const float* src_ptr1 = (const float*)((const uint8_t*)lhs + 1 * lhs_stride); + const float* src_ptr2 = (const float*)((const uint8_t*)lhs + 2 * lhs_stride); + const float* src_ptr3 = (const float*)((const uint8_t*)lhs + 3 * lhs_stride); + + if (m == 1) { + src_ptr1 = src_ptr0; + src_ptr2 = src_ptr0; + src_ptr3 = src_ptr0; + } else if (m == 2) { + src_ptr2 = src_ptr1; + src_ptr3 = src_ptr1; + } else if (m == 3) { + src_ptr3 = src_ptr2; + } + + uint8_t* dst_ptr = lhs_p; + + for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_km0) { + float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmax1 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmax2 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmax3 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); + float32x4_t vmin1 = vdupq_n_f32(FLT_MAX); + float32x4_t vmin2 = vdupq_n_f32(FLT_MAX); + float32x4_t vmin3 = vdupq_n_f32(FLT_MAX); + + // Find min/max for each channel + for (size_t j = 0; j < num_cols; j += 8) { + const float32x4_t src0_0 = vld1q_f32(src_ptr0 + 0 + j); + const float32x4_t src0_1 = vld1q_f32(src_ptr0 + 4 + j); + + const float32x4_t src1_0 = vld1q_f32(src_ptr1 + 0 + j); + const float32x4_t src1_1 = vld1q_f32(src_ptr1 + 4 + j); + + const float32x4_t src2_0 = vld1q_f32(src_ptr2 + 0 + j); + const float32x4_t src2_1 = vld1q_f32(src_ptr2 + 4 + j); + + const float32x4_t src3_0 = vld1q_f32(src_ptr3 + 0 + j); + const float32x4_t src3_1 = vld1q_f32(src_ptr3 + 4 + j); + + // Calculate the max + vmax0 = vmaxq_f32(src0_0, vmax0); + vmax1 = vmaxq_f32(src1_0, vmax1); + vmax2 = vmaxq_f32(src2_0, vmax2); + vmax3 = vmaxq_f32(src3_0, vmax3); + + vmax0 = vmaxq_f32(vmax0, src0_1); + vmax1 = vmaxq_f32(vmax1, src1_1); + vmax2 = vmaxq_f32(vmax2, src2_1); + vmax3 = vmaxq_f32(vmax3, src3_1); + + // Calculate the min + vmin0 = vminq_f32(src0_0, vmin0); + vmin1 = vminq_f32(src1_0, vmin1); + vmin2 = vminq_f32(src2_0, vmin2); + vmin3 = vminq_f32(src3_0, vmin3); + + vmin0 = vminq_f32(vmin0, src0_1); + vmin1 = vminq_f32(vmin1, src1_1); + vmin2 = vminq_f32(vmin2, src2_1); + vmin3 = vminq_f32(vmin3, src3_1); + } + + // Get the max/min from each row + const float max0 = vmaxvq_f32(vmax0); + const float max1 = vmaxvq_f32(vmax1); + const float max2 = vmaxvq_f32(vmax2); + const float max3 = vmaxvq_f32(vmax3); + const float min0 = vminvq_f32(vmin0); + const float min1 = vminvq_f32(vmin1); + const float min2 = vminvq_f32(vmin2); + const float min3 = vminvq_f32(vmin3); + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = KAI_MIN(0.0f, min0); + const float rmax0 = KAI_MAX(0.0f, max0); + const float rmin1 = KAI_MIN(0.0f, min1); + const float rmax1 = KAI_MAX(0.0f, max1); + const float rmin2 = KAI_MIN(0.0f, min2); + const float rmax2 = KAI_MAX(0.0f, max2); + const float rmin3 = KAI_MIN(0.0f, min3); + const float rmax3 = KAI_MAX(0.0f, max3); + + const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + const float scale1 = rmin1 == rmax1 ? 1.f : (qmax - qmin) / (rmax1 - rmin1); + const float scale2 = rmin2 == rmax2 ? 1.f : (qmax - qmin) / (rmax2 - rmin2); + const float scale3 = rmin3 == rmax3 ? 1.f : (qmax - qmin) / (rmax3 - rmin3); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + const float recip_scale1 = scale1 ? 1.0f / scale1 : 0.0f; + const float recip_scale2 = scale2 ? 1.0f / scale2 : 0.0f; + const float recip_scale3 = scale3 ? 1.0f / scale3 : 0.0f; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_min1 = rmin1 * scale1; + const float descaled_min2 = rmin2 * scale2; + const float descaled_min3 = rmin3 * scale3; + const float descaled_max0 = rmax0 * scale0; + const float descaled_max1 = rmax1 * scale1; + const float descaled_max2 = rmax2 * scale2; + const float descaled_max3 = rmax3 * scale3; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + const float zero_point_from_min_error1 = qmin + descaled_min1; + const float zero_point_from_max_error1 = qmax + descaled_max1; + const float zero_point_from_min_error2 = qmin + descaled_min2; + const float zero_point_from_max_error2 = qmax + descaled_max2; + const float zero_point_from_min_error3 = qmin + descaled_min3; + const float zero_point_from_max_error3 = qmax + descaled_max3; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; + float zero_point1 = + zero_point_from_min_error1 + zero_point_from_max_error1 > 0 ? qmin - descaled_min1 : qmax - descaled_max1; + float zero_point2 = + zero_point_from_min_error2 + zero_point_from_max_error2 > 0 ? qmin - descaled_min2 : qmax - descaled_max2; + float zero_point3 = + zero_point_from_min_error3 + zero_point_from_max_error3 > 0 ? qmin - descaled_min3 : qmax - descaled_max3; + + zero_point0 = KAI_MAX(zero_point0, qmin); + zero_point0 = KAI_MIN(zero_point0, qmax); + zero_point1 = KAI_MAX(zero_point1, qmin); + zero_point1 = KAI_MIN(zero_point1, qmax); + zero_point2 = KAI_MAX(zero_point2, qmin); + zero_point2 = KAI_MIN(zero_point2, qmax); + zero_point3 = KAI_MAX(zero_point3, qmin); + zero_point3 = KAI_MIN(zero_point3, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = lrintf(zero_point0); + const int32_t nudged_zero_point1 = lrintf(zero_point1); + const int32_t nudged_zero_point2 = lrintf(zero_point2); + const int32_t nudged_zero_point3 = lrintf(zero_point3); + + // The LHS offsets are stored at the beginning of the row + int32x4_t voffsets_s32 = vdupq_n_s32(0.0f); + voffsets_s32 = vsetq_lane_s32(-nudged_zero_point0, voffsets_s32, 0); + voffsets_s32 = vsetq_lane_s32(-nudged_zero_point1, voffsets_s32, 1); + voffsets_s32 = vsetq_lane_s32(-nudged_zero_point2, voffsets_s32, 2); + voffsets_s32 = vsetq_lane_s32(-nudged_zero_point3, voffsets_s32, 3); + + vst1q_s32((int32_t*)dst_ptr, voffsets_s32); + dst_ptr += sizeof(int32x4_t); + + // Quantize the channels + for (size_t j = 0; j < num_cols; j += 8) { + const float32x4_t src0_0 = vld1q_f32(src_ptr0 + 0 + j); + const float32x4_t src0_1 = vld1q_f32(src_ptr0 + 4 + j); + + const float32x4_t src1_0 = vld1q_f32(src_ptr1 + 0 + j); + const float32x4_t src1_1 = vld1q_f32(src_ptr1 + 4 + j); + + const float32x4_t src2_0 = vld1q_f32(src_ptr2 + 0 + j); + const float32x4_t src2_1 = vld1q_f32(src_ptr2 + 4 + j); + + const float32x4_t src3_0 = vld1q_f32(src_ptr3 + 0 + j); + const float32x4_t src3_1 = vld1q_f32(src_ptr3 + 4 + j); + + float32x4_t v0_f32; + float32x4_t v1_f32; + int32x4_t v0_s32; + int32x4_t v1_s32; + + // Scale the values + v0_f32 = vmulq_n_f32(src0_0, scale0); + v1_f32 = vmulq_n_f32(src0_1, scale0); + v0_s32 = vcvtnq_s32_f32(v0_f32); + v1_s32 = vcvtnq_s32_f32(v1_f32); + v0_s32 = vaddq_s32(v0_s32, vdupq_n_s32(nudged_zero_point0)); + v1_s32 = vaddq_s32(v1_s32, vdupq_n_s32(nudged_zero_point0)); + v0_s32 = vmaxq_s32(v0_s32, vdupq_n_s32(INT8_MIN)); + v0_s32 = vminq_s32(v0_s32, vdupq_n_s32(INT8_MAX)); + v1_s32 = vmaxq_s32(v1_s32, vdupq_n_s32(INT8_MIN)); + v1_s32 = vminq_s32(v1_s32, vdupq_n_s32(INT8_MAX)); + *((int8_t*)(dst_ptr + 0)) = (int8_t)vgetq_lane_s32(v0_s32, 0); + *((int8_t*)(dst_ptr + 1)) = (int8_t)vgetq_lane_s32(v0_s32, 1); + *((int8_t*)(dst_ptr + 2)) = (int8_t)vgetq_lane_s32(v0_s32, 2); + *((int8_t*)(dst_ptr + 3)) = (int8_t)vgetq_lane_s32(v0_s32, 3); + *((int8_t*)(dst_ptr + 4)) = (int8_t)vgetq_lane_s32(v1_s32, 0); + *((int8_t*)(dst_ptr + 5)) = (int8_t)vgetq_lane_s32(v1_s32, 1); + *((int8_t*)(dst_ptr + 6)) = (int8_t)vgetq_lane_s32(v1_s32, 2); + *((int8_t*)(dst_ptr + 7)) = (int8_t)vgetq_lane_s32(v1_s32, 3); + dst_ptr += sizeof(int8_t) * 8; + + v0_f32 = vmulq_n_f32(src1_0, scale1); + v1_f32 = vmulq_n_f32(src1_1, scale1); + v0_s32 = vcvtnq_s32_f32(v0_f32); + v1_s32 = vcvtnq_s32_f32(v1_f32); + v0_s32 = vaddq_s32(v0_s32, vdupq_n_s32(nudged_zero_point1)); + v1_s32 = vaddq_s32(v1_s32, vdupq_n_s32(nudged_zero_point1)); + v0_s32 = vmaxq_s32(v0_s32, vdupq_n_s32(INT8_MIN)); + v0_s32 = vminq_s32(v0_s32, vdupq_n_s32(INT8_MAX)); + v1_s32 = vmaxq_s32(v1_s32, vdupq_n_s32(INT8_MIN)); + v1_s32 = vminq_s32(v1_s32, vdupq_n_s32(INT8_MAX)); + *((int8_t*)(dst_ptr + 0)) = (int8_t)vgetq_lane_s32(v0_s32, 0); + *((int8_t*)(dst_ptr + 1)) = (int8_t)vgetq_lane_s32(v0_s32, 1); + *((int8_t*)(dst_ptr + 2)) = (int8_t)vgetq_lane_s32(v0_s32, 2); + *((int8_t*)(dst_ptr + 3)) = (int8_t)vgetq_lane_s32(v0_s32, 3); + *((int8_t*)(dst_ptr + 4)) = (int8_t)vgetq_lane_s32(v1_s32, 0); + *((int8_t*)(dst_ptr + 5)) = (int8_t)vgetq_lane_s32(v1_s32, 1); + *((int8_t*)(dst_ptr + 6)) = (int8_t)vgetq_lane_s32(v1_s32, 2); + *((int8_t*)(dst_ptr + 7)) = (int8_t)vgetq_lane_s32(v1_s32, 3); + dst_ptr += sizeof(int8_t) * 8; + + v0_f32 = vmulq_n_f32(src2_0, scale2); + v1_f32 = vmulq_n_f32(src2_1, scale2); + v0_s32 = vcvtnq_s32_f32(v0_f32); + v1_s32 = vcvtnq_s32_f32(v1_f32); + v0_s32 = vaddq_s32(v0_s32, vdupq_n_s32(nudged_zero_point2)); + v1_s32 = vaddq_s32(v1_s32, vdupq_n_s32(nudged_zero_point2)); + v0_s32 = vmaxq_s32(v0_s32, vdupq_n_s32(INT8_MIN)); + v0_s32 = vminq_s32(v0_s32, vdupq_n_s32(INT8_MAX)); + v1_s32 = vmaxq_s32(v1_s32, vdupq_n_s32(INT8_MIN)); + v1_s32 = vminq_s32(v1_s32, vdupq_n_s32(INT8_MAX)); + *((int8_t*)(dst_ptr + 0)) = (int8_t)vgetq_lane_s32(v0_s32, 0); + *((int8_t*)(dst_ptr + 1)) = (int8_t)vgetq_lane_s32(v0_s32, 1); + *((int8_t*)(dst_ptr + 2)) = (int8_t)vgetq_lane_s32(v0_s32, 2); + *((int8_t*)(dst_ptr + 3)) = (int8_t)vgetq_lane_s32(v0_s32, 3); + *((int8_t*)(dst_ptr + 4)) = (int8_t)vgetq_lane_s32(v1_s32, 0); + *((int8_t*)(dst_ptr + 5)) = (int8_t)vgetq_lane_s32(v1_s32, 1); + *((int8_t*)(dst_ptr + 6)) = (int8_t)vgetq_lane_s32(v1_s32, 2); + *((int8_t*)(dst_ptr + 7)) = (int8_t)vgetq_lane_s32(v1_s32, 3); + dst_ptr += sizeof(int8_t) * 8; + + v0_f32 = vmulq_n_f32(src3_0, scale3); + v1_f32 = vmulq_n_f32(src3_1, scale3); + v0_s32 = vcvtnq_s32_f32(v0_f32); + v1_s32 = vcvtnq_s32_f32(v1_f32); + v0_s32 = vaddq_s32(v0_s32, vdupq_n_s32(nudged_zero_point3)); + v1_s32 = vaddq_s32(v1_s32, vdupq_n_s32(nudged_zero_point3)); + v0_s32 = vmaxq_s32(v0_s32, vdupq_n_s32(INT8_MIN)); + v0_s32 = vminq_s32(v0_s32, vdupq_n_s32(INT8_MAX)); + v1_s32 = vmaxq_s32(v1_s32, vdupq_n_s32(INT8_MIN)); + v1_s32 = vminq_s32(v1_s32, vdupq_n_s32(INT8_MAX)); + *((int8_t*)(dst_ptr + 0)) = (int8_t)vgetq_lane_s32(v0_s32, 0); + *((int8_t*)(dst_ptr + 1)) = (int8_t)vgetq_lane_s32(v0_s32, 1); + *((int8_t*)(dst_ptr + 2)) = (int8_t)vgetq_lane_s32(v0_s32, 2); + *((int8_t*)(dst_ptr + 3)) = (int8_t)vgetq_lane_s32(v0_s32, 3); + *((int8_t*)(dst_ptr + 4)) = (int8_t)vgetq_lane_s32(v1_s32, 0); + *((int8_t*)(dst_ptr + 5)) = (int8_t)vgetq_lane_s32(v1_s32, 1); + *((int8_t*)(dst_ptr + 6)) = (int8_t)vgetq_lane_s32(v1_s32, 2); + *((int8_t*)(dst_ptr + 7)) = (int8_t)vgetq_lane_s32(v1_s32, 3); + dst_ptr += sizeof(int8_t) * 8; + } + + // Store the scale quantization params + float32x4_t vscales_f32 = vdupq_n_f32(0.0f); + vscales_f32 = vsetq_lane_f32(recip_scale0, vscales_f32, 0); + vscales_f32 = vsetq_lane_f32(recip_scale1, vscales_f32, 1); + vscales_f32 = vsetq_lane_f32(recip_scale2, vscales_f32, 2); + vscales_f32 = vsetq_lane_f32(recip_scale3, vscales_f32, 3); + vst1q_f32((float*)dst_ptr, vscales_f32); + dst_ptr += sizeof(float32x4_t); + + src_ptr0 += kai_km0 * (lhs_stride / sizeof(float)); + src_ptr1 += kai_km0 * (lhs_stride / sizeof(float)); + src_ptr2 += kai_km0 * (lhs_stride / sizeof(float)); + src_ptr3 += kai_km0 * (lhs_stride / sizeof(float)); + } +} diff --git a/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h b/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h new file mode 100644 index 00000000..11bb931e --- /dev/null +++ b/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h @@ -0,0 +1,65 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +#include +#include + +/** + * @brief Function to calculate the offset in bytes for the LHS matrix (not packed) + * + * This function should be called before passing the pointer to the LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) + * + * return the offset in bytes to the LHS matrix + */ +size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP4X8_f32(size_t m_idx, size_t lhs_stride); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP4X8_f32(size_t m_idx, size_t k); + +/** + * @brief Function to return the memory required for storing the quantized and packed LHS matrix + * + * @param[in] m Total number of rows in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). + * + * return the size in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP4X8_f32(size_t m, size_t k); + +/** + * @brief Micro-kernel to quantize and pack the LHS matrix + * + * @param[in] m The number of output rows written. It must be 1, 2, 3, 4, or any multiple of 4. + * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. + * @param[in] lhs LHS of the vector-by-matrix. + * @param[in] lhs_stride Stride in bytes between two rows of LHS. + * @param[in] lhs_p The quantized and packed LHS matrix. + */ +void kai_run_lhs_quant_pack_qa8dxP4X8_f32(size_t m, size_t k, const float* lhs, size_t lhs_stride, void* lhs_p); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c b/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c new file mode 100644 index 00000000..115e4df0 --- /dev/null +++ b/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c @@ -0,0 +1,158 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h" + +#include "../kai_common.h" + +#include +#include +#include +#include + +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); + +inline static int8_t kai_int4_sign_extend(int8_t x) { + return (x ^ 0x8) - 8; +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n_idx, size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n_idx, size_t k, size_t nr, size_t kr) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kr) == 0); + KAI_ASSERT((n_idx % nr) == 0); + + KAI_UNUSED(kr); + + const size_t rhs_packed_stride = nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + + return (n_idx / nr) * rhs_packed_stride; +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n, size_t k, size_t nr, size_t kr) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((k % kr) == 0); + KAI_ASSERT((n % nr) == 0); + + KAI_UNUSED(kr); + + const size_t num_rows = n / nr; + + const size_t rhs_packed_stride = nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + + return num_rows * rhs_packed_stride; +} + +void kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const int32_t* bias, + const float* scale, void* rhs_p, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0_params* params) { + // Temporary asserts + KAI_ASSERT(num_groups == 1); + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT((n % nr) == 0); + KAI_ASSERT((k % kr) == 0); + KAI_ASSERT(bias == NULL); + KAI_ASSERT(extra_bytes == 0); + + KAI_ASSERT(sr == 2); + KAI_ASSERT(kr >= 1 && kr <= 16); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_p != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + // Note: The input matrix (rhs) is expected with: + // "k" columns and "n" rows (NxK) + + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_stride = k / 2; + const size_t rhs_packed_stride = nr * ((k / 2) + sizeof(float) + sizeof(int32_t)); + + for (size_t y = 0; y < n; y += nr) { + const uint8_t* src_row = (const uint8_t*)rhs + y * rhs_stride; + uint8_t* dst_row = (uint8_t*)rhs_p + (y / nr) * rhs_packed_stride; + + int32_t* sums = (int32_t*)dst_row; + + // The RHS reduction sums are stored at the beginning of the row + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + + // Move the pointer after the biases + dst_row += nr * sizeof(int32_t); + + for (size_t x = 0; x < k; x += (kr * 2)) { + for (size_t s = 0; s < sr; ++s) { + for (size_t i = 0; i < nr; ++i) { + for (size_t kr_idx = 0; kr_idx < kr / sr; kr_idx += 2) { + const size_t src_addr_byte0 = i * rhs_stride + (x / 2) + kr_idx / 2 + s * (kr / sr) / 2; + const size_t src_addr_byte1 = src_addr_byte0 + (kr / 2); + + const uint8_t byte0 = src_row[src_addr_byte0]; + const uint8_t byte1 = src_row[src_addr_byte1]; + + if (rhs_zero_point == 0) { + int8_t src_x0_lo = (byte0 & 0x0F); + int8_t src_x1_lo = (byte0 >> 4); + + int8_t src_x0_hi = (byte1 & 0x0F); + int8_t src_x1_hi = (byte1 >> 4); + + const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + const int8_t dst_qs1 = src_x1_lo | (src_x1_hi << 4); + + src_x0_lo = kai_int4_sign_extend(src_x0_lo); + src_x1_lo = kai_int4_sign_extend(src_x1_lo); + src_x0_hi = kai_int4_sign_extend(src_x0_hi); + src_x1_hi = kai_int4_sign_extend(src_x1_hi); + sums[i] += src_x0_lo + src_x0_hi; + sums[i] += src_x1_lo + src_x1_hi; + + *(int8_t*)dst_row = dst_qs0; + dst_row += sizeof(int8_t); + *(int8_t*)dst_row = dst_qs1; + dst_row += sizeof(int8_t); + + } else { + const uint8_t src_x0_lo = (byte0 & 0x0F); + const uint8_t src_x1_lo = (byte0 >> 4); + + const uint8_t src_x0_hi = (byte1 & 0x0F); + const uint8_t src_x1_hi = (byte1 >> 4); + + sums[i] += src_x0_lo + src_x0_hi - 2 * rhs_zero_point; + sums[i] += src_x1_lo + src_x1_hi - 2 * rhs_zero_point; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + const uint8_t dst_qs1 = src_x1_lo | (src_x1_hi << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + *dst_row = dst_qs1 ^ 0x88; + dst_row += sizeof(uint8_t); + } + } + } + } + } + + // Adjust the scale + for (size_t i = 0; i < nr; ++i) { + ((float*)(dst_row))[i] = scale[y + i] * 0.0625f; + } + + // Adjust the scale + for (size_t i = 0; i < nr; ++i) { + sums[i] = sums[i] * 16; + } + } +} diff --git a/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h b/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h new file mode 100644 index 00000000..95b70563 --- /dev/null +++ b/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h @@ -0,0 +1,88 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +#include +#include + +struct kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0_params { + int8_t lhs_zero_point; + uint8_t rhs_zero_point; +}; + +/** + * @brief Function to calculate the offset in bytes for the RHS matrix (not packed), which holds + * the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. + * Two int4 values are stored in one byte. The lower order part of the byte (low) holds + * the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). + * + * @param[in] n_idx Row index in the RHS matrix (not packed). + * @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) + * + * return the offset in bytes to the RHS matrix (not packed) + */ +size_t kai_get_rhs_offset_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n_idx, size_t rhs_stride); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). + * @param[in] k The common dimension between the LHS and RHS matrix (K) + * @param[in] nr The number of columns written by the matmul micro-kernel + * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n_idx, size_t k, size_t nr, size_t kr); + +/** + * @brief Function to return the memory required for storing the quantized and packed RHS matrix + * + * @param[in] n The number of rows in the RHS matrix (not packed) + * @param[in] k The number of columns in the RHS matrix (not packed). + * @param[in] nr The number of columns written by the matmul micro-kernel + * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * + * return the size in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n, size_t k, size_t nr, size_t kr); + +/** + * @brief Micro-kernel to quantize and pack the RHS matrix. + * + * @note The int4 values are stored in a N x K matrix, where N is number of rows and K is the number of columns. + * Two int4 values are stored in one byte. The lower order part of the byte (low) holds + * the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). + * + * @param[in] num_groups The number of groups. It must be 1. + * @param[in] n The number of columns of the output matrix (N). + * @param[in] k The common dimension between the LHS and RHS matrix (K). + * @param[in] nr The number of columns written by the matmul micro-kernel. + * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * @param[in] sr The number of kr splits. It can be 1 (not splits) up to kr. + * However, kr must be multiple of sr. + * @param[in] rhs The RHS matrix containing the 4-bit values. + * Size in bytes is expected to be: n * k * (sizeof(uint8_t) / 2). + * @param[in] bias The biases. + * @param[in] scale The scale for each output channel. + * @param[out] rhs_p The quantized and packed RHS matrix. + * @param[in] extra_bytes Extra bytes to append to the end of each row of the quantized and packed RHS matrix. + * @param[in] params Parameters for the function. + */ +void kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const int32_t* bias, + const float* scale, void* rhs_p, size_t extra_bytes, const struct kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0_params* params); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c new file mode 100644 index 00000000..d4c61c74 --- /dev/null +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c @@ -0,0 +1,214 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h" + +#include "../../kai_common.h" + +#include +#include + +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_k0 = 64; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); + +size_t kai_get_nr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(size_t m_idx, size_t k) { + const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + + return (m_idx / kai_mr) * lhs_packed_stride; +} + +size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSERT((k % kai_k0) == 0); + KAI_ASSERT((n_idx % kai_nr) == 0); + + const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + + return (n_idx / kai_nr) * rhs_packed_stride; +} + +size_t kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((n_idx % kai_nr) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(size_t m, size_t n) { + KAI_ASSERT((n % kai_nr) == 0); + + return m * n; +} + +void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( + size_t m, size_t n, size_t k, const void* restrict lhs_p, const void* restrict rhs_p, float* restrict dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { +#if defined(__ARM_FEATURE_DOTPROD) + KAI_ASSERT(n % kai_nr == 0); + KAI_ASSERT(k % kai_k0 == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t num_rows = m; + const size_t num_cols = n; + + const size_t lhs_packed_stride = kai_mr * (k + sizeof(float) + sizeof(float)); + + const int8x16_t nibble_mask = vdupq_n_s8(0xF0); + + const uint8_t* lhs_ptr_start = lhs_p; + + for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { + const uint8_t* rhs_ptr = rhs_p; + for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { + // Main f32 accumulator + int32x4_t iacc0011 = vdupq_n_s32(0); + int32x4_t iacc2233 = vdupq_n_s32(0); + + const uint8_t* lhs_ptr = lhs_ptr_start; + const uint8_t* rhs_ptr_start = rhs_ptr; + + // LHS offset is stored at the beginning of the row + lhs_ptr += sizeof(int32_t); + // The RHS reduction sums are stored at the beginning of the row + rhs_ptr += sizeof(int32x4_t); + + for (size_t b = 0; b < k; b += kai_k0) { + // Set up RHS + const int8x16_t rhs_raw_vec_0 = vld1q_s8((const int8_t*)(rhs_ptr + 0)); + const int8x16_t rhs_raw_vec_1 = vld1q_s8((const int8_t*)(rhs_ptr + 16)); + const int8x16_t rhs_raw_vec_2 = vld1q_s8((const int8_t*)(rhs_ptr + 32)); + const int8x16_t rhs_raw_vec_3 = vld1q_s8((const int8_t*)(rhs_ptr + 48)); + const int8x16_t rhs_raw_vec_4 = vld1q_s8((const int8_t*)(rhs_ptr + 64)); + const int8x16_t rhs_raw_vec_5 = vld1q_s8((const int8_t*)(rhs_ptr + 80)); + const int8x16_t rhs_raw_vec_6 = vld1q_s8((const int8_t*)(rhs_ptr + 96)); + const int8x16_t rhs_raw_vec_7 = vld1q_s8((const int8_t*)(rhs_ptr + 112)); + + // Low nibble + const int8x16_t rhs_vec_0_0 = vshlq_n_s8(rhs_raw_vec_0, 4); + const int8x16_t rhs_vec_1_0 = vshlq_n_s8(rhs_raw_vec_1, 4); + const int8x16_t rhs_vec_2_0 = vshlq_n_s8(rhs_raw_vec_2, 4); + const int8x16_t rhs_vec_3_0 = vshlq_n_s8(rhs_raw_vec_3, 4); + const int8x16_t rhs_vec_4_0 = vshlq_n_s8(rhs_raw_vec_4, 4); + const int8x16_t rhs_vec_5_0 = vshlq_n_s8(rhs_raw_vec_5, 4); + const int8x16_t rhs_vec_6_0 = vshlq_n_s8(rhs_raw_vec_6, 4); + const int8x16_t rhs_vec_7_0 = vshlq_n_s8(rhs_raw_vec_7, 4); + + // High nibble + const int8x16_t rhs_vec_0_1 = vandq_s8(rhs_raw_vec_0, nibble_mask); + const int8x16_t rhs_vec_1_1 = vandq_s8(rhs_raw_vec_1, nibble_mask); + const int8x16_t rhs_vec_2_1 = vandq_s8(rhs_raw_vec_2, nibble_mask); + const int8x16_t rhs_vec_3_1 = vandq_s8(rhs_raw_vec_3, nibble_mask); + const int8x16_t rhs_vec_4_1 = vandq_s8(rhs_raw_vec_4, nibble_mask); + const int8x16_t rhs_vec_5_1 = vandq_s8(rhs_raw_vec_5, nibble_mask); + const int8x16_t rhs_vec_6_1 = vandq_s8(rhs_raw_vec_6, nibble_mask); + const int8x16_t rhs_vec_7_1 = vandq_s8(rhs_raw_vec_7, nibble_mask); + + const int8x16_t lhs_vec_0 = vld1q_s8((const int8_t*)(lhs_ptr + 0)); + const int8x16_t lhs_vec_1 = vld1q_s8((const int8_t*)(lhs_ptr + 16)); + const int8x16_t lhs_vec_2 = vld1q_s8((const int8_t*)(lhs_ptr + 32)); + const int8x16_t lhs_vec_3 = vld1q_s8((const int8_t*)(lhs_ptr + 48)); + + lhs_ptr += 64; + rhs_ptr += 128; + + int8x16_t t; + + t = vcombine_s8(vget_low_s8(lhs_vec_0), vget_low_s8(lhs_vec_0)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_0, t); + t = vcombine_s8(vget_high_s8(lhs_vec_0), vget_high_s8(lhs_vec_0)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_2_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_3_0, t); + t = vcombine_s8(vget_low_s8(lhs_vec_1), vget_low_s8(lhs_vec_1)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_1, t); + t = vcombine_s8(vget_high_s8(lhs_vec_1), vget_high_s8(lhs_vec_1)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_2_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_3_1, t); + + t = vcombine_s8(vget_low_s8(lhs_vec_2), vget_low_s8(lhs_vec_2)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_4_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_5_0, t); + t = vcombine_s8(vget_high_s8(lhs_vec_2), vget_high_s8(lhs_vec_2)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_6_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_7_0, t); + t = vcombine_s8(vget_low_s8(lhs_vec_3), vget_low_s8(lhs_vec_3)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_4_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_5_1, t); + t = vcombine_s8(vget_high_s8(lhs_vec_3), vget_high_s8(lhs_vec_3)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_6_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_7_1, t); + } + + int32x4_t iacc = vpaddq_s32(iacc0011, iacc2233); + + // RHS scale + const float32x4_t rhs_scale = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + + // LHS scale + const float32x4_t lhs_scale = vld1q_dup_f32((const float*)lhs_ptr); + lhs_ptr += sizeof(float); + + // LHS offset + const int32x4_t lhs_offset = vld1q_dup_s32((const int32_t*)lhs_ptr_start); + + // RHS sum values + const int32x4_t sum_n_s32 = vld1q_s32((const int32_t*)(rhs_ptr_start)); + + // Add the reduction sum + iacc = vmlaq_s32(iacc, sum_n_s32, lhs_offset); + + float32x4_t main_acc = vmulq_f32(vcvtq_f32_s32(iacc), rhs_scale); + + main_acc = vmulq_f32(main_acc, lhs_scale); + + // Clip (min-max) operation + const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); + const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); + + main_acc = vmaxq_f32(main_acc, vmin_f32); + main_acc = vminq_f32(main_acc, vmax_f32); + + vst1q_f32((float*)((uint8_t*)dst + col_idx * sizeof(float) + row_idx * dst_stride_row), main_acc); + } + lhs_ptr_start += lhs_packed_stride; + } +#else + KAI_ASSERT(false); + KAI_UNUSED(m); + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(lhs_p); + KAI_UNUSED(rhs_p); + KAI_UNUSED(dst); + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(scalar_min); + KAI_UNUSED(scalar_max); +#endif +} diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h new file mode 100644 index 00000000..878beb25 --- /dev/null +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h @@ -0,0 +1,115 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +#include +#include + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the nr value + */ +size_t kai_get_nr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void); + +/** + * @brief Function to get the kr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the kr value + */ +size_t kai_get_kr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the sr value + */ +size_t kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). It must be a multiple of 64. + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(size_t m_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. + * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be a multiple of 64. + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(size_t n_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the DST matrix + * + * @param[in] m_idx Row index in the DST matrix. + * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * + * return the DST offset in bytes + */ +size_t kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/** + * @brief Function to query the size in bytes for the constant workspace. + * + * @param[in] m Number of rows in the destination (DST) matrix + * @param[in] n Number of columns in the destination (DST) matrix + */ +size_t kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(size_t m, size_t n); + +/** + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clip (min-max) operation. + * + * LHS matrix: 8-bit quantized asymmeitric per-row (qa8dx) and packed + * RHS matrix: 4-bit quantized symmetric per-channel (qs4cx) and packed. + * Output tile: (rows x cols) = 1 x 4 + * Accumulation performed in a single for loop: 64 + * Instruction used: dotprod + * + * @param[in] m The number of output rows written. + * @param[in] n The number of output columns written. It must be a multiple of 4. + * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 64. + * @param[in] lhs_p The LHS matrix packed. + * When the activation are dynamically quantized, you can obtain this matrix + * by calling the @ref kai_lhs_quant_pack_qa8dxP1X8_f32 micro-kernel which performs + * both the dynamic quantization to 8-bit and activation packing in a single step. + * @param[in] rhs_p The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 + * @param[out] dst Result of the vector-by-matrix + * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. + * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) + * @param[in] scalar_min Min value used to clip the final result. + * @param[in] scalar_max Max value used to clip the final result. + */ +void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c new file mode 100644 index 00000000..a3a81282 --- /dev/null +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c @@ -0,0 +1,240 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h" + +#include "../../kai_common.h" + +#include +#include + +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_k0 = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); + +size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m_idx, size_t k) { + const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + + return (m_idx / kai_mr) * lhs_packed_stride; +} + +size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((k % kai_k0) == 0); + KAI_ASSERT((n_idx % kai_nr) == 0); + + const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + + return (n_idx / kai_nr) * rhs_packed_stride; +} + +size_t +kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((n_idx % kai_nr) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m, size_t n) { + KAI_ASSERT((n % kai_nr) == 0); + + return m * n; +} + +void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* restrict lhs_p, const void* restrict rhs_p, float* restrict dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { +#if defined(__ARM_FEATURE_MATMUL_INT8) + KAI_ASSERT(m <= 3 || (m % kai_mr == 0)); + KAI_ASSERT(n % kai_nr == 0); + KAI_ASSERT(k % kai_k0 == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t num_rows = m; + const size_t num_cols = n; + + const size_t lhs_packed_stride = kai_mr * (k + sizeof(float32_t) + sizeof(float32_t)); + + const int8x16_t nibble_mask = vdupq_n_s8(0xF0); + + const uint8_t* lhs_ptr_start = lhs_p; + + for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { + const uint8_t* rhs_ptr = rhs_p; + + for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { + const uint8_t* lhs_ptr = lhs_ptr_start; + const uint8_t* rhs_ptr_start = rhs_ptr; + + // Main f32 accumulator + int32x4_t iacc_mat_00 = vdupq_n_s32(0); + int32x4_t iacc_mat_01 = vdupq_n_s32(0); + int32x4_t iacc_mat_10 = vdupq_n_s32(0); + int32x4_t iacc_mat_11 = vdupq_n_s32(0); + + // LHS offset is stored at the beginning of the row + lhs_ptr += sizeof(int32x4_t); + // The RHS reduction sums are stored at the beginning of the row + rhs_ptr += sizeof(int32x4_t); + + for (size_t b = 0; b < k; b += kai_k0) { + // Set up RHS + const int8x16_t rhs_raw_mat_01_0 = vld1q_s8((const int8_t*)rhs_ptr + 0); + const int8x16_t rhs_raw_mat_23_0 = vld1q_s8((const int8_t*)rhs_ptr + 16); + const int8x16_t rhs_raw_mat_01_1 = vld1q_s8((const int8_t*)rhs_ptr + 32); + const int8x16_t rhs_raw_mat_23_1 = vld1q_s8((const int8_t*)rhs_ptr + 48); + + // Low nibble + const int8x16_t rhs_mat_01_0 = vshlq_n_s8(rhs_raw_mat_01_0, 4); + const int8x16_t rhs_mat_23_0 = vshlq_n_s8(rhs_raw_mat_23_0, 4); + const int8x16_t rhs_mat_01_1 = vshlq_n_s8(rhs_raw_mat_01_1, 4); + const int8x16_t rhs_mat_23_1 = vshlq_n_s8(rhs_raw_mat_23_1, 4); + + // High nibble + const int8x16_t rhs_mat_01_2 = vandq_s8(rhs_raw_mat_01_0, nibble_mask); + const int8x16_t rhs_mat_23_2 = vandq_s8(rhs_raw_mat_23_0, nibble_mask); + const int8x16_t rhs_mat_01_3 = vandq_s8(rhs_raw_mat_01_1, nibble_mask); + const int8x16_t rhs_mat_23_3 = vandq_s8(rhs_raw_mat_23_1, nibble_mask); + + // Process LHS in pairs of rows + const int8x16_t lhs_mat_01_0 = vld1q_s8((const int8_t*)lhs_ptr + 0); + const int8x16_t lhs_mat_23_0 = vld1q_s8((const int8_t*)lhs_ptr + 16); + const int8x16_t lhs_mat_01_1 = vld1q_s8((const int8_t*)lhs_ptr + 32); + const int8x16_t lhs_mat_23_1 = vld1q_s8((const int8_t*)lhs_ptr + 48); + const int8x16_t lhs_mat_01_2 = vld1q_s8((const int8_t*)lhs_ptr + 64); + const int8x16_t lhs_mat_23_2 = vld1q_s8((const int8_t*)lhs_ptr + 80); + const int8x16_t lhs_mat_01_3 = vld1q_s8((const int8_t*)lhs_ptr + 96); + const int8x16_t lhs_mat_23_3 = vld1q_s8((const int8_t*)lhs_ptr + 112); + + // Do the MMLAs into 2x2 matrices + iacc_mat_00 = vmmlaq_s32( + vmmlaq_s32( + vmmlaq_s32(vmmlaq_s32(iacc_mat_00, lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), + lhs_mat_01_2, rhs_mat_01_2), + lhs_mat_01_3, rhs_mat_01_3); + iacc_mat_01 = vmmlaq_s32( + vmmlaq_s32( + vmmlaq_s32(vmmlaq_s32(iacc_mat_01, lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), + lhs_mat_01_2, rhs_mat_23_2), + lhs_mat_01_3, rhs_mat_23_3); + iacc_mat_10 = vmmlaq_s32( + vmmlaq_s32( + vmmlaq_s32(vmmlaq_s32(iacc_mat_10, lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), + lhs_mat_23_2, rhs_mat_01_2), + lhs_mat_23_3, rhs_mat_01_3); + iacc_mat_11 = vmmlaq_s32( + vmmlaq_s32( + vmmlaq_s32(vmmlaq_s32(iacc_mat_11, lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), + lhs_mat_23_2, rhs_mat_23_2), + lhs_mat_23_3, rhs_mat_23_3); + + // Straighten out to make 4 row vectors + lhs_ptr += 128; + rhs_ptr += 64; + } + + // RHS scale + const float32x4_t rhs_scale = vld1q_f32((const float32_t*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + + // LHS scale + const float32x4_t lhs_scale = vld1q_f32((const float32_t*)lhs_ptr); + lhs_ptr += sizeof(float32x4_t); + + int32x4_t iacc_row_0 = vreinterpretq_s32_u64( + vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); + int32x4_t iacc_row_1 = vreinterpretq_s32_u64( + vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); + int32x4_t iacc_row_2 = vreinterpretq_s32_u64( + vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); + int32x4_t iacc_row_3 = vreinterpretq_s32_u64( + vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); + + // LHS offset + const int32x4_t lhs_offset = vld1q_s32((const int32_t*)lhs_ptr_start); + + // RHS sum values + const int32x4_t sum_n_s32 = vld1q_s32((const int32_t*)(rhs_ptr_start)); + + // Add the RHS reduction sum + iacc_row_0 = vmlaq_laneq_s32(iacc_row_0, sum_n_s32, lhs_offset, 0); + iacc_row_1 = vmlaq_laneq_s32(iacc_row_1, sum_n_s32, lhs_offset, 1); + iacc_row_2 = vmlaq_laneq_s32(iacc_row_2, sum_n_s32, lhs_offset, 2); + iacc_row_3 = vmlaq_laneq_s32(iacc_row_3, sum_n_s32, lhs_offset, 3); + + float32x4_t main_acc0 = vmulq_f32(vcvtq_f32_s32(iacc_row_0), rhs_scale); + float32x4_t main_acc1 = vmulq_f32(vcvtq_f32_s32(iacc_row_1), rhs_scale); + float32x4_t main_acc2 = vmulq_f32(vcvtq_f32_s32(iacc_row_2), rhs_scale); + float32x4_t main_acc3 = vmulq_f32(vcvtq_f32_s32(iacc_row_3), rhs_scale); + + main_acc0 = vmulq_laneq_f32(main_acc0, lhs_scale, 0); + main_acc1 = vmulq_laneq_f32(main_acc1, lhs_scale, 1); + main_acc2 = vmulq_laneq_f32(main_acc2, lhs_scale, 2); + main_acc3 = vmulq_laneq_f32(main_acc3, lhs_scale, 3); + + // Clip (min-max) operation + const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); + const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); + + main_acc0 = vmaxq_f32(main_acc0, vmin_f32); + main_acc0 = vminq_f32(main_acc0, vmax_f32); + main_acc1 = vmaxq_f32(main_acc1, vmin_f32); + main_acc1 = vminq_f32(main_acc1, vmax_f32); + main_acc2 = vmaxq_f32(main_acc2, vmin_f32); + main_acc2 = vminq_f32(main_acc2, vmax_f32); + + // Stores the rows in reverse order to avoid out-of-bound writes. + // Override out-of-bound values with in-bound values + vst1q_f32( + (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 3, m - 1) * dst_stride_row), + main_acc3); + vst1q_f32( + (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 2, m - 1) * dst_stride_row), + main_acc2); + vst1q_f32( + (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 1, m - 1) * dst_stride_row), + main_acc1); + vst1q_f32( + (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 0, m - 1) * dst_stride_row), + main_acc0); + } + + lhs_ptr_start += lhs_packed_stride; + } +#else + KAI_ASSERT(false); + KAI_UNUSED(m); + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(lhs_p); + KAI_UNUSED(rhs_p); + KAI_UNUSED(dst); + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(scalar_min); + KAI_UNUSED(scalar_max); +#endif +} diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h new file mode 100644 index 00000000..b89a2a46 --- /dev/null +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h @@ -0,0 +1,115 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +#include +#include + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the nr value + */ +size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void); + +/** + * @brief Function to get the kr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the kr value + */ +size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the sr value + */ +size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). It must be a multiple of 64. + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. + * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be a multiple of 64. + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t n_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the DST matrix + * + * @param[in] m_idx Row index in the DST matrix. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * + * return the DST offset in bytes + */ +size_t +kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride); + +/** + * @brief Function to query the size in bytes for the constant workspace. + * + * @param[in] m Number of rows in the destination (DST) matrix. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] n Number of columns in the destination (DST) matrix. It must be a multiple 4. + */ +size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m, size_t n); + +/** + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clip (min-max) operation. + * + * LHS matrix: 8-bit quantized asymmeitric per-row (qa8dx) and packed + * RHS matrix: 4-bit quantized symmetric per-channel (qs4cx) and packed. + * Output tile: (rows x cols) = 4 x 4 + * Accumulation performed in a single for loop: 64 + * Instruction used: i8mm + * + * @param[in] m The number of output rows written. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] n The number of output columns written. It must be a multiple of 4. + * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 64. + * @param[in] lhs_p The LHS matrix packed. + * When the activation are dynamically quantized, you can obtain this matrix + * by calling the @ref kai_lhs_quant_pack_qa8dxP1X8_f32 micro-kernel which performs + * both the dynamic quantization to 8-bit and activation packing in a single step. + * @param[in] rhs_p The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 + * @param[out] dst Result of the vector-by-matrix + * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. + * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) + * @param[in] scalar_min Min value used to clip the final result. + * @param[in] scalar_max Max value used to clip the final result. + */ +void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif -- GitLab From 2a6b7b49db070c6472a2f6f0d6c80e858700759f Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 26 Apr 2024 15:42:21 +0100 Subject: [PATCH 02/14] Add assembly ukernels for QA8DX (LHS) x QS4CX (RHS) -> F32 (DST) - Add assembly ukernels for the 4x4, 8x4 and 8x8 variants - Implement a generic packing function for the LHS matrix that also performs the quantization - Add example to demonstrate how to use the ukernels in the context of int4 quantization Signed-off-by: Gian Marco Iodice --- .editorconfig | 4 + .../matmul_f32_qa8dx_qs4cx/CMakeLists.txt | 40 ++ .../matmul_f32_qa8dx_qs4cx.cpp | 468 ++++++++++++++ src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.c | 142 ----- src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c | 315 ---------- src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h | 65 -- src/matmul/kai_lhs_quant_pack_qa8dxP_f32.c | 154 +++++ ..._f32.h => kai_lhs_quant_pack_qa8dxP_f32.h} | 9 +- .../kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c | 16 +- ..._qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c | 34 +- ..._qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h | 8 + ..._qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.c | 223 +++++++ ..._qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h | 123 ++++ ...f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c | 291 ++++----- ...f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h | 14 +- ...f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.c | 330 ++++++++++ ...f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h | 123 ++++ ...f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.c | 397 ++++++++++++ ...f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h | 123 ++++ ...f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.c | 582 ++++++++++++++++++ ...f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h | 123 ++++ 21 files changed, 2875 insertions(+), 709 deletions(-) create mode 100644 examples/matmul_f32_qa8dx_qs4cx/CMakeLists.txt create mode 100644 examples/matmul_f32_qa8dx_qs4cx/matmul_f32_qa8dx_qs4cx.cpp delete mode 100644 src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.c delete mode 100644 src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c delete mode 100644 src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h create mode 100644 src/matmul/kai_lhs_quant_pack_qa8dxP_f32.c rename src/matmul/{kai_lhs_quant_pack_qa8dxP1X8_f32.h => kai_lhs_quant_pack_qa8dxP_f32.h} (81%) create mode 100644 src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.c create mode 100644 src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h create mode 100644 src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.c create mode 100644 src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h create mode 100644 src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.c create mode 100644 src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h create mode 100644 src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.c create mode 100644 src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h diff --git a/.editorconfig b/.editorconfig index 6b724e49..d5ad33f9 100644 --- a/.editorconfig +++ b/.editorconfig @@ -20,6 +20,10 @@ trim_trailing_whitespace = true [*.{json,yml,yaml}] indent_size = 2 +# Override settings. +[*.{c,cpp,h,hpp}] +indent_size = unset + # Override settings. [LICENSES/*] indent_size = unset diff --git a/examples/matmul_f32_qa8dx_qs4cx/CMakeLists.txt b/examples/matmul_f32_qa8dx_qs4cx/CMakeLists.txt new file mode 100644 index 00000000..29a24a14 --- /dev/null +++ b/examples/matmul_f32_qa8dx_qs4cx/CMakeLists.txt @@ -0,0 +1,40 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.16) + +# KleidiAI include directories +include_directories( + ../../src/ + ../../src/matmul/ + ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") + +# Project name +project(matmul_f32_qa8dx_qs4cx) + +# Files requires to build the executable +add_executable(matmul_f32_qa8dx_qs4cx + matmul_f32_qa8dx_qs4cx.cpp + ../../src/kai_common.h + ../../src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h + ../../src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c + ../../src/matmul/kai_lhs_quant_pack_qa8dxP_f32.h + ../../src/matmul/kai_lhs_quant_pack_qa8dxP_f32.c + ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h + ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c + ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h + ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.c + ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h + ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c + ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h + ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.c + ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h + ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.c + ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h + ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.c) + diff --git a/examples/matmul_f32_qa8dx_qs4cx/matmul_f32_qa8dx_qs4cx.cpp b/examples/matmul_f32_qa8dx_qs4cx/matmul_f32_qa8dx_qs4cx.cpp new file mode 100644 index 00000000..2110fe46 --- /dev/null +++ b/examples/matmul_f32_qa8dx_qs4cx/matmul_f32_qa8dx_qs4cx.cpp @@ -0,0 +1,468 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Include micro-kernel variants +#include "kai_lhs_quant_pack_qa8dxP_f32.h" +#include "kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h" +#include "kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h" +#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h" +#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h" +#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h" +#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h" +#include "kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h" + +#include +#include +#include +#include +#include +#include + +#define INT4_MIN (-8) +#define INT4_MAX (7) + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clip_f32_qa8dxP_qs4cxP + +// Micro-kernel helper interfaces ("get" methods) +typedef size_t (*kai_get_mr_func_t)(void); +typedef size_t (*kai_get_nr_func_t)(void); +typedef size_t (*kai_get_kr_func_t)(void); +typedef size_t (*kai_get_sr_func_t)(void); +typedef size_t (*kai_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_get_dst_size_func_t)(size_t m, size_t n); + +// Micro-kernel core interface ("run" method) +typedef void (*kai_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +// Micro-kernel interface +struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { + kai_get_mr_func_t get_mr = nullptr; + kai_get_nr_func_t get_nr = nullptr; + kai_get_nr_func_t get_kr = nullptr; + kai_get_sr_func_t get_sr = nullptr; + kai_get_lhs_packed_offset_func_t get_lhs_packed_offset = nullptr; + kai_get_rhs_packed_offset_func_t get_rhs_packed_offset = nullptr; + kai_get_dst_offset_func_t get_dst_offset = nullptr; + kai_get_dst_size_func_t get_dst_size = nullptr; + kai_run_matmul_func_t run_matmul = nullptr; + std::string name = {}; +}; + +kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { + {kai_get_mr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, + kai_get_nr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, + kai_get_kr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, + kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, + kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, + kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, + kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, + "matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod"}, + {kai_get_mr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, + kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, + "matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod"}, + {kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, + kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, + kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, + kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, + kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, + kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, + "matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm"}, + {kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, + kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, + kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, + kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, + kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, + kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, + "matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm"}, + {kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, + kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, + kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, + kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, + kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, + kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, + "matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm"}, + {kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, + kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, + kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, + kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, + kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, + kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, + "matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm"}, +}; + +// Number of micro-kernel variants stored in the array +const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); + +static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, size_t seed) { + + std::srand(seed); + + // Fill the array with random values between -1 and 1 + for (int i = 0; i < num_rows * num_cols; i++) { + dst[i] = (float)((double)std::rand() / RAND_MAX) * 2 - 1; + } +} + +static void quant_qs4cx_f32(size_t n, size_t k, const float* rhs_f32, uint8_t* rhs_qs4cx, float* rhs_scales_f32) { + + const size_t dst_stride = (k / 2) * sizeof(int8_t); + + for (size_t row_idx = 0; row_idx < n; ++row_idx) { + + const float* src_ptr = rhs_f32 + row_idx * k; + + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + max0 = std::max(src0_0, max0); + min0 = std::min(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT4_MIN; + const float qmax = (float)INT4_MAX; + + const float rmin0 = std::min(0.0f, min0); + const float rmax0 = std::max(0.0f, max0); + + const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + + uint8_t* dst_ptr = (uint8_t*)rhs_qs4cx + row_idx * dst_stride; + + // Quantize the channels + for (size_t k_idx = 0; k_idx < k; k_idx += 2) { + const float src0_0 = src_ptr[k_idx + 0]; + const float src0_1 = src_ptr[k_idx + 1]; + + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); + int32_t v1_s32 = (int32_t)(round(src0_1 * scale0)); + + // Maximum/minimum int4 values + v0_s32 = std::max(v0_s32, INT4_MIN); + v0_s32 = std::min(v0_s32, INT4_MAX); + v1_s32 = std::max(v1_s32, INT4_MIN); + v1_s32 = std::min(v1_s32, INT4_MAX); + + int32_t v0_u8 = (uint8_t)(v0_s32 + 8); + int32_t v1_u8 = (uint8_t)(v1_s32 + 8); + + const uint8_t rhs_v0 = (v1_u8 << 4) | v0_u8; + + dst_ptr[0] = rhs_v0; + dst_ptr += sizeof(uint8_t); + } + + rhs_scales_f32[row_idx] = recip_scale0; + } +}; + +static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { + + const size_t dst_stride = (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + + const float* src_ptr = lhs_f32 + row_idx * k; + + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + max0 = std::max(src0_0, max0); + min0 = std::min(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = std::min(0.0f, min0); + const float rmax0 = std::max(0.0f, max0); + + const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; + + zero_point0 = std::max(zero_point0, qmin); + zero_point0 = std::min(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = lrintf(zero_point0); + + int8_t* dst_ptr = (int8_t*)lhs_qa8dx + row_idx * dst_stride; + + // LHS offset at the beginning of the row + *((float*)(dst_ptr)) = recip_scale0; + dst_ptr += sizeof(float); + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + dst_ptr += sizeof(int32_t); + + // Quantize the channels + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = std::max(v0_s32, INT8_MIN); + v0_s32 = std::min(v0_s32, INT8_MAX); + dst_ptr[0] = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + } +}; + +static void ref_matmul_f32_qa8dx_qs4cx( + size_t m, size_t n, size_t k, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4cx, const float* rhs_scales_f32, + float* dst_f32, float scalar_min, float scalar_max) { + + const size_t lhs_stride = k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); + const size_t rhs_stride = (k / 2) * sizeof(uint8_t); + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + + const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; + for (size_t col_idx = 0; col_idx < n; ++col_idx) { + + // Main f32 accumulator + int32_t iacc = 0; + + const int8_t* lhs_ptr = lhs_ptr_start; + const uint8_t* rhs_ptr = rhs_qs4cx + col_idx * rhs_stride; + + // Get the LHS quantization parameters stored at the + // beginning of each row + const float lhs_scale = *(const float*)lhs_ptr; + lhs_ptr += sizeof(float); + + const int32_t lhs_offset = *(const int32_t*)lhs_ptr; + lhs_ptr += sizeof(int32_t); + + for (size_t b = 0; b < k; b += 2) { + // Get the LHS values + const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; + const int32_t lhs_v1 = (int32_t)lhs_ptr[1]; + + // Get the RHS values + const uint8_t rhs_byte = rhs_ptr[0]; + + // Unpack the RHS values + const int32_t rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8); + const int32_t rhs_v1 = (((int32_t)(rhs_byte >> 4)) - 8); + + iacc += lhs_v0 * rhs_v0; + iacc += lhs_v1 * rhs_v1; + iacc += lhs_offset * rhs_v0; + iacc += lhs_offset * rhs_v1; + + lhs_ptr += 2; + rhs_ptr += 1; + } + + // Get the RHS scale + const float rhs_scale = rhs_scales_f32[col_idx]; + + float main_acc = iacc * rhs_scale; + + main_acc = main_acc * lhs_scale; + + // Clamp (min-max) operation + main_acc = std::max(main_acc, scalar_min); + main_acc = std::min(main_acc, scalar_max); + + dst_f32[0] = main_acc; + dst_f32 += 1; + } + } +}; + +static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, const float* ref, const float* act) { + + bool is_valid = true; + + for (size_t i = 0; i < num_rows * num_cols; ++i) { + if (std::fabs(ref[i] - act[i]) > tolerance) { + const size_t x = i % num_cols; + const size_t y = i / num_cols; + printf("ERROR![%ld][%ld]: ref=%.5f vs. act=%.5f\n", y, x, ref[i], act[i]); + is_valid = false; + } + } + return is_valid; +} + +int main(int argc, char** argv) { + + const size_t m = 17; + const size_t n = 32; // It must be a multiple of 8 + const size_t k = 64; // It must be a multiple of 64 + const size_t seed_lhs = 4568; + const size_t seed_rhs = seed_lhs + 4; + + const size_t lhs_native_size_f32 = m * k * sizeof(float); + const size_t rhs_native_size_f32 = n * k * sizeof(float); + const size_t rhs_native_size_qs4cx = n * (k / 2) * sizeof(uint8_t); + const size_t rhs_scales_size_f32 = n * sizeof(float); + + // Allocate the memory + uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; + uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; + uint8_t* rhs_native_mtx_qs4cx = new uint8_t[rhs_native_size_qs4cx]; + uint8_t* rhs_scales_f32 = new uint8_t[rhs_scales_size_f32]; + + fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); + + fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); + + quant_qs4cx_f32(n, k, (const float*)rhs_native_mtx_f32, (uint8_t*)rhs_native_mtx_qs4cx, (float*)rhs_scales_f32); + + delete[] rhs_native_mtx_f32; + + //----------- REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + // Memory sizes for the reference implementation + // After dynamically quantized the LHS matrix, we have the scale and offset for each + // row. The scale (f32) and offset (int32) are stored at the beginning of each row + size_t lhs_ref_size_qa8dx = m * (k + sizeof(int32_t) + sizeof(float)); + size_t dst_ref_size_f32 = m * n * sizeof(float); + + uint8_t* lhs_ref_mtx_qa8dx = new uint8_t[lhs_ref_size_qa8dx]; + uint8_t* dst_ref_mtx_f32 = new uint8_t[dst_ref_size_f32]; + + ref_quant_qa8dx_f32(m, k, (const float*)lhs_native_mtx_f32, (int8_t*)lhs_ref_mtx_qa8dx); + + ref_matmul_f32_qa8dx_qs4cx( + m, n, k, (const int8_t*)lhs_ref_mtx_qa8dx, (const uint8_t*)rhs_native_mtx_qs4cx, (const float*)rhs_scales_f32, + (float*)dst_ref_mtx_f32, -FLT_MAX, FLT_MAX); + + // Remove the unnecessary buffer + delete[] lhs_ref_mtx_qa8dx; + + //----------- END REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + + //----------- MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + for (size_t idx_variant = 0; idx_variant < num_ukernel_variants; ++idx_variant) { + + std::cout << "Testing " << ukernel_variants[idx_variant].name << std::endl; + + const size_t mr = ukernel_variants[idx_variant].get_mr(); + const size_t nr = ukernel_variants[idx_variant].get_nr(); + const size_t kr = ukernel_variants[idx_variant].get_kr(); + const size_t sr = ukernel_variants[idx_variant].get_sr(); + + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP_f32(m, k, mr, kr); + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qs4cxP_qs4cxS1S0(n, k, nr, kr); + const size_t dst_size = ukernel_variants[idx_variant].get_dst_size(m, n); + + uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; + uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; + uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; + + // LHS packing + kai_run_lhs_quant_pack_qa8dxP_f32( + m, k, mr, kr, sr, (const float*)lhs_native_mtx_f32, k * sizeof(float), lhs_packed_mtx_qa8dx); + + struct kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + + // RHS packing + kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0( + 1, n, k, nr, kr, sr, + (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS + NULL, // Bias + (const float*)(rhs_scales_f32), // Scale + rhs_packed_mtx_qs4cx, // DST + 0, ¶ms); + + { + const size_t dst_stride = n * sizeof(float); + + const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + + ukernel_variants[idx_variant].get_lhs_packed_offset(0, k)); + const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + + ukernel_variants[idx_variant].get_rhs_packed_offset(0, k)); + float* dst_ptr = + (float*)((uint8_t*)dst_act_mtx_f32 + ukernel_variants[idx_variant].get_dst_offset(0, 0, dst_stride)); + + ukernel_variants[idx_variant].run_matmul( + m, n, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + } + + const bool is_valid = + is_output_correct(m, n, 0.0001f, (const float*)dst_ref_mtx_f32, (const float*)dst_act_mtx_f32); + + if (is_valid) { + printf("TEST[%ld] = PASSED\n", idx_variant); + } else { + printf("TEST[%ld] = FAILED\n", idx_variant); + } + delete[] lhs_packed_mtx_qa8dx; + delete[] rhs_packed_mtx_qs4cx; + delete[] dst_act_mtx_f32; + } + delete[] lhs_native_mtx_f32; + delete[] rhs_native_mtx_qs4cx; + delete[] rhs_scales_f32; + delete[] dst_ref_mtx_f32; +} + +//----------- END MICRO-KERNELS TESTS +//------------------------------------ +//------------------------------------ diff --git a/src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.c b/src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.c deleted file mode 100644 index f4a4f752..00000000 --- a/src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.c +++ /dev/null @@ -1,142 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#include "kai_lhs_quant_pack_qa8dxP1X8_f32.h" - -#include "../kai_common.h" - -#include -#include -#include -#include - -static const size_t kai_kk0 = 8; -static const size_t kai_km0 = 1; -static const size_t kai_num_bytes_per_multiplier = sizeof(float); -static const size_t kai_num_bytes_per_offset = sizeof(float); - -size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP1X8_f32(size_t m_idx, size_t lhs_stride) { - return m_idx * lhs_stride; -} - -size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP1X8_f32(size_t m_idx, size_t k) { - const size_t dst_stride = kai_km0 * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); - - return (m_idx / kai_km0) * dst_stride; -} - -size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP1X8_f32(size_t m, size_t k) { - const size_t m_roundup4 = kai_roundup(m, kai_km0); - - const size_t t_size = m_roundup4 * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); - return t_size; -} - -void kai_run_lhs_quant_pack_qa8dxP1X8_f32( - size_t m, size_t k, const float* restrict lhs, size_t lhs_stride, void* restrict lhs_p) { - KAI_ASSERT(k % kai_kk0 == 0); - - if (m == 0) { - return; - } - - const size_t num_rows = m; - const size_t num_cols = k; - - const float* src_ptr = lhs; - uint8_t* dst_ptr = lhs_p; - - for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_km0) { - float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); - float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); - - // Find min/max for each channel - for (size_t j = 0; j < num_cols; j += 8) { - const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + j); - const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + j); - - // Calculate the max - vmax0 = vmaxq_f32(src0_0, vmax0); - vmax0 = vmaxq_f32(vmax0, src0_1); - - // Calculate the min - vmin0 = vminq_f32(src0_0, vmin0); - vmin0 = vminq_f32(vmin0, src0_1); - } - - // Get the max/min from each row - const float max0 = vmaxvq_f32(vmax0); - const float min0 = vminvq_f32(vmin0); - - // Maximum/minimum int8 values - const float qmin = (float)INT8_MIN; - const float qmax = (float)INT8_MAX; - - const float rmin0 = KAI_MIN(0.0f, min0); - const float rmax0 = KAI_MAX(0.0f, max0); - - const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); - - // Reciprocal to quantize - const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; - - const float descaled_min0 = rmin0 * scale0; - const float descaled_max0 = rmax0 * scale0; - - const float zero_point_from_min_error0 = qmin + descaled_min0; - const float zero_point_from_max_error0 = qmax + descaled_max0; - - float zero_point0 = - zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; - - zero_point0 = KAI_MAX(zero_point0, qmin); - zero_point0 = KAI_MIN(zero_point0, qmax); - - // Round to nearest integer - const int32_t nudged_zero_point0 = lrintf(zero_point0); - - // LHS offset at the beginning of the row - *((int32_t*)dst_ptr) = -nudged_zero_point0; - dst_ptr += sizeof(int32_t); - - // Quantize the channels - for (size_t j = 0; j < num_cols; j += 8) { - const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + j); - const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + j); - - float32x4_t v0_f32; - float32x4_t v1_f32; - int32x4_t v0_s32; - int32x4_t v1_s32; - - // Scale the values - v0_f32 = vmulq_n_f32(src0_0, scale0); - v1_f32 = vmulq_n_f32(src0_1, scale0); - v0_s32 = vcvtnq_s32_f32(v0_f32); - v1_s32 = vcvtnq_s32_f32(v1_f32); - v0_s32 = vaddq_s32(v0_s32, vdupq_n_s32(nudged_zero_point0)); - v1_s32 = vaddq_s32(v1_s32, vdupq_n_s32(nudged_zero_point0)); - v0_s32 = vmaxq_s32(v0_s32, vdupq_n_s32(INT8_MIN)); - v0_s32 = vminq_s32(v0_s32, vdupq_n_s32(INT8_MAX)); - v1_s32 = vmaxq_s32(v1_s32, vdupq_n_s32(INT8_MIN)); - v1_s32 = vminq_s32(v1_s32, vdupq_n_s32(INT8_MAX)); - *((int8_t*)(dst_ptr + 0)) = (int8_t)vgetq_lane_s32(v0_s32, 0); - *((int8_t*)(dst_ptr + 1)) = (int8_t)vgetq_lane_s32(v0_s32, 1); - *((int8_t*)(dst_ptr + 2)) = (int8_t)vgetq_lane_s32(v0_s32, 2); - *((int8_t*)(dst_ptr + 3)) = (int8_t)vgetq_lane_s32(v0_s32, 3); - *((int8_t*)(dst_ptr + 4)) = (int8_t)vgetq_lane_s32(v1_s32, 0); - *((int8_t*)(dst_ptr + 5)) = (int8_t)vgetq_lane_s32(v1_s32, 1); - *((int8_t*)(dst_ptr + 6)) = (int8_t)vgetq_lane_s32(v1_s32, 2); - *((int8_t*)(dst_ptr + 7)) = (int8_t)vgetq_lane_s32(v1_s32, 3); - dst_ptr += sizeof(int8_t) * kai_kk0; - } - - // Store the scale quantization params - *((float*)dst_ptr) = recip_scale0; - dst_ptr += sizeof(float); - - src_ptr += kai_km0 * (lhs_stride / sizeof(float)); - } -} diff --git a/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c b/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c deleted file mode 100644 index 89fc26d9..00000000 --- a/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.c +++ /dev/null @@ -1,315 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#include "kai_lhs_quant_pack_qa8dxP4X8_f32.h" - -#include "../kai_common.h" - -#include -#include -#include -#include - -static const size_t kai_kk0 = 8; -static const size_t kai_km0 = 4; -static const size_t kai_num_bytes_per_multiplier = sizeof(float); -static const size_t kai_num_bytes_per_offset = sizeof(float); - -size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP4X8_f32(size_t m_idx, size_t lhs_stride) { - return m_idx * lhs_stride; -} - -size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP4X8_f32(size_t m_idx, size_t k) { - const size_t dst_stride = kai_km0 * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); - - return (m_idx / kai_km0) * dst_stride; -} - -size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP4X8_f32(size_t m, size_t k) { - const size_t m_roundup4 = kai_roundup(m, kai_km0); - - const size_t t_size = m_roundup4 * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); - return t_size; -} - -void kai_run_lhs_quant_pack_qa8dxP4X8_f32( - size_t m, size_t k, const float* restrict lhs, size_t lhs_stride, void* restrict lhs_p) { - KAI_ASSERT(k % kai_kk0 == 0); - KAI_ASSERT(m <= 3 || (m % kai_km0 == 0)); - - if (m == 0) { - return; - } - - const size_t num_cols = k; - const size_t num_rows = m; - - const float* src_ptr0 = (const float*)((const uint8_t*)lhs + 0 * lhs_stride); - const float* src_ptr1 = (const float*)((const uint8_t*)lhs + 1 * lhs_stride); - const float* src_ptr2 = (const float*)((const uint8_t*)lhs + 2 * lhs_stride); - const float* src_ptr3 = (const float*)((const uint8_t*)lhs + 3 * lhs_stride); - - if (m == 1) { - src_ptr1 = src_ptr0; - src_ptr2 = src_ptr0; - src_ptr3 = src_ptr0; - } else if (m == 2) { - src_ptr2 = src_ptr1; - src_ptr3 = src_ptr1; - } else if (m == 3) { - src_ptr3 = src_ptr2; - } - - uint8_t* dst_ptr = lhs_p; - - for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_km0) { - float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); - float32x4_t vmax1 = vdupq_n_f32(-FLT_MAX); - float32x4_t vmax2 = vdupq_n_f32(-FLT_MAX); - float32x4_t vmax3 = vdupq_n_f32(-FLT_MAX); - float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); - float32x4_t vmin1 = vdupq_n_f32(FLT_MAX); - float32x4_t vmin2 = vdupq_n_f32(FLT_MAX); - float32x4_t vmin3 = vdupq_n_f32(FLT_MAX); - - // Find min/max for each channel - for (size_t j = 0; j < num_cols; j += 8) { - const float32x4_t src0_0 = vld1q_f32(src_ptr0 + 0 + j); - const float32x4_t src0_1 = vld1q_f32(src_ptr0 + 4 + j); - - const float32x4_t src1_0 = vld1q_f32(src_ptr1 + 0 + j); - const float32x4_t src1_1 = vld1q_f32(src_ptr1 + 4 + j); - - const float32x4_t src2_0 = vld1q_f32(src_ptr2 + 0 + j); - const float32x4_t src2_1 = vld1q_f32(src_ptr2 + 4 + j); - - const float32x4_t src3_0 = vld1q_f32(src_ptr3 + 0 + j); - const float32x4_t src3_1 = vld1q_f32(src_ptr3 + 4 + j); - - // Calculate the max - vmax0 = vmaxq_f32(src0_0, vmax0); - vmax1 = vmaxq_f32(src1_0, vmax1); - vmax2 = vmaxq_f32(src2_0, vmax2); - vmax3 = vmaxq_f32(src3_0, vmax3); - - vmax0 = vmaxq_f32(vmax0, src0_1); - vmax1 = vmaxq_f32(vmax1, src1_1); - vmax2 = vmaxq_f32(vmax2, src2_1); - vmax3 = vmaxq_f32(vmax3, src3_1); - - // Calculate the min - vmin0 = vminq_f32(src0_0, vmin0); - vmin1 = vminq_f32(src1_0, vmin1); - vmin2 = vminq_f32(src2_0, vmin2); - vmin3 = vminq_f32(src3_0, vmin3); - - vmin0 = vminq_f32(vmin0, src0_1); - vmin1 = vminq_f32(vmin1, src1_1); - vmin2 = vminq_f32(vmin2, src2_1); - vmin3 = vminq_f32(vmin3, src3_1); - } - - // Get the max/min from each row - const float max0 = vmaxvq_f32(vmax0); - const float max1 = vmaxvq_f32(vmax1); - const float max2 = vmaxvq_f32(vmax2); - const float max3 = vmaxvq_f32(vmax3); - const float min0 = vminvq_f32(vmin0); - const float min1 = vminvq_f32(vmin1); - const float min2 = vminvq_f32(vmin2); - const float min3 = vminvq_f32(vmin3); - - // Maximum/minimum int8 values - const float qmin = (float)INT8_MIN; - const float qmax = (float)INT8_MAX; - - const float rmin0 = KAI_MIN(0.0f, min0); - const float rmax0 = KAI_MAX(0.0f, max0); - const float rmin1 = KAI_MIN(0.0f, min1); - const float rmax1 = KAI_MAX(0.0f, max1); - const float rmin2 = KAI_MIN(0.0f, min2); - const float rmax2 = KAI_MAX(0.0f, max2); - const float rmin3 = KAI_MIN(0.0f, min3); - const float rmax3 = KAI_MAX(0.0f, max3); - - const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); - const float scale1 = rmin1 == rmax1 ? 1.f : (qmax - qmin) / (rmax1 - rmin1); - const float scale2 = rmin2 == rmax2 ? 1.f : (qmax - qmin) / (rmax2 - rmin2); - const float scale3 = rmin3 == rmax3 ? 1.f : (qmax - qmin) / (rmax3 - rmin3); - - // Reciprocal to quantize - const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; - const float recip_scale1 = scale1 ? 1.0f / scale1 : 0.0f; - const float recip_scale2 = scale2 ? 1.0f / scale2 : 0.0f; - const float recip_scale3 = scale3 ? 1.0f / scale3 : 0.0f; - - const float descaled_min0 = rmin0 * scale0; - const float descaled_min1 = rmin1 * scale1; - const float descaled_min2 = rmin2 * scale2; - const float descaled_min3 = rmin3 * scale3; - const float descaled_max0 = rmax0 * scale0; - const float descaled_max1 = rmax1 * scale1; - const float descaled_max2 = rmax2 * scale2; - const float descaled_max3 = rmax3 * scale3; - - const float zero_point_from_min_error0 = qmin + descaled_min0; - const float zero_point_from_max_error0 = qmax + descaled_max0; - const float zero_point_from_min_error1 = qmin + descaled_min1; - const float zero_point_from_max_error1 = qmax + descaled_max1; - const float zero_point_from_min_error2 = qmin + descaled_min2; - const float zero_point_from_max_error2 = qmax + descaled_max2; - const float zero_point_from_min_error3 = qmin + descaled_min3; - const float zero_point_from_max_error3 = qmax + descaled_max3; - - float zero_point0 = - zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; - float zero_point1 = - zero_point_from_min_error1 + zero_point_from_max_error1 > 0 ? qmin - descaled_min1 : qmax - descaled_max1; - float zero_point2 = - zero_point_from_min_error2 + zero_point_from_max_error2 > 0 ? qmin - descaled_min2 : qmax - descaled_max2; - float zero_point3 = - zero_point_from_min_error3 + zero_point_from_max_error3 > 0 ? qmin - descaled_min3 : qmax - descaled_max3; - - zero_point0 = KAI_MAX(zero_point0, qmin); - zero_point0 = KAI_MIN(zero_point0, qmax); - zero_point1 = KAI_MAX(zero_point1, qmin); - zero_point1 = KAI_MIN(zero_point1, qmax); - zero_point2 = KAI_MAX(zero_point2, qmin); - zero_point2 = KAI_MIN(zero_point2, qmax); - zero_point3 = KAI_MAX(zero_point3, qmin); - zero_point3 = KAI_MIN(zero_point3, qmax); - - // Round to nearest integer - const int32_t nudged_zero_point0 = lrintf(zero_point0); - const int32_t nudged_zero_point1 = lrintf(zero_point1); - const int32_t nudged_zero_point2 = lrintf(zero_point2); - const int32_t nudged_zero_point3 = lrintf(zero_point3); - - // The LHS offsets are stored at the beginning of the row - int32x4_t voffsets_s32 = vdupq_n_s32(0.0f); - voffsets_s32 = vsetq_lane_s32(-nudged_zero_point0, voffsets_s32, 0); - voffsets_s32 = vsetq_lane_s32(-nudged_zero_point1, voffsets_s32, 1); - voffsets_s32 = vsetq_lane_s32(-nudged_zero_point2, voffsets_s32, 2); - voffsets_s32 = vsetq_lane_s32(-nudged_zero_point3, voffsets_s32, 3); - - vst1q_s32((int32_t*)dst_ptr, voffsets_s32); - dst_ptr += sizeof(int32x4_t); - - // Quantize the channels - for (size_t j = 0; j < num_cols; j += 8) { - const float32x4_t src0_0 = vld1q_f32(src_ptr0 + 0 + j); - const float32x4_t src0_1 = vld1q_f32(src_ptr0 + 4 + j); - - const float32x4_t src1_0 = vld1q_f32(src_ptr1 + 0 + j); - const float32x4_t src1_1 = vld1q_f32(src_ptr1 + 4 + j); - - const float32x4_t src2_0 = vld1q_f32(src_ptr2 + 0 + j); - const float32x4_t src2_1 = vld1q_f32(src_ptr2 + 4 + j); - - const float32x4_t src3_0 = vld1q_f32(src_ptr3 + 0 + j); - const float32x4_t src3_1 = vld1q_f32(src_ptr3 + 4 + j); - - float32x4_t v0_f32; - float32x4_t v1_f32; - int32x4_t v0_s32; - int32x4_t v1_s32; - - // Scale the values - v0_f32 = vmulq_n_f32(src0_0, scale0); - v1_f32 = vmulq_n_f32(src0_1, scale0); - v0_s32 = vcvtnq_s32_f32(v0_f32); - v1_s32 = vcvtnq_s32_f32(v1_f32); - v0_s32 = vaddq_s32(v0_s32, vdupq_n_s32(nudged_zero_point0)); - v1_s32 = vaddq_s32(v1_s32, vdupq_n_s32(nudged_zero_point0)); - v0_s32 = vmaxq_s32(v0_s32, vdupq_n_s32(INT8_MIN)); - v0_s32 = vminq_s32(v0_s32, vdupq_n_s32(INT8_MAX)); - v1_s32 = vmaxq_s32(v1_s32, vdupq_n_s32(INT8_MIN)); - v1_s32 = vminq_s32(v1_s32, vdupq_n_s32(INT8_MAX)); - *((int8_t*)(dst_ptr + 0)) = (int8_t)vgetq_lane_s32(v0_s32, 0); - *((int8_t*)(dst_ptr + 1)) = (int8_t)vgetq_lane_s32(v0_s32, 1); - *((int8_t*)(dst_ptr + 2)) = (int8_t)vgetq_lane_s32(v0_s32, 2); - *((int8_t*)(dst_ptr + 3)) = (int8_t)vgetq_lane_s32(v0_s32, 3); - *((int8_t*)(dst_ptr + 4)) = (int8_t)vgetq_lane_s32(v1_s32, 0); - *((int8_t*)(dst_ptr + 5)) = (int8_t)vgetq_lane_s32(v1_s32, 1); - *((int8_t*)(dst_ptr + 6)) = (int8_t)vgetq_lane_s32(v1_s32, 2); - *((int8_t*)(dst_ptr + 7)) = (int8_t)vgetq_lane_s32(v1_s32, 3); - dst_ptr += sizeof(int8_t) * 8; - - v0_f32 = vmulq_n_f32(src1_0, scale1); - v1_f32 = vmulq_n_f32(src1_1, scale1); - v0_s32 = vcvtnq_s32_f32(v0_f32); - v1_s32 = vcvtnq_s32_f32(v1_f32); - v0_s32 = vaddq_s32(v0_s32, vdupq_n_s32(nudged_zero_point1)); - v1_s32 = vaddq_s32(v1_s32, vdupq_n_s32(nudged_zero_point1)); - v0_s32 = vmaxq_s32(v0_s32, vdupq_n_s32(INT8_MIN)); - v0_s32 = vminq_s32(v0_s32, vdupq_n_s32(INT8_MAX)); - v1_s32 = vmaxq_s32(v1_s32, vdupq_n_s32(INT8_MIN)); - v1_s32 = vminq_s32(v1_s32, vdupq_n_s32(INT8_MAX)); - *((int8_t*)(dst_ptr + 0)) = (int8_t)vgetq_lane_s32(v0_s32, 0); - *((int8_t*)(dst_ptr + 1)) = (int8_t)vgetq_lane_s32(v0_s32, 1); - *((int8_t*)(dst_ptr + 2)) = (int8_t)vgetq_lane_s32(v0_s32, 2); - *((int8_t*)(dst_ptr + 3)) = (int8_t)vgetq_lane_s32(v0_s32, 3); - *((int8_t*)(dst_ptr + 4)) = (int8_t)vgetq_lane_s32(v1_s32, 0); - *((int8_t*)(dst_ptr + 5)) = (int8_t)vgetq_lane_s32(v1_s32, 1); - *((int8_t*)(dst_ptr + 6)) = (int8_t)vgetq_lane_s32(v1_s32, 2); - *((int8_t*)(dst_ptr + 7)) = (int8_t)vgetq_lane_s32(v1_s32, 3); - dst_ptr += sizeof(int8_t) * 8; - - v0_f32 = vmulq_n_f32(src2_0, scale2); - v1_f32 = vmulq_n_f32(src2_1, scale2); - v0_s32 = vcvtnq_s32_f32(v0_f32); - v1_s32 = vcvtnq_s32_f32(v1_f32); - v0_s32 = vaddq_s32(v0_s32, vdupq_n_s32(nudged_zero_point2)); - v1_s32 = vaddq_s32(v1_s32, vdupq_n_s32(nudged_zero_point2)); - v0_s32 = vmaxq_s32(v0_s32, vdupq_n_s32(INT8_MIN)); - v0_s32 = vminq_s32(v0_s32, vdupq_n_s32(INT8_MAX)); - v1_s32 = vmaxq_s32(v1_s32, vdupq_n_s32(INT8_MIN)); - v1_s32 = vminq_s32(v1_s32, vdupq_n_s32(INT8_MAX)); - *((int8_t*)(dst_ptr + 0)) = (int8_t)vgetq_lane_s32(v0_s32, 0); - *((int8_t*)(dst_ptr + 1)) = (int8_t)vgetq_lane_s32(v0_s32, 1); - *((int8_t*)(dst_ptr + 2)) = (int8_t)vgetq_lane_s32(v0_s32, 2); - *((int8_t*)(dst_ptr + 3)) = (int8_t)vgetq_lane_s32(v0_s32, 3); - *((int8_t*)(dst_ptr + 4)) = (int8_t)vgetq_lane_s32(v1_s32, 0); - *((int8_t*)(dst_ptr + 5)) = (int8_t)vgetq_lane_s32(v1_s32, 1); - *((int8_t*)(dst_ptr + 6)) = (int8_t)vgetq_lane_s32(v1_s32, 2); - *((int8_t*)(dst_ptr + 7)) = (int8_t)vgetq_lane_s32(v1_s32, 3); - dst_ptr += sizeof(int8_t) * 8; - - v0_f32 = vmulq_n_f32(src3_0, scale3); - v1_f32 = vmulq_n_f32(src3_1, scale3); - v0_s32 = vcvtnq_s32_f32(v0_f32); - v1_s32 = vcvtnq_s32_f32(v1_f32); - v0_s32 = vaddq_s32(v0_s32, vdupq_n_s32(nudged_zero_point3)); - v1_s32 = vaddq_s32(v1_s32, vdupq_n_s32(nudged_zero_point3)); - v0_s32 = vmaxq_s32(v0_s32, vdupq_n_s32(INT8_MIN)); - v0_s32 = vminq_s32(v0_s32, vdupq_n_s32(INT8_MAX)); - v1_s32 = vmaxq_s32(v1_s32, vdupq_n_s32(INT8_MIN)); - v1_s32 = vminq_s32(v1_s32, vdupq_n_s32(INT8_MAX)); - *((int8_t*)(dst_ptr + 0)) = (int8_t)vgetq_lane_s32(v0_s32, 0); - *((int8_t*)(dst_ptr + 1)) = (int8_t)vgetq_lane_s32(v0_s32, 1); - *((int8_t*)(dst_ptr + 2)) = (int8_t)vgetq_lane_s32(v0_s32, 2); - *((int8_t*)(dst_ptr + 3)) = (int8_t)vgetq_lane_s32(v0_s32, 3); - *((int8_t*)(dst_ptr + 4)) = (int8_t)vgetq_lane_s32(v1_s32, 0); - *((int8_t*)(dst_ptr + 5)) = (int8_t)vgetq_lane_s32(v1_s32, 1); - *((int8_t*)(dst_ptr + 6)) = (int8_t)vgetq_lane_s32(v1_s32, 2); - *((int8_t*)(dst_ptr + 7)) = (int8_t)vgetq_lane_s32(v1_s32, 3); - dst_ptr += sizeof(int8_t) * 8; - } - - // Store the scale quantization params - float32x4_t vscales_f32 = vdupq_n_f32(0.0f); - vscales_f32 = vsetq_lane_f32(recip_scale0, vscales_f32, 0); - vscales_f32 = vsetq_lane_f32(recip_scale1, vscales_f32, 1); - vscales_f32 = vsetq_lane_f32(recip_scale2, vscales_f32, 2); - vscales_f32 = vsetq_lane_f32(recip_scale3, vscales_f32, 3); - vst1q_f32((float*)dst_ptr, vscales_f32); - dst_ptr += sizeof(float32x4_t); - - src_ptr0 += kai_km0 * (lhs_stride / sizeof(float)); - src_ptr1 += kai_km0 * (lhs_stride / sizeof(float)); - src_ptr2 += kai_km0 * (lhs_stride / sizeof(float)); - src_ptr3 += kai_km0 * (lhs_stride / sizeof(float)); - } -} diff --git a/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h b/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h deleted file mode 100644 index 11bb931e..00000000 --- a/src/matmul/kai_lhs_quant_pack_qa8dxP4X8_f32.h +++ /dev/null @@ -1,65 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#pragma once - -#ifdef __cplusplus -extern "C" { -#else -#include -#endif - -#include -#include - -/** - * @brief Function to calculate the offset in bytes for the LHS matrix (not packed) - * - * This function should be called before passing the pointer to the LHS matrix to the micro-kernel. - * - * @param[in] m_idx Row index in the LHS matrix (not packed). - * @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) - * - * return the offset in bytes to the LHS matrix - */ -size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP4X8_f32(size_t m_idx, size_t lhs_stride); - -/** - * @brief Function to calculate the offset in bytes for the packed LHS matrix, - * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. - * - * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. - * - * @param[in] m_idx Row index in the LHS matrix (not packed). - * @param[in] k Total number of columns in the LHS matrix (not packed). - * - * return the offset in bytes to the packed LHS matrix - */ -size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP4X8_f32(size_t m_idx, size_t k); - -/** - * @brief Function to return the memory required for storing the quantized and packed LHS matrix - * - * @param[in] m Total number of rows in the LHS matrix (not packed). - * @param[in] k Total number of columns in the LHS matrix (not packed). - * - * return the size in bytes to the packed LHS matrix - */ -size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP4X8_f32(size_t m, size_t k); - -/** - * @brief Micro-kernel to quantize and pack the LHS matrix - * - * @param[in] m The number of output rows written. It must be 1, 2, 3, 4, or any multiple of 4. - * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. - * @param[in] lhs LHS of the vector-by-matrix. - * @param[in] lhs_stride Stride in bytes between two rows of LHS. - * @param[in] lhs_p The quantized and packed LHS matrix. - */ -void kai_run_lhs_quant_pack_qa8dxP4X8_f32(size_t m, size_t k, const float* lhs, size_t lhs_stride, void* lhs_p); - -#ifdef __cplusplus -} -#endif diff --git a/src/matmul/kai_lhs_quant_pack_qa8dxP_f32.c b/src/matmul/kai_lhs_quant_pack_qa8dxP_f32.c new file mode 100644 index 00000000..3f68d8a4 --- /dev/null +++ b/src/matmul/kai_lhs_quant_pack_qa8dxP_f32.c @@ -0,0 +1,154 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_lhs_quant_pack_qa8dxP_f32.h" + +#include "../kai_common.h" + +#include +#include +#include +#include + +static const size_t kai_num_bytes_per_multiplier = sizeof(float); +static const size_t kai_num_bytes_per_offset = sizeof(int32_t); + +size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP_f32(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP_f32(size_t m_idx, size_t k, size_t mr, size_t kr) { + const size_t dst_stride = mr * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); + + return (m_idx / mr) * dst_stride; +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP_f32(size_t m, size_t k, size_t mr, size_t kr) { + const size_t m_roundup = kai_roundup(m, mr); + + const size_t t_size = m_roundup * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); + return t_size; +} + +void kai_run_lhs_quant_pack_qa8dxP_f32( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, const float* restrict lhs, size_t lhs_stride, + void* restrict lhs_p) { + KAI_ASSERT(k % kr == 0); + KAI_ASSERT((kr % sr) == 0); + + if (m == 0) { + return; + } + + const size_t num_rows = m; + const size_t num_cols = k; + + const float* src_ptr = lhs; + + const size_t dst_stride = mr * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); + const size_t k_block_len = kr / sr; + + for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { + float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); + + // Find min/max for each channel + size_t k_idx = 0; + for (; k_idx <= (num_cols - 8); k_idx += 8) { + const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + k_idx); + const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + k_idx); + + // Calculate the max + vmax0 = vmaxq_f32(src0_0, vmax0); + vmax0 = vmaxq_f32(vmax0, src0_1); + + // Calculate the min + vmin0 = vminq_f32(src0_0, vmin0); + vmin0 = vminq_f32(vmin0, src0_1); + } + + for (; k_idx < num_cols; ++k_idx) { + const float src0_0 = *(src_ptr + k_idx); + + // Calculate the max + vmax0 = vsetq_lane_f32(KAI_MAX(src0_0, vgetq_lane_f32(vmax0, 0)), vmax0, 0); + // Calculate the min + vmin0 = vsetq_lane_f32(KAI_MIN(src0_0, vgetq_lane_f32(vmin0, 0)), vmin0, 0); + } + + // Get the max/min + const float max0 = vmaxvq_f32(vmax0); + const float min0 = vminvq_f32(vmin0); + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = KAI_MIN(0.0f, min0); + const float rmax0 = KAI_MAX(0.0f, max0); + + const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; + + zero_point0 = KAI_MAX(zero_point0, qmin); + zero_point0 = KAI_MIN(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = lrintf(zero_point0); + + const size_t dst_x = (row_idx % mr); + const size_t dst_y = (row_idx / mr); + + uint8_t* dst_ptr = (uint8_t*)lhs_p + dst_y * dst_stride; + + dst_ptr += dst_x * k_block_len * sizeof(int8_t); + + // Quantize the channels + k_idx = 0; + for (; k_idx < num_cols; k_idx += k_block_len) { + for (size_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { + const float src0_0 = *(src_ptr + k_idx + k_block_idx); + + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = KAI_MAX(v0_s32, INT8_MIN); + v0_s32 = KAI_MIN(v0_s32, INT8_MAX); + *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + + dst_ptr = (uint8_t*)lhs_p + dst_y * dst_stride + mr * (k * sizeof(int8_t)); + + dst_ptr += dst_x * kai_num_bytes_per_offset; + + // LHS offset at the beginning of the row + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + + // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier + KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); + + dst_ptr += mr * kai_num_bytes_per_offset; + + // Store the scale quantization params + *((float*)(dst_ptr)) = recip_scale0; + + src_ptr += (lhs_stride / sizeof(float)); + } +} diff --git a/src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.h b/src/matmul/kai_lhs_quant_pack_qa8dxP_f32.h similarity index 81% rename from src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.h rename to src/matmul/kai_lhs_quant_pack_qa8dxP_f32.h index 3c494877..d27b3773 100644 --- a/src/matmul/kai_lhs_quant_pack_qa8dxP1X8_f32.h +++ b/src/matmul/kai_lhs_quant_pack_qa8dxP_f32.h @@ -24,7 +24,7 @@ extern "C" { * * return the offset in bytes to the LHS matrix */ -size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP1X8_f32(size_t m_idx, size_t lhs_stride); +size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP_f32(size_t m_idx, size_t lhs_stride); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -37,7 +37,7 @@ size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP1X8_f32(size_t m_idx, size_t lhs_ * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP1X8_f32(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP_f32(size_t m_idx, size_t k, size_t mr, size_t kr); /** * @brief Function to return the memory required for storing the quantized and packed LHS matrix @@ -47,7 +47,7 @@ size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP1X8_f32(size_t m_idx, size * * return the size in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP1X8_f32(size_t m, size_t k); +size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP_f32(size_t m, size_t k, size_t mr, size_t kr); /** * @brief Micro-kernel to quantize and pack the LHS matrix @@ -58,7 +58,8 @@ size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP1X8_f32(size_t m, size_t k); * @param[in] lhs_stride Stride in bytes between two rows of LHS. * @param[out] lhs_p The quantized and packed LHS matrix. */ -void kai_run_lhs_quant_pack_qa8dxP1X8_f32(size_t m, size_t k, const float* lhs, size_t lhs_stride, void* lhs_p); +void kai_run_lhs_quant_pack_qa8dxP_f32( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, const float* lhs, size_t lhs_stride, void* lhs_p); #ifdef __cplusplus } diff --git a/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c b/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c index 115e4df0..eb42ca11 100644 --- a/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c +++ b/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c @@ -81,15 +81,11 @@ void kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0( const uint8_t* src_row = (const uint8_t*)rhs + y * rhs_stride; uint8_t* dst_row = (uint8_t*)rhs_p + (y / nr) * rhs_packed_stride; - int32_t* sums = (int32_t*)dst_row; + int32_t* sums = (int32_t*)(dst_row + nr * (k / 2)); - // The RHS reduction sums are stored at the beginning of the row // Initialize to zero the RHS reduction sums memset(sums, 0, nr * sizeof(int32_t)); - // Move the pointer after the biases - dst_row += nr * sizeof(int32_t); - for (size_t x = 0; x < k; x += (kr * 2)) { for (size_t s = 0; s < sr; ++s) { for (size_t i = 0; i < nr; ++i) { @@ -145,14 +141,16 @@ void kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0( } } - // Adjust the scale + // Adjust the reduction sums for (size_t i = 0; i < nr; ++i) { - ((float*)(dst_row))[i] = scale[y + i] * 0.0625f; + *((int32_t*)(dst_row)) = sums[i] * 16; + dst_row += sizeof(int32_t); } - // Adjust the scale + // Adjust the scales for (size_t i = 0; i < nr; ++i) { - sums[i] = sums[i] * 16; + *((float*)(dst_row)) = scale[y + i] * 0.0625f; + dst_row += sizeof(float); } } } diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c index d4c61c74..257fea30 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c @@ -20,6 +20,10 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +size_t kai_get_mr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void) { + return kai_mr; +} + size_t kai_get_nr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void) { return kai_nr; } @@ -57,7 +61,7 @@ size_t kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotpro size_t kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(size_t m, size_t n) { KAI_ASSERT((n % kai_nr) == 0); - return m * n; + return m * n * sizeof(float); } void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( @@ -82,19 +86,15 @@ void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( const uint8_t* lhs_ptr_start = lhs_p; for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { + const uint8_t* rhs_ptr = rhs_p; for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { - // Main f32 accumulator - int32x4_t iacc0011 = vdupq_n_s32(0); - int32x4_t iacc2233 = vdupq_n_s32(0); const uint8_t* lhs_ptr = lhs_ptr_start; - const uint8_t* rhs_ptr_start = rhs_ptr; - // LHS offset is stored at the beginning of the row - lhs_ptr += sizeof(int32_t); - // The RHS reduction sums are stored at the beginning of the row - rhs_ptr += sizeof(int32x4_t); + // Main f32 accumulator + int32x4_t iacc0011 = vdupq_n_s32(0); + int32x4_t iacc2233 = vdupq_n_s32(0); for (size_t b = 0; b < k; b += kai_k0) { // Set up RHS @@ -166,19 +166,21 @@ void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( int32x4_t iacc = vpaddq_s32(iacc0011, iacc2233); - // RHS scale - const float32x4_t rhs_scale = vld1q_f32((const float*)rhs_ptr); - rhs_ptr += sizeof(float32x4_t); + // LHS offset + const int32x4_t lhs_offset = vld1q_dup_s32((const int32_t*)lhs_ptr); + lhs_ptr += sizeof(int32_t); // LHS scale const float32x4_t lhs_scale = vld1q_dup_f32((const float*)lhs_ptr); lhs_ptr += sizeof(float); - // LHS offset - const int32x4_t lhs_offset = vld1q_dup_s32((const int32_t*)lhs_ptr_start); - // RHS sum values - const int32x4_t sum_n_s32 = vld1q_s32((const int32_t*)(rhs_ptr_start)); + const int32x4_t sum_n_s32 = vld1q_s32((const int32_t*)(rhs_ptr)); + rhs_ptr += sizeof(int32x4_t); + + // RHS scale + const float32x4_t rhs_scale = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); // Add the reduction sum iacc = vmlaq_s32(iacc, sum_n_s32, lhs_offset); diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h index 878beb25..d8524a2b 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h @@ -14,6 +14,14 @@ extern "C" { #include #include +/** + * @brief Function to get the mr value, which must be used to pack the LHS matrix with + * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function + * + * @return the mr value + */ +size_t kai_get_mr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void); + /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.c b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.c new file mode 100644 index 00000000..01b4c377 --- /dev/null +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.c @@ -0,0 +1,223 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h" + +#include "../../kai_common.h" + +#include +#include + +static const size_t kai_mr = 1; +static const size_t kai_nr = 8; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_k0 = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); + +size_t kai_get_mr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(size_t m_idx, size_t k) { + const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + + return (m_idx / kai_mr) * lhs_packed_stride; +} + +size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSERT((k % kai_k0) == 0); + KAI_ASSERT((n_idx % kai_nr) == 0); + + const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + + return (n_idx / kai_nr) * rhs_packed_stride; +} + +size_t kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((n_idx % kai_nr) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(size_t m, size_t n) { + KAI_ASSERT((n % kai_nr) == 0); + + return m * n * sizeof(float); +} + +void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* restrict lhs_p, const void* restrict rhs_p, float* restrict dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { +#if defined(__ARM_FEATURE_DOTPROD) + KAI_ASSERT(n % kai_nr == 0); + KAI_ASSERT(k % kai_k0 == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t num_rows = m; + const size_t num_cols = n; + + const size_t lhs_packed_stride = kai_mr * (k + sizeof(float) + sizeof(float)); + + const int8x16_t nibble_mask = vdupq_n_s8(0xF0); + + const uint8_t* lhs_ptr_start = lhs_p; + + for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { + const uint8_t* rhs_ptr = rhs_p; + for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { + + const uint8_t* lhs_ptr = lhs_ptr_start; + + // Main f32 accumulator + int32x4_t iacc0011 = vdupq_n_s32(0); + int32x4_t iacc2233 = vdupq_n_s32(0); + int32x4_t iacc4455 = vdupq_n_s32(0); + int32x4_t iacc6677 = vdupq_n_s32(0); + + for (size_t b = 0; b < k; b += kai_k0) { + // Set up RHS + const int8x16_t rhs_raw_vec_0 = vld1q_s8((const int8_t*)(rhs_ptr + 0)); + const int8x16_t rhs_raw_vec_1 = vld1q_s8((const int8_t*)(rhs_ptr + 16)); + const int8x16_t rhs_raw_vec_2 = vld1q_s8((const int8_t*)(rhs_ptr + 32)); + const int8x16_t rhs_raw_vec_3 = vld1q_s8((const int8_t*)(rhs_ptr + 48)); + const int8x16_t rhs_raw_vec_4 = vld1q_s8((const int8_t*)(rhs_ptr + 64)); + const int8x16_t rhs_raw_vec_5 = vld1q_s8((const int8_t*)(rhs_ptr + 80)); + const int8x16_t rhs_raw_vec_6 = vld1q_s8((const int8_t*)(rhs_ptr + 96)); + const int8x16_t rhs_raw_vec_7 = vld1q_s8((const int8_t*)(rhs_ptr + 112)); + + // Low nibble + const int8x16_t rhs_vec_0_0 = vshlq_n_s8(rhs_raw_vec_0, 4); + const int8x16_t rhs_vec_1_0 = vshlq_n_s8(rhs_raw_vec_1, 4); + const int8x16_t rhs_vec_2_0 = vshlq_n_s8(rhs_raw_vec_2, 4); + const int8x16_t rhs_vec_3_0 = vshlq_n_s8(rhs_raw_vec_3, 4); + const int8x16_t rhs_vec_4_0 = vshlq_n_s8(rhs_raw_vec_4, 4); + const int8x16_t rhs_vec_5_0 = vshlq_n_s8(rhs_raw_vec_5, 4); + const int8x16_t rhs_vec_6_0 = vshlq_n_s8(rhs_raw_vec_6, 4); + const int8x16_t rhs_vec_7_0 = vshlq_n_s8(rhs_raw_vec_7, 4); + + // High nibble + const int8x16_t rhs_vec_0_1 = vandq_s8(rhs_raw_vec_0, nibble_mask); + const int8x16_t rhs_vec_1_1 = vandq_s8(rhs_raw_vec_1, nibble_mask); + const int8x16_t rhs_vec_2_1 = vandq_s8(rhs_raw_vec_2, nibble_mask); + const int8x16_t rhs_vec_3_1 = vandq_s8(rhs_raw_vec_3, nibble_mask); + const int8x16_t rhs_vec_4_1 = vandq_s8(rhs_raw_vec_4, nibble_mask); + const int8x16_t rhs_vec_5_1 = vandq_s8(rhs_raw_vec_5, nibble_mask); + const int8x16_t rhs_vec_6_1 = vandq_s8(rhs_raw_vec_6, nibble_mask); + const int8x16_t rhs_vec_7_1 = vandq_s8(rhs_raw_vec_7, nibble_mask); + + const int8x16_t lhs_vec_0 = vld1q_s8((const int8_t*)(lhs_ptr + 0)); + const int8x16_t lhs_vec_1 = vld1q_s8((const int8_t*)(lhs_ptr + 16)); + + lhs_ptr += 32; + rhs_ptr += 128; + + int8x16_t t; + + t = vcombine_s8(vget_low_s8(lhs_vec_0), vget_low_s8(lhs_vec_0)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_0, t); + iacc4455 = vdotq_s32(iacc4455, rhs_vec_2_0, t); + iacc6677 = vdotq_s32(iacc6677, rhs_vec_3_0, t); + t = vcombine_s8(vget_high_s8(lhs_vec_0), vget_high_s8(lhs_vec_0)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_4_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_5_0, t); + iacc4455 = vdotq_s32(iacc4455, rhs_vec_6_0, t); + iacc6677 = vdotq_s32(iacc6677, rhs_vec_7_0, t); + + t = vcombine_s8(vget_low_s8(lhs_vec_1), vget_low_s8(lhs_vec_1)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_1, t); + iacc4455 = vdotq_s32(iacc4455, rhs_vec_2_1, t); + iacc6677 = vdotq_s32(iacc6677, rhs_vec_3_1, t); + t = vcombine_s8(vget_high_s8(lhs_vec_1), vget_high_s8(lhs_vec_1)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_4_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_5_1, t); + iacc4455 = vdotq_s32(iacc4455, rhs_vec_6_1, t); + iacc6677 = vdotq_s32(iacc6677, rhs_vec_7_1, t); + } + + int32x4_t iacc0 = vpaddq_s32(iacc0011, iacc2233); + int32x4_t iacc1 = vpaddq_s32(iacc4455, iacc6677); + + // LHS offset + const int32x4_t lhs_offset = vld1q_dup_s32((const int32_t*)lhs_ptr); + lhs_ptr += sizeof(int32_t); + + // LHS scale + const float32x4_t lhs_scale = vld1q_dup_f32((const float*)lhs_ptr); + lhs_ptr += sizeof(float); + + // RHS sum values + const int32x4_t sum_n_s32_0 = vld1q_s32((const int32_t*)(rhs_ptr)); + rhs_ptr += sizeof(int32x4_t); + const int32x4_t sum_n_s32_1 = vld1q_s32((const int32_t*)(rhs_ptr)); + rhs_ptr += sizeof(int32x4_t); + + // RHS scale + const float32x4_t rhs_scale0 = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + const float32x4_t rhs_scale1 = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + + // Add the reduction sum + iacc0 = vmlaq_s32(iacc0, sum_n_s32_0, lhs_offset); + iacc1 = vmlaq_s32(iacc1, sum_n_s32_1, lhs_offset); + + float32x4_t main_acc0 = vmulq_f32(vcvtq_f32_s32(iacc0), rhs_scale0); + float32x4_t main_acc1 = vmulq_f32(vcvtq_f32_s32(iacc1), rhs_scale1); + + main_acc0 = vmulq_f32(main_acc0, lhs_scale); + main_acc1 = vmulq_f32(main_acc1, lhs_scale); + + // Clip (min-max) operation + const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); + const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); + + main_acc0 = vmaxq_f32(main_acc0, vmin_f32); + main_acc0 = vminq_f32(main_acc0, vmax_f32); + + main_acc1 = vmaxq_f32(main_acc1, vmin_f32); + main_acc1 = vminq_f32(main_acc1, vmax_f32); + + vst1q_f32((float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + row_idx * dst_stride_row), main_acc0); + vst1q_f32((float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + row_idx * dst_stride_row), main_acc1); + } + lhs_ptr_start += lhs_packed_stride; + } +#else + KAI_ASSERT(false); + KAI_UNUSED(m); + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(lhs_p); + KAI_UNUSED(rhs_p); + KAI_UNUSED(dst); + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(scalar_min); + KAI_UNUSED(scalar_max); +#endif +} diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h new file mode 100644 index 00000000..4a2c3e43 --- /dev/null +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h @@ -0,0 +1,123 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +#include +#include + +/** + * @brief Function to get the mr value, which must be used to pack the LHS matrix with + * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function + * + * @return the mr value + */ +size_t kai_get_mr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the nr value + */ +size_t kai_get_nr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void); + +/** + * @brief Function to get the kr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the kr value + */ +size_t kai_get_kr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the sr value + */ +size_t kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). It must be a multiple of 64. + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(size_t m_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. + * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be a multiple of 64. + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(size_t n_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the DST matrix + * + * @param[in] m_idx Row index in the DST matrix. + * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * + * return the DST offset in bytes + */ +size_t kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/** + * @brief Function to query the size in bytes for the constant workspace. + * + * @param[in] m Number of rows in the destination (DST) matrix + * @param[in] n Number of columns in the destination (DST) matrix + */ +size_t kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(size_t m, size_t n); + +/** + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clip (min-max) operation. + * + * LHS matrix: 8-bit quantized asymmeitric per-row (qa8dx) and packed + * RHS matrix: 4-bit quantized symmetric per-channel (qs4cx) and packed. + * Output tile: (rows x cols) = 1 x 4 + * Accumulation performed in a single for loop: 64 + * Instruction used: dotprod + * + * @param[in] m The number of output rows written. + * @param[in] n The number of output columns written. It must be a multiple of 4. + * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 64. + * @param[in] lhs_p The LHS matrix packed. + * When the activation are dynamically quantized, you can obtain this matrix + * by calling the @ref kai_lhs_quant_pack_qa8dxP1X8_f32 micro-kernel which performs + * both the dynamic quantization to 8-bit and activation packing in a single step. + * @param[in] rhs_p The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 + * @param[out] dst Result of the vector-by-matrix + * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. + * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) + * @param[in] scalar_min Min value used to clip the final result. + * @param[in] scalar_max Max value used to clip the final result. + */ +void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c index a3a81282..b91f0441 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c @@ -20,6 +20,10 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void) { + return kai_mr; +} + size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void) { return kai_nr; } @@ -33,6 +37,7 @@ size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void) { } size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_mr) == 0); const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); return (m_idx / kai_mr) * lhs_packed_stride; @@ -49,6 +54,7 @@ size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon size_t kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_mr) == 0); KAI_ASSERT((n_idx % kai_nr) == 0); return (n_idx * sizeof(float)) + m_idx * dst_stride; @@ -57,14 +63,13 @@ kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m, size_t n) { KAI_ASSERT((n % kai_nr) == 0); - return m * n; + return m * n * sizeof(float); } void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm( - size_t m, size_t n, size_t k, const void* restrict lhs_p, const void* restrict rhs_p, float* restrict dst, - size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) - KAI_ASSERT(m <= 3 || (m % kai_mr == 0)); KAI_ASSERT(n % kai_nr == 0); KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); @@ -73,157 +78,133 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm( return; } - const size_t num_rows = m; - const size_t num_cols = n; - - const size_t lhs_packed_stride = kai_mr * (k + sizeof(float32_t) + sizeof(float32_t)); - - const int8x16_t nibble_mask = vdupq_n_s8(0xF0); - - const uint8_t* lhs_ptr_start = lhs_p; - - for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { - const uint8_t* rhs_ptr = rhs_p; - - for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { - const uint8_t* lhs_ptr = lhs_ptr_start; - const uint8_t* rhs_ptr_start = rhs_ptr; - - // Main f32 accumulator - int32x4_t iacc_mat_00 = vdupq_n_s32(0); - int32x4_t iacc_mat_01 = vdupq_n_s32(0); - int32x4_t iacc_mat_10 = vdupq_n_s32(0); - int32x4_t iacc_mat_11 = vdupq_n_s32(0); - - // LHS offset is stored at the beginning of the row - lhs_ptr += sizeof(int32x4_t); - // The RHS reduction sums are stored at the beginning of the row - rhs_ptr += sizeof(int32x4_t); - - for (size_t b = 0; b < k; b += kai_k0) { - // Set up RHS - const int8x16_t rhs_raw_mat_01_0 = vld1q_s8((const int8_t*)rhs_ptr + 0); - const int8x16_t rhs_raw_mat_23_0 = vld1q_s8((const int8_t*)rhs_ptr + 16); - const int8x16_t rhs_raw_mat_01_1 = vld1q_s8((const int8_t*)rhs_ptr + 32); - const int8x16_t rhs_raw_mat_23_1 = vld1q_s8((const int8_t*)rhs_ptr + 48); - - // Low nibble - const int8x16_t rhs_mat_01_0 = vshlq_n_s8(rhs_raw_mat_01_0, 4); - const int8x16_t rhs_mat_23_0 = vshlq_n_s8(rhs_raw_mat_23_0, 4); - const int8x16_t rhs_mat_01_1 = vshlq_n_s8(rhs_raw_mat_01_1, 4); - const int8x16_t rhs_mat_23_1 = vshlq_n_s8(rhs_raw_mat_23_1, 4); - - // High nibble - const int8x16_t rhs_mat_01_2 = vandq_s8(rhs_raw_mat_01_0, nibble_mask); - const int8x16_t rhs_mat_23_2 = vandq_s8(rhs_raw_mat_23_0, nibble_mask); - const int8x16_t rhs_mat_01_3 = vandq_s8(rhs_raw_mat_01_1, nibble_mask); - const int8x16_t rhs_mat_23_3 = vandq_s8(rhs_raw_mat_23_1, nibble_mask); - - // Process LHS in pairs of rows - const int8x16_t lhs_mat_01_0 = vld1q_s8((const int8_t*)lhs_ptr + 0); - const int8x16_t lhs_mat_23_0 = vld1q_s8((const int8_t*)lhs_ptr + 16); - const int8x16_t lhs_mat_01_1 = vld1q_s8((const int8_t*)lhs_ptr + 32); - const int8x16_t lhs_mat_23_1 = vld1q_s8((const int8_t*)lhs_ptr + 48); - const int8x16_t lhs_mat_01_2 = vld1q_s8((const int8_t*)lhs_ptr + 64); - const int8x16_t lhs_mat_23_2 = vld1q_s8((const int8_t*)lhs_ptr + 80); - const int8x16_t lhs_mat_01_3 = vld1q_s8((const int8_t*)lhs_ptr + 96); - const int8x16_t lhs_mat_23_3 = vld1q_s8((const int8_t*)lhs_ptr + 112); - - // Do the MMLAs into 2x2 matrices - iacc_mat_00 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vmmlaq_s32(iacc_mat_00, lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), - lhs_mat_01_2, rhs_mat_01_2), - lhs_mat_01_3, rhs_mat_01_3); - iacc_mat_01 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vmmlaq_s32(iacc_mat_01, lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), - lhs_mat_01_2, rhs_mat_23_2), - lhs_mat_01_3, rhs_mat_23_3); - iacc_mat_10 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vmmlaq_s32(iacc_mat_10, lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), - lhs_mat_23_2, rhs_mat_01_2), - lhs_mat_23_3, rhs_mat_01_3); - iacc_mat_11 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vmmlaq_s32(iacc_mat_11, lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), - lhs_mat_23_2, rhs_mat_23_2), - lhs_mat_23_3, rhs_mat_23_3); - - // Straighten out to make 4 row vectors - lhs_ptr += 128; - rhs_ptr += 64; - } - - // RHS scale - const float32x4_t rhs_scale = vld1q_f32((const float32_t*)rhs_ptr); - rhs_ptr += sizeof(float32x4_t); - - // LHS scale - const float32x4_t lhs_scale = vld1q_f32((const float32_t*)lhs_ptr); - lhs_ptr += sizeof(float32x4_t); - - int32x4_t iacc_row_0 = vreinterpretq_s32_u64( - vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); - int32x4_t iacc_row_1 = vreinterpretq_s32_u64( - vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); - int32x4_t iacc_row_2 = vreinterpretq_s32_u64( - vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); - int32x4_t iacc_row_3 = vreinterpretq_s32_u64( - vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); - - // LHS offset - const int32x4_t lhs_offset = vld1q_s32((const int32_t*)lhs_ptr_start); - - // RHS sum values - const int32x4_t sum_n_s32 = vld1q_s32((const int32_t*)(rhs_ptr_start)); - - // Add the RHS reduction sum - iacc_row_0 = vmlaq_laneq_s32(iacc_row_0, sum_n_s32, lhs_offset, 0); - iacc_row_1 = vmlaq_laneq_s32(iacc_row_1, sum_n_s32, lhs_offset, 1); - iacc_row_2 = vmlaq_laneq_s32(iacc_row_2, sum_n_s32, lhs_offset, 2); - iacc_row_3 = vmlaq_laneq_s32(iacc_row_3, sum_n_s32, lhs_offset, 3); - - float32x4_t main_acc0 = vmulq_f32(vcvtq_f32_s32(iacc_row_0), rhs_scale); - float32x4_t main_acc1 = vmulq_f32(vcvtq_f32_s32(iacc_row_1), rhs_scale); - float32x4_t main_acc2 = vmulq_f32(vcvtq_f32_s32(iacc_row_2), rhs_scale); - float32x4_t main_acc3 = vmulq_f32(vcvtq_f32_s32(iacc_row_3), rhs_scale); - - main_acc0 = vmulq_laneq_f32(main_acc0, lhs_scale, 0); - main_acc1 = vmulq_laneq_f32(main_acc1, lhs_scale, 1); - main_acc2 = vmulq_laneq_f32(main_acc2, lhs_scale, 2); - main_acc3 = vmulq_laneq_f32(main_acc3, lhs_scale, 3); - - // Clip (min-max) operation - const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); - const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); - - main_acc0 = vmaxq_f32(main_acc0, vmin_f32); - main_acc0 = vminq_f32(main_acc0, vmax_f32); - main_acc1 = vmaxq_f32(main_acc1, vmin_f32); - main_acc1 = vminq_f32(main_acc1, vmax_f32); - main_acc2 = vmaxq_f32(main_acc2, vmin_f32); - main_acc2 = vminq_f32(main_acc2, vmax_f32); - - // Stores the rows in reverse order to avoid out-of-bound writes. - // Override out-of-bound values with in-bound values - vst1q_f32( - (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 3, m - 1) * dst_stride_row), - main_acc3); - vst1q_f32( - (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 2, m - 1) * dst_stride_row), - main_acc2); - vst1q_f32( - (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 1, m - 1) * dst_stride_row), - main_acc1); - vst1q_f32( - (float*)((uint8_t*)dst + col_idx * sizeof(float) + KAI_MIN(row_idx + 0, m - 1) * dst_stride_row), - main_acc0); - } - - lhs_ptr_start += lhs_packed_stride; - } + size_t num_blocks = k / 32; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__("mov x26, #0x80\n" + "mov x20, #0x20\n" + "movi v4.16b, #0xf0\n" + "mov x25, %x[m]\n" + "madd x26, %x[num_blocks], x26, x20\n" + "cbz x25, 5f\n" + "1:" // Row loop + "mov x24, %x[rhs_p]\n" + "mov x23, %x[n]\n" + "add x22, %x[dst], %x[dst_stride_row], LSL #2\n" + "2:" // Column loop + "movi v3.4s, #0x0\n" + "movi v2.4s, #0x0\n" + "mov x21, %x[lhs_p]\n" + "mov x20, %x[num_blocks]\n" + "movi v1.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "3:" // Block loop + "ldr q31, [x24, #0x0]\n" + "ldr q30, [x24, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q29, [x21, #0x0]\n" + "ldr q28, [x21, #0x10]\n" + "ldr q27, [x24, #0x20]\n" + "ldr q26, [x24, #0x30]\n" + "add x24, x24, #0x40\n" + "ldr q25, [x21, #0x20]\n" + "ldr q24, [x21, #0x30]\n" + "shl v23.16b, v31.16b, #0x4\n" + "shl v22.16b, v30.16b, #0x4\n" + "ldr q21, [x21, #0x40]\n" + "ldr q20, [x21, #0x50]\n" + "and v31.16b, v31.16b, v4.16b\n" + "and v30.16b, v30.16b, v4.16b\n" + "ldr q19, [x21, #0x60]\n" + "ldr q18, [x21, #0x70]\n" + "shl v17.16b, v27.16b, #0x4\n" + "shl v16.16b, v26.16b, #0x4\n" + ".inst 0x4e97a7a3 // smmla v3.4s, v29.16b, v23.16b\n" + ".inst 0x4e96a7a2 // smmla v2.4s, v29.16b, v22.16b\n" + "and v27.16b, v27.16b, v4.16b\n" + "add x21, x21, #0x80\n" + ".inst 0x4e97a781 // smmla v1.4s, v28.16b, v23.16b\n" + ".inst 0x4e96a780 // smmla v0.4s, v28.16b, v22.16b\n" + "and v26.16b, v26.16b, v4.16b\n" + ".inst 0x4e91a723 // smmla v3.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a722 // smmla v2.4s, v25.16b, v16.16b\n" + ".inst 0x4e91a701 // smmla v1.4s, v24.16b, v17.16b\n" + ".inst 0x4e90a700 // smmla v0.4s, v24.16b, v16.16b\n" + ".inst 0x4e9fa6a3 // smmla v3.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a2 // smmla v2.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa681 // smmla v1.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea680 // smmla v0.4s, v20.16b, v30.16b\n" + ".inst 0x4e9ba663 // smmla v3.4s, v19.16b, v27.16b\n" + ".inst 0x4e9aa662 // smmla v2.4s, v19.16b, v26.16b\n" + ".inst 0x4e9ba641 // smmla v1.4s, v18.16b, v27.16b\n" + ".inst 0x4e9aa640 // smmla v0.4s, v18.16b, v26.16b\n" + "bgt 3b\n" + "ldr q18, [x24, #0x0]\n" + "ldr q17, [x21, #0x0]\n" + "uzp1 v26.2d, v3.2d, v2.2d\n" + "uzp2 v25.2d, v3.2d, v2.2d\n" + "ldr q24, [x24, #0x10]\n" + "ldr q16, [x21, #0x10]\n" + "uzp1 v23.2d, v1.2d, v0.2d\n" + "uzp2 v22.2d, v1.2d, v0.2d\n" + "ld1r { v21.4s }, [%x[clamp_vals]]\n" + "add x21, %x[clamp_vals], #0x4\n" + "mov x20, %x[dst]\n" + "ld1r { v20.4s }, [x21]\n" + "mla v26.4s, v18.4s, v17.s[0]\n" + "mla v25.4s, v18.4s, v17.s[1]\n" + "cmp x25, #0x1\n" + "add x24, x24, #0x20\n" + "mla v23.4s, v18.4s, v17.s[2]\n" + "mla v22.4s, v18.4s, v17.s[3]\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v26.4s, v26.4s\n" + "scvtf v25.4s, v25.4s\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v22.4s, v22.4s\n" + "fmul v26.4s, v26.4s, v19.4s\n" + "fmul v25.4s, v25.4s, v18.4s\n" + "fmul v23.4s, v23.4s, v17.4s\n" + "fmul v22.4s, v22.4s, v16.4s\n" + "fmax v26.4s, v26.4s, v21.4s\n" + "fmax v25.4s, v25.4s, v21.4s\n" + "fmax v23.4s, v23.4s, v21.4s\n" + "fmax v22.4s, v22.4s, v21.4s\n" + "fmin v26.4s, v26.4s, v20.4s\n" + "fmin v25.4s, v25.4s, v20.4s\n" + "fmin v23.4s, v23.4s, v20.4s\n" + "fmin v22.4s, v22.4s, v20.4s\n" + "str q26, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 4f\n" + "cmp x25, #0x2\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 4f\n" + "cmp x25, #0x3\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 4f\n" + "str q22, [x20, #0x0]\n" + "4:" // Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bne 2b\n" + "subs x25, x25, #0x4\n" + "add %x[lhs_p], %x[lhs_p], x26\n" + "mov %x[dst], x22\n" + "bgt 1b\n" + "5:" // Row loop skip + : [lhs_p] "+&r"(lhs_p), [dst] "+&r"(dst) + : [rhs_p] "r"(rhs_p), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), + [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", + "x23", "x24", "x25", "x26"); #else KAI_ASSERT(false); KAI_UNUSED(m); diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h index b89a2a46..b4f965fb 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h @@ -14,6 +14,14 @@ extern "C" { #include #include +/** + * @brief Function to get the mr value, which must be used to pack the LHS matrix with + * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function + * + * @return the mr value + */ +size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void); + /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function @@ -88,15 +96,15 @@ size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(siz * LHS matrix: 8-bit quantized asymmeitric per-row (qa8dx) and packed * RHS matrix: 4-bit quantized symmetric per-channel (qs4cx) and packed. * Output tile: (rows x cols) = 4 x 4 - * Accumulation performed in a single for loop: 64 + * Accumulation performed in a single for loop: 32 * Instruction used: i8mm * * @param[in] m The number of output rows written. It must be either 1, 2, 3, or any multiple of 4. * @param[in] n The number of output columns written. It must be a multiple of 4. - * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 64. + * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 32. * @param[in] lhs_p The LHS matrix packed. * When the activation are dynamically quantized, you can obtain this matrix - * by calling the @ref kai_lhs_quant_pack_qa8dxP1X8_f32 micro-kernel which performs + * by calling the @ref kai_run_lhs_quant_pack_qa8dsP_f32 micro-kernel which performs * both the dynamic quantization to 8-bit and activation packing in a single step. * @param[in] rhs_p The RHS matrix packed, which is obtained by calling @ref * kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.c b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.c new file mode 100644 index 00000000..664e4b88 --- /dev/null +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.c @@ -0,0 +1,330 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h" + +#include "../../kai_common.h" + +#include +#include + +static const size_t kai_mr = 4; +static const size_t kai_nr = 8; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_k0 = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); + +size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_mr) == 0); + const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + + return (m_idx / kai_mr) * lhs_packed_stride; +} + +size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((k % kai_k0) == 0); + KAI_ASSERT((n_idx % kai_nr) == 0); + + const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + + return (n_idx / kai_nr) * rhs_packed_stride; +} + +size_t +kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_mr) == 0); + KAI_ASSERT((n_idx % kai_nr) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t m, size_t n) { + KAI_ASSERT((n % kai_nr) == 0); + + return m * n * sizeof(float); +} + +void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* restrict lhs_p, const void* restrict rhs_p, float* restrict dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { +#if defined(__ARM_FEATURE_MATMUL_INT8) + KAI_ASSERT(n % kai_nr == 0); + KAI_ASSERT(k % kai_k0 == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t num_rows = m; + const size_t num_cols = n; + + const size_t lhs_packed_stride = kai_mr * (k + sizeof(float32_t) + sizeof(float32_t)); + + const int8x16_t nibble_mask = vdupq_n_s8(0xF0); + + const uint8_t* lhs_ptr_start = lhs_p; + + for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { + const uint8_t* rhs_ptr = rhs_p; + + for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { + const uint8_t* lhs_ptr = lhs_ptr_start; + + // Main f32 accumulator + int32x4_t iacc_mat_00 = vdupq_n_s32(0); + int32x4_t iacc_mat_01 = vdupq_n_s32(0); + int32x4_t iacc_mat_10 = vdupq_n_s32(0); + int32x4_t iacc_mat_11 = vdupq_n_s32(0); + + int32x4_t iacc_mat_02 = vdupq_n_s32(0); + int32x4_t iacc_mat_03 = vdupq_n_s32(0); + int32x4_t iacc_mat_12 = vdupq_n_s32(0); + int32x4_t iacc_mat_13 = vdupq_n_s32(0); + + for (size_t b = 0; b < k; b += kai_k0) { + // Set up RHS + const int8x16_t rhs_raw_mat_01_0 = vld1q_s8((const int8_t*)rhs_ptr + 0); + const int8x16_t rhs_raw_mat_23_0 = vld1q_s8((const int8_t*)rhs_ptr + 16); + const int8x16_t rhs_raw_mat_45_0 = vld1q_s8((const int8_t*)rhs_ptr + 32); + const int8x16_t rhs_raw_mat_67_0 = vld1q_s8((const int8_t*)rhs_ptr + 48); + const int8x16_t rhs_raw_mat_01_1 = vld1q_s8((const int8_t*)rhs_ptr + 64); + const int8x16_t rhs_raw_mat_23_1 = vld1q_s8((const int8_t*)rhs_ptr + 80); + const int8x16_t rhs_raw_mat_45_1 = vld1q_s8((const int8_t*)rhs_ptr + 96); + const int8x16_t rhs_raw_mat_67_1 = vld1q_s8((const int8_t*)rhs_ptr + 112); + + // Low nibble + const int8x16_t rhs_mat_01_0 = vshlq_n_s8(rhs_raw_mat_01_0, 4); + const int8x16_t rhs_mat_23_0 = vshlq_n_s8(rhs_raw_mat_23_0, 4); + const int8x16_t rhs_mat_45_0 = vshlq_n_s8(rhs_raw_mat_45_0, 4); + const int8x16_t rhs_mat_67_0 = vshlq_n_s8(rhs_raw_mat_67_0, 4); + + const int8x16_t rhs_mat_01_1 = vshlq_n_s8(rhs_raw_mat_01_1, 4); + const int8x16_t rhs_mat_23_1 = vshlq_n_s8(rhs_raw_mat_23_1, 4); + const int8x16_t rhs_mat_45_1 = vshlq_n_s8(rhs_raw_mat_45_1, 4); + const int8x16_t rhs_mat_67_1 = vshlq_n_s8(rhs_raw_mat_67_1, 4); + + // High nibble + const int8x16_t rhs_mat_01_2 = vandq_s8(rhs_raw_mat_01_0, nibble_mask); + const int8x16_t rhs_mat_23_2 = vandq_s8(rhs_raw_mat_23_0, nibble_mask); + const int8x16_t rhs_mat_45_2 = vandq_s8(rhs_raw_mat_45_0, nibble_mask); + const int8x16_t rhs_mat_67_2 = vandq_s8(rhs_raw_mat_67_0, nibble_mask); + + const int8x16_t rhs_mat_01_3 = vandq_s8(rhs_raw_mat_01_1, nibble_mask); + const int8x16_t rhs_mat_23_3 = vandq_s8(rhs_raw_mat_23_1, nibble_mask); + const int8x16_t rhs_mat_45_3 = vandq_s8(rhs_raw_mat_45_1, nibble_mask); + const int8x16_t rhs_mat_67_3 = vandq_s8(rhs_raw_mat_67_1, nibble_mask); + + // Process LHS in pairs of rows + const int8x16_t lhs_mat_01_0 = vld1q_s8((const int8_t*)lhs_ptr + 0); + const int8x16_t lhs_mat_23_0 = vld1q_s8((const int8_t*)lhs_ptr + 16); + const int8x16_t lhs_mat_01_1 = vld1q_s8((const int8_t*)lhs_ptr + 32); + const int8x16_t lhs_mat_23_1 = vld1q_s8((const int8_t*)lhs_ptr + 48); + const int8x16_t lhs_mat_01_2 = vld1q_s8((const int8_t*)lhs_ptr + 64); + const int8x16_t lhs_mat_23_2 = vld1q_s8((const int8_t*)lhs_ptr + 80); + const int8x16_t lhs_mat_01_3 = vld1q_s8((const int8_t*)lhs_ptr + 96); + const int8x16_t lhs_mat_23_3 = vld1q_s8((const int8_t*)lhs_ptr + 112); + + // Do the MMLAs into 2x2 matrices + iacc_mat_00 = vmmlaq_s32( + vmmlaq_s32( + vmmlaq_s32(vmmlaq_s32(iacc_mat_00, lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), + lhs_mat_01_2, rhs_mat_01_2), + lhs_mat_01_3, rhs_mat_01_3); + iacc_mat_01 = vmmlaq_s32( + vmmlaq_s32( + vmmlaq_s32(vmmlaq_s32(iacc_mat_01, lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), + lhs_mat_01_2, rhs_mat_23_2), + lhs_mat_01_3, rhs_mat_23_3); + iacc_mat_10 = vmmlaq_s32( + vmmlaq_s32( + vmmlaq_s32(vmmlaq_s32(iacc_mat_10, lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), + lhs_mat_23_2, rhs_mat_01_2), + lhs_mat_23_3, rhs_mat_01_3); + iacc_mat_11 = vmmlaq_s32( + vmmlaq_s32( + vmmlaq_s32(vmmlaq_s32(iacc_mat_11, lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), + lhs_mat_23_2, rhs_mat_23_2), + lhs_mat_23_3, rhs_mat_23_3); + + /// + + iacc_mat_02 = vmmlaq_s32( + vmmlaq_s32( + vmmlaq_s32(vmmlaq_s32(iacc_mat_02, lhs_mat_01_0, rhs_mat_45_0), lhs_mat_01_1, rhs_mat_45_1), + lhs_mat_01_2, rhs_mat_45_2), + lhs_mat_01_3, rhs_mat_45_3); + iacc_mat_03 = vmmlaq_s32( + vmmlaq_s32( + vmmlaq_s32(vmmlaq_s32(iacc_mat_03, lhs_mat_01_0, rhs_mat_67_0), lhs_mat_01_1, rhs_mat_67_1), + lhs_mat_01_2, rhs_mat_67_2), + lhs_mat_01_3, rhs_mat_67_3); + iacc_mat_12 = vmmlaq_s32( + vmmlaq_s32( + vmmlaq_s32(vmmlaq_s32(iacc_mat_12, lhs_mat_23_0, rhs_mat_45_0), lhs_mat_23_1, rhs_mat_45_1), + lhs_mat_23_2, rhs_mat_45_2), + lhs_mat_23_3, rhs_mat_45_3); + iacc_mat_13 = vmmlaq_s32( + vmmlaq_s32( + vmmlaq_s32(vmmlaq_s32(iacc_mat_13, lhs_mat_23_0, rhs_mat_67_0), lhs_mat_23_1, rhs_mat_67_1), + lhs_mat_23_2, rhs_mat_67_2), + lhs_mat_23_3, rhs_mat_67_3); + + // Straighten out to make 4 row vectors + lhs_ptr += 128; + rhs_ptr += 128; + } + + int32x4_t iacc_row_0_0123 = vreinterpretq_s32_u64( + vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); + int32x4_t iacc_row_1_0123 = vreinterpretq_s32_u64( + vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); + int32x4_t iacc_row_2_0123 = vreinterpretq_s32_u64( + vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); + int32x4_t iacc_row_3_0123 = vreinterpretq_s32_u64( + vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); + + int32x4_t iacc_row_0_4567 = vreinterpretq_s32_u64( + vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_02), vreinterpretq_u64_s32(iacc_mat_03))); + int32x4_t iacc_row_1_4567 = vreinterpretq_s32_u64( + vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_02), vreinterpretq_u64_s32(iacc_mat_03))); + int32x4_t iacc_row_2_4567 = vreinterpretq_s32_u64( + vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_12), vreinterpretq_u64_s32(iacc_mat_13))); + int32x4_t iacc_row_3_4567 = vreinterpretq_s32_u64( + vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_12), vreinterpretq_u64_s32(iacc_mat_13))); + + // LHS offset + const int32x4_t lhs_offset = vld1q_s32((const int32_t*)lhs_ptr); + lhs_ptr += sizeof(int32x4_t); + + // LHS scale + const float32x4_t lhs_scale = vld1q_f32((const float32_t*)lhs_ptr); + lhs_ptr += sizeof(float32x4_t); + + // RHS sum values + const int32x4_t sum_n_s32_0 = vld1q_s32((const int32_t*)(rhs_ptr)); + rhs_ptr += sizeof(int32x4_t); + const int32x4_t sum_n_s32_1 = vld1q_s32((const int32_t*)(rhs_ptr)); + rhs_ptr += sizeof(int32x4_t); + + // RHS scale + const float32x4_t rhs_scale0 = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + const float32x4_t rhs_scale1 = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + + // Add the RHS reduction sum + iacc_row_0_0123 = vmlaq_laneq_s32(iacc_row_0_0123, sum_n_s32_0, lhs_offset, 0); + iacc_row_1_0123 = vmlaq_laneq_s32(iacc_row_1_0123, sum_n_s32_0, lhs_offset, 1); + iacc_row_2_0123 = vmlaq_laneq_s32(iacc_row_2_0123, sum_n_s32_0, lhs_offset, 2); + iacc_row_3_0123 = vmlaq_laneq_s32(iacc_row_3_0123, sum_n_s32_0, lhs_offset, 3); + + iacc_row_0_4567 = vmlaq_laneq_s32(iacc_row_0_4567, sum_n_s32_1, lhs_offset, 0); + iacc_row_1_4567 = vmlaq_laneq_s32(iacc_row_1_4567, sum_n_s32_1, lhs_offset, 1); + iacc_row_2_4567 = vmlaq_laneq_s32(iacc_row_2_4567, sum_n_s32_1, lhs_offset, 2); + iacc_row_3_4567 = vmlaq_laneq_s32(iacc_row_3_4567, sum_n_s32_1, lhs_offset, 3); + + float32x4_t main_acc0_0123 = vmulq_f32(vcvtq_f32_s32(iacc_row_0_0123), rhs_scale0); + float32x4_t main_acc1_0123 = vmulq_f32(vcvtq_f32_s32(iacc_row_1_0123), rhs_scale0); + float32x4_t main_acc2_0123 = vmulq_f32(vcvtq_f32_s32(iacc_row_2_0123), rhs_scale0); + float32x4_t main_acc3_0123 = vmulq_f32(vcvtq_f32_s32(iacc_row_3_0123), rhs_scale0); + + float32x4_t main_acc0_4567 = vmulq_f32(vcvtq_f32_s32(iacc_row_0_4567), rhs_scale1); + float32x4_t main_acc1_4567 = vmulq_f32(vcvtq_f32_s32(iacc_row_1_4567), rhs_scale1); + float32x4_t main_acc2_4567 = vmulq_f32(vcvtq_f32_s32(iacc_row_2_4567), rhs_scale1); + float32x4_t main_acc3_4567 = vmulq_f32(vcvtq_f32_s32(iacc_row_3_4567), rhs_scale1); + + main_acc0_0123 = vmulq_laneq_f32(main_acc0_0123, lhs_scale, 0); + main_acc1_0123 = vmulq_laneq_f32(main_acc1_0123, lhs_scale, 1); + main_acc2_0123 = vmulq_laneq_f32(main_acc2_0123, lhs_scale, 2); + main_acc3_0123 = vmulq_laneq_f32(main_acc3_0123, lhs_scale, 3); + + main_acc0_4567 = vmulq_laneq_f32(main_acc0_4567, lhs_scale, 0); + main_acc1_4567 = vmulq_laneq_f32(main_acc1_4567, lhs_scale, 1); + main_acc2_4567 = vmulq_laneq_f32(main_acc2_4567, lhs_scale, 2); + main_acc3_4567 = vmulq_laneq_f32(main_acc3_4567, lhs_scale, 3); + + // Clip (min-max) operation + const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); + const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); + + main_acc0_0123 = vmaxq_f32(main_acc0_0123, vmin_f32); + main_acc0_0123 = vminq_f32(main_acc0_0123, vmax_f32); + main_acc1_0123 = vmaxq_f32(main_acc1_0123, vmin_f32); + main_acc1_0123 = vminq_f32(main_acc1_0123, vmax_f32); + main_acc2_0123 = vmaxq_f32(main_acc2_0123, vmin_f32); + main_acc2_0123 = vminq_f32(main_acc2_0123, vmax_f32); + + main_acc0_4567 = vmaxq_f32(main_acc0_4567, vmin_f32); + main_acc0_4567 = vminq_f32(main_acc0_4567, vmax_f32); + main_acc1_4567 = vmaxq_f32(main_acc1_4567, vmin_f32); + main_acc1_4567 = vminq_f32(main_acc1_4567, vmax_f32); + main_acc2_4567 = vmaxq_f32(main_acc2_4567, vmin_f32); + main_acc2_4567 = vminq_f32(main_acc2_4567, vmax_f32); + + // Stores the rows in reverse order to avoid out-of-bound writes. + // Override out-of-bound values with in-bound values + vst1q_f32( + (float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + KAI_MIN(row_idx + 3, m - 1) * dst_stride_row), + main_acc3_0123); + vst1q_f32( + (float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + KAI_MIN(row_idx + 3, m - 1) * dst_stride_row), + main_acc3_4567); + vst1q_f32( + (float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + KAI_MIN(row_idx + 2, m - 1) * dst_stride_row), + main_acc2_0123); + vst1q_f32( + (float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + KAI_MIN(row_idx + 2, m - 1) * dst_stride_row), + main_acc2_4567); + vst1q_f32( + (float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + KAI_MIN(row_idx + 1, m - 1) * dst_stride_row), + main_acc1_0123); + vst1q_f32( + (float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + KAI_MIN(row_idx + 1, m - 1) * dst_stride_row), + main_acc1_4567); + vst1q_f32( + (float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + KAI_MIN(row_idx + 0, m - 1) * dst_stride_row), + main_acc0_0123); + vst1q_f32( + (float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + KAI_MIN(row_idx + 0, m - 1) * dst_stride_row), + main_acc0_4567); + } + + lhs_ptr_start += lhs_packed_stride; + } +#else + KAI_ASSERT(false); + KAI_UNUSED(m); + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(lhs_p); + KAI_UNUSED(rhs_p); + KAI_UNUSED(dst); + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(scalar_min); + KAI_UNUSED(scalar_max); +#endif +} diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h new file mode 100644 index 00000000..fac55cf4 --- /dev/null +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h @@ -0,0 +1,123 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +#include +#include + +/** + * @brief Function to get the mr value, which must be used to pack the LHS matrix with + * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function + * + * @return the mr value + */ +size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the nr value + */ +size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void); + +/** + * @brief Function to get the kr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the kr value + */ +size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the sr value + */ +size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). It must be a multiple of 64. + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t m_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. + * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be a multiple of 64. + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t n_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the DST matrix + * + * @param[in] m_idx Row index in the DST matrix. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * + * return the DST offset in bytes + */ +size_t +kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride); + +/** + * @brief Function to query the size in bytes for the constant workspace. + * + * @param[in] m Number of rows in the destination (DST) matrix. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] n Number of columns in the destination (DST) matrix. It must be a multiple 4. + */ +size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t m, size_t n); + +/** + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clip (min-max) operation. + * + * LHS matrix: 8-bit quantized asymmeitric per-row (qa8dx) and packed + * RHS matrix: 4-bit quantized symmetric per-channel (qs4cx) and packed. + * Output tile: (rows x cols) = 4 x 8 + * Accumulation performed in a single for loop: 32 + * Instruction used: i8mm + * + * @param[in] m The number of output rows written. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] n The number of output columns written. It must be a multiple of 4. + * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 32. + * @param[in] lhs_p The LHS matrix packed. + * When the activation are dynamically quantized, you can obtain this matrix + * by calling the @ref kai_run_lhs_quant_pack_qa8dsP_f32 micro-kernel which performs + * both the dynamic quantization to 8-bit and activation packing in a single step. + * @param[in] rhs_p The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 + * @param[out] dst Result of the vector-by-matrix + * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. + * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) + * @param[in] scalar_min Min value used to clip the final result. + * @param[in] scalar_max Max value used to clip the final result. + */ +void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.c b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.c new file mode 100644 index 00000000..eea059c9 --- /dev/null +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.c @@ -0,0 +1,397 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h" + +#include "../../kai_common.h" + +#include +#include + +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_k0 = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); + +size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_mr) == 0); + const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + + return (m_idx / kai_mr) * lhs_packed_stride; +} + +size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((k % kai_k0) == 0); + KAI_ASSERT((n_idx % kai_nr) == 0); + + const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + + return (n_idx / kai_nr) * rhs_packed_stride; +} + +size_t +kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_mr) == 0); + KAI_ASSERT((n_idx % kai_nr) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t m, size_t n) { + KAI_ASSERT((n % kai_nr) == 0); + + return m * n * sizeof(float); +} + +void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max) { +#if defined(__ARM_FEATURE_MATMUL_INT8) + KAI_ASSERT(n % kai_nr == 0); + KAI_ASSERT(k % kai_k0 == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + size_t num_blocks = k / 32; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__("mov x27, %x[m]\n" + "mov x26, #0x80\n" + "movi v11.16b, #0xf0\n" + "mov x20, #0x20\n" + "cmp x27, #0x8\n" + "madd x26, %x[num_blocks], x26, x20\n" + "blt 4f\n" + "1:" // Row loop + "mov x24, %x[rhs_p]\n" + "mov x23, %x[n]\n" + "add x22, %x[dst], %x[dst_stride_row], LSL #3\n" + "2:" // Column loop + "mov x25, %x[lhs_p]\n" + "movi v10.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "mov x21, %x[num_blocks]\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "movi v4.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "add x20, x25, x26\n" + "3:" // Block loop + "ldr q2, [x24, #0x0]\n" + "ldr q1, [x24, #0x10]\n" + "subs x21, x21, #0x1\n" + "ldr q20, [x25, #0x0]\n" + "ldr q19, [x25, #0x10]\n" + "ldr q18, [x20, #0x0]\n" + "ldr q0, [x20, #0x10]\n" + "ldr q31, [x24, #0x20]\n" + "ldr q30, [x24, #0x30]\n" + "shl v17.16b, v2.16b, #0x4\n" + "shl v16.16b, v1.16b, #0x4\n" + "ldr q29, [x25, #0x20]\n" + "ldr q28, [x25, #0x30]\n" + "and v2.16b, v2.16b, v11.16b\n" + "and v1.16b, v1.16b, v11.16b\n" + "ldr q27, [x20, #0x20]\n" + "ldr q26, [x20, #0x30]\n" + "add x24, x24, #0x40\n" + "ldr q25, [x25, #0x40]\n" + "ldr q24, [x25, #0x50]\n" + ".inst 0x4e91a68a // smmla v10.4s, v20.16b, v17.16b\n" + ".inst 0x4e90a689 // smmla v9.4s, v20.16b, v16.16b\n" + "ldr q23, [x20, #0x40]\n" + "ldr q22, [x20, #0x50]\n" + ".inst 0x4e91a668 // smmla v8.4s, v19.16b, v17.16b\n" + ".inst 0x4e90a667 // smmla v7.4s, v19.16b, v16.16b\n" + "ldr q21, [x25, #0x60]\n" + "ldr q20, [x25, #0x70]\n" + ".inst 0x4e91a646 // smmla v6.4s, v18.16b, v17.16b\n" + ".inst 0x4e90a645 // smmla v5.4s, v18.16b, v16.16b\n" + "ldr q19, [x20, #0x60]\n" + "ldr q18, [x20, #0x70]\n" + ".inst 0x4e91a404 // smmla v4.4s, v0.16b, v17.16b\n" + ".inst 0x4e90a403 // smmla v3.4s, v0.16b, v16.16b\n" + "shl v17.16b, v31.16b, #0x4\n" + "shl v16.16b, v30.16b, #0x4\n" + "add x25, x25, #0x80\n" + "add x20, x20, #0x80\n" + "and v31.16b, v31.16b, v11.16b\n" + "and v30.16b, v30.16b, v11.16b\n" + ".inst 0x4e91a7aa // smmla v10.4s, v29.16b, v17.16b\n" + ".inst 0x4e90a7a9 // smmla v9.4s, v29.16b, v16.16b\n" + ".inst 0x4e91a788 // smmla v8.4s, v28.16b, v17.16b\n" + ".inst 0x4e90a787 // smmla v7.4s, v28.16b, v16.16b\n" + ".inst 0x4e91a766 // smmla v6.4s, v27.16b, v17.16b\n" + ".inst 0x4e90a765 // smmla v5.4s, v27.16b, v16.16b\n" + ".inst 0x4e91a744 // smmla v4.4s, v26.16b, v17.16b\n" + ".inst 0x4e90a743 // smmla v3.4s, v26.16b, v16.16b\n" + ".inst 0x4e82a72a // smmla v10.4s, v25.16b, v2.16b\n" + ".inst 0x4e81a729 // smmla v9.4s, v25.16b, v1.16b\n" + ".inst 0x4e82a708 // smmla v8.4s, v24.16b, v2.16b\n" + ".inst 0x4e81a707 // smmla v7.4s, v24.16b, v1.16b\n" + ".inst 0x4e82a6e6 // smmla v6.4s, v23.16b, v2.16b\n" + ".inst 0x4e81a6e5 // smmla v5.4s, v23.16b, v1.16b\n" + ".inst 0x4e82a6c4 // smmla v4.4s, v22.16b, v2.16b\n" + ".inst 0x4e81a6c3 // smmla v3.4s, v22.16b, v1.16b\n" + ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" + ".inst 0x4e9fa666 // smmla v6.4s, v19.16b, v31.16b\n" + ".inst 0x4e9ea665 // smmla v5.4s, v19.16b, v30.16b\n" + ".inst 0x4e9fa644 // smmla v4.4s, v18.16b, v31.16b\n" + ".inst 0x4e9ea643 // smmla v3.4s, v18.16b, v30.16b\n" + "bgt 3b\n" + "ldr q1, [x24, #0x0]\n" + "ldr q16, [x25, #0x0]\n" + "uzp1 v0.2d, v10.2d, v9.2d\n" + "uzp2 v31.2d, v10.2d, v9.2d\n" + "ldr q30, [x24, #0x10]\n" + "ldr q29, [x25, #0x10]\n" + "uzp1 v28.2d, v8.2d, v7.2d\n" + "uzp2 v27.2d, v8.2d, v7.2d\n" + "ldr q17, [x20, #0x0]\n" + "ldr q26, [x20, #0x10]\n" + "uzp1 v25.2d, v6.2d, v5.2d\n" + "uzp2 v24.2d, v6.2d, v5.2d\n" + "ld1r { v23.4s }, [%x[clamp_vals]]\n" + "mla v0.4s, v1.4s, v16.s[0]\n" + "mla v31.4s, v1.4s, v16.s[1]\n" + "uzp1 v22.2d, v4.2d, v3.2d\n" + "mla v28.4s, v1.4s, v16.s[2]\n" + "mla v27.4s, v1.4s, v16.s[3]\n" + "fmul v21.4s, v30.4s, v29.s[0]\n" + "add x20, %x[clamp_vals], #0x4\n" + "ld1r { v20.4s }, [x20]\n" + "uzp2 v19.2d, v4.2d, v3.2d\n" + "mla v25.4s, v1.4s, v17.s[0]\n" + "mla v24.4s, v1.4s, v17.s[1]\n" + "fmul v16.4s, v30.4s, v29.s[1]\n" + "fmul v18.4s, v30.4s, v29.s[2]\n" + "mla v22.4s, v1.4s, v17.s[2]\n" + "mov x20, %x[dst]\n" + "scvtf v0.4s, v0.4s\n" + "scvtf v31.4s, v31.4s\n" + "subs x23, x23, #0x4\n" + "add x24, x24, #0x20\n" + "scvtf v28.4s, v28.4s\n" + "scvtf v27.4s, v27.4s\n" + "mla v19.4s, v1.4s, v17.s[3]\n" + "add %x[dst], %x[dst], #0x10\n" + "fmul v17.4s, v30.4s, v29.s[3]\n" + "scvtf v25.4s, v25.4s\n" + "fmul v0.4s, v0.4s, v21.4s\n" + "fmul v31.4s, v31.4s, v16.4s\n" + "fmul v16.4s, v30.4s, v26.s[0]\n" + "fmul v28.4s, v28.4s, v18.4s\n" + "scvtf v24.4s, v24.4s\n" + "fmul v18.4s, v30.4s, v26.s[1]\n" + "fmul v27.4s, v27.4s, v17.4s\n" + "scvtf v22.4s, v22.4s\n" + "fmul v17.4s, v30.4s, v26.s[2]\n" + "fmax v0.4s, v0.4s, v23.4s\n" + "fmul v25.4s, v25.4s, v16.4s\n" + "scvtf v19.4s, v19.4s\n" + "fmul v16.4s, v30.4s, v26.s[3]\n" + "fmax v31.4s, v31.4s, v23.4s\n" + "fmul v24.4s, v24.4s, v18.4s\n" + "fmax v28.4s, v28.4s, v23.4s\n" + "fmul v22.4s, v22.4s, v17.4s\n" + "fmin v0.4s, v0.4s, v20.4s\n" + "fmax v27.4s, v27.4s, v23.4s\n" + "fmul v19.4s, v19.4s, v16.4s\n" + "fmin v31.4s, v31.4s, v20.4s\n" + "fmax v25.4s, v25.4s, v23.4s\n" + "str q0, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "fmin v28.4s, v28.4s, v20.4s\n" + "fmax v24.4s, v24.4s, v23.4s\n" + "fmin v27.4s, v27.4s, v20.4s\n" + "fmax v22.4s, v22.4s, v23.4s\n" + "str q31, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "fmin v25.4s, v25.4s, v20.4s\n" + "fmax v19.4s, v19.4s, v23.4s\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "fmin v24.4s, v24.4s, v20.4s\n" + "str q27, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "fmin v22.4s, v22.4s, v20.4s\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "fmin v19.4s, v19.4s, v20.4s\n" + "str q24, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q19, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x2\n" + "sub x27, x27, #0x8\n" + "cmp x27, #0x8\n" + "mov %x[dst], x22\n" + "madd %x[lhs_p], x20, x26, %x[lhs_p]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x27, 9f\n" + "5:" // Row tail: Row loop + "mov x24, %x[rhs_p]\n" + "mov x23, %x[n]\n" + "add x22, %x[dst], %x[dst_stride_row], LSL #2\n" + "6:" // Row tail: Column loop + "movi v10.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "mov x25, %x[lhs_p]\n" + "mov x20, %x[num_blocks]\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "7:" // Row tail: Block loop + "ldr q31, [x24, #0x0]\n" + "ldr q30, [x24, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q29, [x25, #0x0]\n" + "ldr q28, [x25, #0x10]\n" + "ldr q27, [x24, #0x20]\n" + "ldr q26, [x24, #0x30]\n" + "add x24, x24, #0x40\n" + "ldr q25, [x25, #0x20]\n" + "ldr q24, [x25, #0x30]\n" + "shl v23.16b, v31.16b, #0x4\n" + "shl v22.16b, v30.16b, #0x4\n" + "ldr q21, [x25, #0x40]\n" + "ldr q20, [x25, #0x50]\n" + "and v31.16b, v31.16b, v11.16b\n" + "and v30.16b, v30.16b, v11.16b\n" + "ldr q19, [x25, #0x60]\n" + "ldr q18, [x25, #0x70]\n" + "shl v17.16b, v27.16b, #0x4\n" + "shl v16.16b, v26.16b, #0x4\n" + ".inst 0x4e97a7aa // smmla v10.4s, v29.16b, v23.16b\n" + ".inst 0x4e96a7a9 // smmla v9.4s, v29.16b, v22.16b\n" + "and v27.16b, v27.16b, v11.16b\n" + "add x25, x25, #0x80\n" + ".inst 0x4e97a788 // smmla v8.4s, v28.16b, v23.16b\n" + ".inst 0x4e96a787 // smmla v7.4s, v28.16b, v22.16b\n" + "and v26.16b, v26.16b, v11.16b\n" + ".inst 0x4e91a72a // smmla v10.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a729 // smmla v9.4s, v25.16b, v16.16b\n" + ".inst 0x4e91a708 // smmla v8.4s, v24.16b, v17.16b\n" + ".inst 0x4e90a707 // smmla v7.4s, v24.16b, v16.16b\n" + ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" + ".inst 0x4e9ba66a // smmla v10.4s, v19.16b, v27.16b\n" + ".inst 0x4e9aa669 // smmla v9.4s, v19.16b, v26.16b\n" + ".inst 0x4e9ba648 // smmla v8.4s, v18.16b, v27.16b\n" + ".inst 0x4e9aa647 // smmla v7.4s, v18.16b, v26.16b\n" + "bgt 7b\n" + "ldr q18, [x24, #0x0]\n" + "ldr q17, [x25, #0x0]\n" + "uzp1 v26.2d, v10.2d, v9.2d\n" + "uzp2 v25.2d, v10.2d, v9.2d\n" + "ldr q24, [x24, #0x10]\n" + "ldr q16, [x25, #0x10]\n" + "uzp1 v23.2d, v8.2d, v7.2d\n" + "uzp2 v22.2d, v8.2d, v7.2d\n" + "ld1r { v21.4s }, [%x[clamp_vals]]\n" + "add x21, %x[clamp_vals], #0x4\n" + "mov x20, %x[dst]\n" + "ld1r { v20.4s }, [x21]\n" + "mla v26.4s, v18.4s, v17.s[0]\n" + "mla v25.4s, v18.4s, v17.s[1]\n" + "cmp x27, #0x1\n" + "add x24, x24, #0x20\n" + "mla v23.4s, v18.4s, v17.s[2]\n" + "mla v22.4s, v18.4s, v17.s[3]\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v26.4s, v26.4s\n" + "scvtf v25.4s, v25.4s\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v22.4s, v22.4s\n" + "fmul v26.4s, v26.4s, v19.4s\n" + "fmul v25.4s, v25.4s, v18.4s\n" + "fmul v23.4s, v23.4s, v17.4s\n" + "fmul v22.4s, v22.4s, v16.4s\n" + "fmax v26.4s, v26.4s, v21.4s\n" + "fmax v25.4s, v25.4s, v21.4s\n" + "fmax v23.4s, v23.4s, v21.4s\n" + "fmax v22.4s, v22.4s, v21.4s\n" + "fmin v26.4s, v26.4s, v20.4s\n" + "fmin v25.4s, v25.4s, v20.4s\n" + "fmin v23.4s, v23.4s, v20.4s\n" + "fmin v22.4s, v22.4s, v20.4s\n" + "str q26, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 8f\n" + "cmp x27, #0x2\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 8f\n" + "cmp x27, #0x3\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 8f\n" + "str q22, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bne 6b\n" + "subs x27, x27, #0x4\n" + "add %x[lhs_p], %x[lhs_p], x26\n" + "mov %x[dst], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [lhs_p] "+&r"(lhs_p), [dst] "+&r"(dst) + : [rhs_p] "r"(rhs_p), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), + [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27"); +#else + KAI_ASSERT(false); + KAI_UNUSED(m); + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(lhs_p); + KAI_UNUSED(rhs_p); + KAI_UNUSED(dst); + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(scalar_min); + KAI_UNUSED(scalar_max); +#endif +} diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h new file mode 100644 index 00000000..bd07c802 --- /dev/null +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h @@ -0,0 +1,123 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +#include +#include + +/** + * @brief Function to get the mr value, which must be used to pack the LHS matrix with + * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function + * + * @return the mr value + */ +size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the nr value + */ +size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void); + +/** + * @brief Function to get the kr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the kr value + */ +size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the sr value + */ +size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). It must be a multiple of 64. + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t m_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. + * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be a multiple of 64. + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t n_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the DST matrix + * + * @param[in] m_idx Row index in the DST matrix. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * + * return the DST offset in bytes + */ +size_t +kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride); + +/** + * @brief Function to query the size in bytes for the constant workspace. + * + * @param[in] m Number of rows in the destination (DST) matrix. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] n Number of columns in the destination (DST) matrix. It must be a multiple 4. + */ +size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t m, size_t n); + +/** + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clip (min-max) operation. + * + * LHS matrix: 8-bit quantized asymmeitric per-row (qa8dx) and packed + * RHS matrix: 4-bit quantized symmetric per-channel (qs4cx) and packed. + * Output tile: (rows x cols) = 8 x 4 + * Accumulation performed in a single for loop: 32 + * Instruction used: i8mm + * + * @param[in] m The number of output rows written. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] n The number of output columns written. It must be a multiple of 4. + * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 32. + * @param[in] lhs_p The LHS matrix packed. + * When the activation are dynamically quantized, you can obtain this matrix + * by calling the @ref kai_run_lhs_quant_pack_qa8dsP_f32 micro-kernel which performs + * both the dynamic quantization to 8-bit and activation packing in a single step. + * @param[in] rhs_p The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 + * @param[out] dst Result of the vector-by-matrix + * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. + * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) + * @param[in] scalar_min Min value used to clip the final result. + * @param[in] scalar_max Max value used to clip the final result. + */ +void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.c b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.c new file mode 100644 index 00000000..fd58ab2b --- /dev/null +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.c @@ -0,0 +1,582 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h" + +#include "../../kai_common.h" + +#include +#include + +static const size_t kai_mr = 4; +static const size_t kai_nr = 8; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_k0 = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); + +size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_mr) == 0); + const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); + + return (m_idx / kai_mr) * lhs_packed_stride; +} + +size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((k % kai_k0) == 0); + KAI_ASSERT((n_idx % kai_nr) == 0); + + const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); + + return (n_idx / kai_nr) * rhs_packed_stride; +} + +size_t +kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_mr) == 0); + KAI_ASSERT((n_idx % kai_nr) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t m, size_t n) { + KAI_ASSERT((n % kai_nr) == 0); + + return m * n * sizeof(float); +} + +void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max) { +#if defined(__ARM_FEATURE_MATMUL_INT8) + KAI_ASSERT(n % kai_nr == 0); + KAI_ASSERT(k % kai_k0 == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + size_t num_blocks = k / 32; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__("mov x27, %x[m]\n" + "mov x26, #0x80\n" + "movi v5.16b, #0xf0\n" + "mov x20, #0x20\n" + "cmp x27, #0x8\n" + "madd x26, %x[num_blocks], x26, x20\n" + "blt 4f\n" + "1:" // Row loop + "mov x25, %x[rhs_p]\n" + "mov x23, %x[n]\n" + "add x22, %x[dst], %x[dst_stride_row], LSL #3\n" + "2:" // Column loop + "mov x24, %x[lhs_p]\n" + "movi v8.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "mov x21, %x[num_blocks]\n" + "movi v11.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "movi v31.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v13.4s, #0x0\n" + "movi v15.4s, #0x0\n" + "add x20, x24, x26\n" + "movi v0.4s, #0x0\n" + "movi v22.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "movi v16.4s, #0x0\n" + "movi v21.4s, #0x0\n" + "3:" // Block loop + "ldr q12, [x25, #0x0]\n" + "ldr q10, [x25, #0x10]\n" + "subs x21, x21, #0x1\n" + "ldr q7, [x25, #0x20]\n" + "ldr q28, [x25, #0x30]\n" + "ldr q24, [x24, #0x0]\n" + "ldr q20, [x24, #0x10]\n" + "ldr q9, [x20, #0x0]\n" + "ldr q2, [x20, #0x10]\n" + "shl v18.16b, v12.16b, #0x4\n" + "shl v17.16b, v10.16b, #0x4\n" + "ldr q1, [x25, #0x40]\n" + "ldr q4, [x25, #0x50]\n" + "shl v14.16b, v7.16b, #0x4\n" + "shl v19.16b, v28.16b, #0x4\n" + "ldr q25, [x25, #0x60]\n" + "and v12.16b, v12.16b, v5.16b\n" + "and v10.16b, v10.16b, v5.16b\n" + ".inst 0x4e92a708 // smmla v8.4s, v24.16b, v18.16b\n" + ".inst 0x4e91a70b // smmla v11.4s, v24.16b, v17.16b\n" + ".inst 0x4e92a69f // smmla v31.4s, v20.16b, v18.16b\n" + "and v7.16b, v7.16b, v5.16b\n" + ".inst 0x4e8ea71b // smmla v27.4s, v24.16b, v14.16b\n" + ".inst 0x4e93a706 // smmla v6.4s, v24.16b, v19.16b\n" + "ldr q24, [x25, #0x70]\n" + "and v28.16b, v28.16b, v5.16b\n" + ".inst 0x4e91a68d // smmla v13.4s, v20.16b, v17.16b\n" + ".inst 0x4e8ea69a // smmla v26.4s, v20.16b, v14.16b\n" + "add x25, x25, #0x80\n" + ".inst 0x4e93a68f // smmla v15.4s, v20.16b, v19.16b\n" + "ldr q20, [x24, #0x20]\n" + ".inst 0x4e92a520 // smmla v0.4s, v9.16b, v18.16b\n" + ".inst 0x4e91a53e // smmla v30.4s, v9.16b, v17.16b\n" + ".inst 0x4e8ea536 // smmla v22.4s, v9.16b, v14.16b\n" + ".inst 0x4e93a53d // smmla v29.4s, v9.16b, v19.16b\n" + "ldr q9, [x24, #0x30]\n" + ".inst 0x4e92a457 // smmla v23.4s, v2.16b, v18.16b\n" + "ldr q18, [x20, #0x20]\n" + ".inst 0x4e91a450 // smmla v16.4s, v2.16b, v17.16b\n" + "ldr q17, [x20, #0x30]\n" + ".inst 0x4e8ea443 // smmla v3.4s, v2.16b, v14.16b\n" + "ldr q14, [x24, #0x40]\n" + ".inst 0x4e93a455 // smmla v21.4s, v2.16b, v19.16b\n" + "ldr q2, [x24, #0x50]\n" + "shl v19.16b, v1.16b, #0x4\n" + "and v1.16b, v1.16b, v5.16b\n" + ".inst 0x4e93a688 // smmla v8.4s, v20.16b, v19.16b\n" + ".inst 0x4e93a53f // smmla v31.4s, v9.16b, v19.16b\n" + ".inst 0x4e93a640 // smmla v0.4s, v18.16b, v19.16b\n" + ".inst 0x4e93a637 // smmla v23.4s, v17.16b, v19.16b\n" + "shl v19.16b, v4.16b, #0x4\n" + "and v4.16b, v4.16b, v5.16b\n" + ".inst 0x4e93a68b // smmla v11.4s, v20.16b, v19.16b\n" + ".inst 0x4e93a52d // smmla v13.4s, v9.16b, v19.16b\n" + ".inst 0x4e93a65e // smmla v30.4s, v18.16b, v19.16b\n" + ".inst 0x4e93a630 // smmla v16.4s, v17.16b, v19.16b\n" + "shl v19.16b, v25.16b, #0x4\n" + ".inst 0x4e8ca5c8 // smmla v8.4s, v14.16b, v12.16b\n" + ".inst 0x4e8ca45f // smmla v31.4s, v2.16b, v12.16b\n" + "and v25.16b, v25.16b, v5.16b\n" + ".inst 0x4e93a69b // smmla v27.4s, v20.16b, v19.16b\n" + ".inst 0x4e93a53a // smmla v26.4s, v9.16b, v19.16b\n" + ".inst 0x4e93a656 // smmla v22.4s, v18.16b, v19.16b\n" + ".inst 0x4e93a623 // smmla v3.4s, v17.16b, v19.16b\n" + "shl v19.16b, v24.16b, #0x4\n" + ".inst 0x4e8aa5cb // smmla v11.4s, v14.16b, v10.16b\n" + ".inst 0x4e8aa44d // smmla v13.4s, v2.16b, v10.16b\n" + "and v24.16b, v24.16b, v5.16b\n" + ".inst 0x4e93a686 // smmla v6.4s, v20.16b, v19.16b\n" + "ldr q20, [x20, #0x40]\n" + ".inst 0x4e93a52f // smmla v15.4s, v9.16b, v19.16b\n" + "ldr q9, [x20, #0x50]\n" + ".inst 0x4e93a65d // smmla v29.4s, v18.16b, v19.16b\n" + "ldr q18, [x24, #0x60]\n" + ".inst 0x4e93a635 // smmla v21.4s, v17.16b, v19.16b\n" + "ldr q19, [x24, #0x70]\n" + "ldr q17, [x20, #0x60]\n" + ".inst 0x4e87a5db // smmla v27.4s, v14.16b, v7.16b\n" + ".inst 0x4e87a45a // smmla v26.4s, v2.16b, v7.16b\n" + "add x24, x24, #0x80\n" + ".inst 0x4e8ca680 // smmla v0.4s, v20.16b, v12.16b\n" + ".inst 0x4e8aa69e // smmla v30.4s, v20.16b, v10.16b\n" + ".inst 0x4e9ca5c6 // smmla v6.4s, v14.16b, v28.16b\n" + "ldr q14, [x20, #0x70]\n" + ".inst 0x4e9ca44f // smmla v15.4s, v2.16b, v28.16b\n" + "add x20, x20, #0x80\n" + ".inst 0x4e87a696 // smmla v22.4s, v20.16b, v7.16b\n" + ".inst 0x4e9ca69d // smmla v29.4s, v20.16b, v28.16b\n" + ".inst 0x4e8ca537 // smmla v23.4s, v9.16b, v12.16b\n" + ".inst 0x4e8aa530 // smmla v16.4s, v9.16b, v10.16b\n" + ".inst 0x4e87a523 // smmla v3.4s, v9.16b, v7.16b\n" + ".inst 0x4e9ca535 // smmla v21.4s, v9.16b, v28.16b\n" + ".inst 0x4e81a648 // smmla v8.4s, v18.16b, v1.16b\n" + ".inst 0x4e84a64b // smmla v11.4s, v18.16b, v4.16b\n" + ".inst 0x4e99a65b // smmla v27.4s, v18.16b, v25.16b\n" + ".inst 0x4e98a646 // smmla v6.4s, v18.16b, v24.16b\n" + ".inst 0x4e81a67f // smmla v31.4s, v19.16b, v1.16b\n" + ".inst 0x4e84a66d // smmla v13.4s, v19.16b, v4.16b\n" + ".inst 0x4e99a67a // smmla v26.4s, v19.16b, v25.16b\n" + ".inst 0x4e98a66f // smmla v15.4s, v19.16b, v24.16b\n" + ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" + ".inst 0x4e84a63e // smmla v30.4s, v17.16b, v4.16b\n" + ".inst 0x4e99a636 // smmla v22.4s, v17.16b, v25.16b\n" + ".inst 0x4e98a63d // smmla v29.4s, v17.16b, v24.16b\n" + ".inst 0x4e81a5d7 // smmla v23.4s, v14.16b, v1.16b\n" + ".inst 0x4e84a5d0 // smmla v16.4s, v14.16b, v4.16b\n" + ".inst 0x4e99a5c3 // smmla v3.4s, v14.16b, v25.16b\n" + ".inst 0x4e98a5d5 // smmla v21.4s, v14.16b, v24.16b\n" + "bgt 3b\n" + "ldr q20, [x25, #0x0]\n" + "ldr q12, [x25, #0x10]\n" + "uzp1 v25.2d, v8.2d, v11.2d\n" + "uzp1 v24.2d, v27.2d, v6.2d\n" + "ldr q19, [x24, #0x0]\n" + "ldr q7, [x25, #0x20]\n" + "uzp2 v9.2d, v8.2d, v11.2d\n" + "uzp2 v6.2d, v27.2d, v6.2d\n" + "ldr q8, [x25, #0x30]\n" + "ldr q10, [x24, #0x10]\n" + "uzp1 v14.2d, v31.2d, v13.2d\n" + "uzp1 v11.2d, v26.2d, v15.2d\n" + "ldr q4, [x20, #0x0]\n" + "ldr q1, [x20, #0x10]\n" + "uzp2 v27.2d, v31.2d, v13.2d\n" + "uzp2 v13.2d, v26.2d, v15.2d\n" + "ld1r { v2.4s }, [%x[clamp_vals]]\n" + "mla v25.4s, v20.4s, v19.s[0]\n" + "mla v24.4s, v12.4s, v19.s[0]\n" + "uzp1 v31.2d, v0.2d, v30.2d\n" + "mla v9.4s, v20.4s, v19.s[1]\n" + "mla v6.4s, v12.4s, v19.s[1]\n" + "uzp1 v15.2d, v22.2d, v29.2d\n" + "add x20, %x[clamp_vals], #0x4\n" + "ld1r { v28.4s }, [x20]\n" + "mla v14.4s, v20.4s, v19.s[2]\n" + "mla v11.4s, v12.4s, v19.s[2]\n" + "uzp2 v0.2d, v0.2d, v30.2d\n" + "uzp2 v29.2d, v22.2d, v29.2d\n" + "mla v27.4s, v20.4s, v19.s[3]\n" + "mla v13.4s, v12.4s, v19.s[3]\n" + "mov x20, %x[dst]\n" + "uzp1 v30.2d, v23.2d, v16.2d\n" + "uzp1 v26.2d, v3.2d, v21.2d\n" + "mla v31.4s, v20.4s, v4.s[0]\n" + "subs x23, x23, #0x8\n" + "scvtf v25.4s, v25.4s\n" + "fmul v19.4s, v7.4s, v10.s[0]\n" + "mla v15.4s, v12.4s, v4.s[0]\n" + "add x25, x25, #0x40\n" + "scvtf v24.4s, v24.4s\n" + "fmul v18.4s, v8.4s, v10.s[0]\n" + "mla v0.4s, v20.4s, v4.s[1]\n" + "add %x[dst], %x[dst], #0x20\n" + "uzp2 v23.2d, v23.2d, v16.2d\n" + "uzp2 v22.2d, v3.2d, v21.2d\n" + "mla v29.4s, v12.4s, v4.s[1]\n" + "scvtf v9.4s, v9.4s\n" + "fmul v17.4s, v7.4s, v10.s[1]\n" + "mla v30.4s, v20.4s, v4.s[2]\n" + "scvtf v6.4s, v6.4s\n" + "fmul v16.4s, v8.4s, v10.s[1]\n" + "mla v26.4s, v12.4s, v4.s[2]\n" + "scvtf v14.4s, v14.4s\n" + "fmul v21.4s, v7.4s, v10.s[2]\n" + "mla v23.4s, v20.4s, v4.s[3]\n" + "scvtf v11.4s, v11.4s\n" + "fmul v20.4s, v8.4s, v10.s[2]\n" + "mla v22.4s, v12.4s, v4.s[3]\n" + "fmul v25.4s, v25.4s, v19.4s\n" + "fmul v24.4s, v24.4s, v18.4s\n" + "scvtf v27.4s, v27.4s\n" + "fmul v19.4s, v7.4s, v10.s[3]\n" + "scvtf v13.4s, v13.4s\n" + "fmul v18.4s, v8.4s, v10.s[3]\n" + "fmul v9.4s, v9.4s, v17.4s\n" + "fmul v6.4s, v6.4s, v16.4s\n" + "scvtf v31.4s, v31.4s\n" + "fmul v17.4s, v7.4s, v1.s[0]\n" + "scvtf v15.4s, v15.4s\n" + "fmul v16.4s, v8.4s, v1.s[0]\n" + "fmul v14.4s, v14.4s, v21.4s\n" + "fmul v11.4s, v11.4s, v20.4s\n" + "scvtf v0.4s, v0.4s\n" + "fmul v21.4s, v7.4s, v1.s[1]\n" + "scvtf v29.4s, v29.4s\n" + "fmul v20.4s, v8.4s, v1.s[1]\n" + "fmul v27.4s, v27.4s, v19.4s\n" + "fmul v13.4s, v13.4s, v18.4s\n" + "scvtf v30.4s, v30.4s\n" + "fmul v19.4s, v7.4s, v1.s[2]\n" + "scvtf v26.4s, v26.4s\n" + "fmul v18.4s, v8.4s, v1.s[2]\n" + "fmax v25.4s, v25.4s, v2.4s\n" + "fmax v24.4s, v24.4s, v2.4s\n" + "fmul v31.4s, v31.4s, v17.4s\n" + "fmul v15.4s, v15.4s, v16.4s\n" + "scvtf v23.4s, v23.4s\n" + "fmul v17.4s, v7.4s, v1.s[3]\n" + "scvtf v22.4s, v22.4s\n" + "fmul v16.4s, v8.4s, v1.s[3]\n" + "fmax v9.4s, v9.4s, v2.4s\n" + "fmax v6.4s, v6.4s, v2.4s\n" + "fmul v0.4s, v0.4s, v21.4s\n" + "fmul v29.4s, v29.4s, v20.4s\n" + "fmax v14.4s, v14.4s, v2.4s\n" + "fmax v11.4s, v11.4s, v2.4s\n" + "fmul v30.4s, v30.4s, v19.4s\n" + "fmul v26.4s, v26.4s, v18.4s\n" + "fmin v25.4s, v25.4s, v28.4s\n" + "fmin v24.4s, v24.4s, v28.4s\n" + "fmax v27.4s, v27.4s, v2.4s\n" + "fmax v13.4s, v13.4s, v2.4s\n" + "fmul v23.4s, v23.4s, v17.4s\n" + "fmul v22.4s, v22.4s, v16.4s\n" + "str q25, [x20, #0x0]\n" + "str q24, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "fmin v9.4s, v9.4s, v28.4s\n" + "fmin v6.4s, v6.4s, v28.4s\n" + "fmax v31.4s, v31.4s, v2.4s\n" + "fmax v15.4s, v15.4s, v2.4s\n" + "fmin v14.4s, v14.4s, v28.4s\n" + "fmin v11.4s, v11.4s, v28.4s\n" + "str q9, [x20, #0x0]\n" + "str q6, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "fmax v0.4s, v0.4s, v2.4s\n" + "fmax v29.4s, v29.4s, v2.4s\n" + "fmin v27.4s, v27.4s, v28.4s\n" + "fmin v13.4s, v13.4s, v28.4s\n" + "str q14, [x20, #0x0]\n" + "str q11, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "fmax v30.4s, v30.4s, v2.4s\n" + "fmax v26.4s, v26.4s, v2.4s\n" + "fmin v31.4s, v31.4s, v28.4s\n" + "fmin v15.4s, v15.4s, v28.4s\n" + "str q27, [x20, #0x0]\n" + "str q13, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "fmax v23.4s, v23.4s, v2.4s\n" + "fmax v22.4s, v22.4s, v2.4s\n" + "fmin v0.4s, v0.4s, v28.4s\n" + "fmin v29.4s, v29.4s, v28.4s\n" + "str q31, [x20, #0x0]\n" + "str q15, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "fmin v30.4s, v30.4s, v28.4s\n" + "fmin v26.4s, v26.4s, v28.4s\n" + "fmin v23.4s, v23.4s, v28.4s\n" + "fmin v22.4s, v22.4s, v28.4s\n" + "str q0, [x20, #0x0]\n" + "str q29, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q30, [x20, #0x0]\n" + "str q26, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q23, [x20, #0x0]\n" + "str q22, [x20, #0x10]\n" + "bne 2b\n" + "mov x20, #0x2\n" + "sub x27, x27, #0x8\n" + "cmp x27, #0x8\n" + "mov %x[dst], x22\n" + "madd %x[lhs_p], x20, x26, %x[lhs_p]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x27, 9f\n" + "5:" // Row tail: Row loop + "mov x23, %x[rhs_p]\n" + "mov x22, %x[n]\n" + "add x21, %x[dst], %x[dst_stride_row], LSL #2\n" + "6:" // Row tail: Column loop + "movi v8.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "mov x24, %x[lhs_p]\n" + "mov x20, %x[num_blocks]\n" + "movi v11.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "movi v31.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v13.4s, #0x0\n" + "movi v15.4s, #0x0\n" + "7:" // Row tail: Block loop + "ldr q4, [x23, #0x0]\n" + "ldr q10, [x23, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q2, [x23, #0x20]\n" + "ldr q1, [x23, #0x30]\n" + "ldr q0, [x24, #0x0]\n" + "ldr q12, [x24, #0x10]\n" + "ldr q30, [x23, #0x40]\n" + "ldr q29, [x23, #0x50]\n" + "shl v19.16b, v4.16b, #0x4\n" + "shl v18.16b, v10.16b, #0x4\n" + "ldr q3, [x23, #0x60]\n" + "ldr q28, [x23, #0x70]\n" + "shl v17.16b, v2.16b, #0x4\n" + "shl v16.16b, v1.16b, #0x4\n" + "ldr q25, [x24, #0x20]\n" + "ldr q24, [x24, #0x30]\n" + "and v4.16b, v4.16b, v5.16b\n" + "and v10.16b, v10.16b, v5.16b\n" + "ldr q23, [x24, #0x40]\n" + "ldr q22, [x24, #0x50]\n" + ".inst 0x4e93a408 // smmla v8.4s, v0.16b, v19.16b\n" + ".inst 0x4e92a40b // smmla v11.4s, v0.16b, v18.16b\n" + "ldr q21, [x24, #0x60]\n" + "ldr q20, [x24, #0x70]\n" + ".inst 0x4e91a41b // smmla v27.4s, v0.16b, v17.16b\n" + ".inst 0x4e90a406 // smmla v6.4s, v0.16b, v16.16b\n" + ".inst 0x4e93a59f // smmla v31.4s, v12.16b, v19.16b\n" + ".inst 0x4e92a58d // smmla v13.4s, v12.16b, v18.16b\n" + "shl v19.16b, v30.16b, #0x4\n" + "add x24, x24, #0x80\n" + ".inst 0x4e91a59a // smmla v26.4s, v12.16b, v17.16b\n" + ".inst 0x4e90a58f // smmla v15.4s, v12.16b, v16.16b\n" + "shl v18.16b, v29.16b, #0x4\n" + "add x23, x23, #0x80\n" + "shl v17.16b, v3.16b, #0x4\n" + "shl v16.16b, v28.16b, #0x4\n" + ".inst 0x4e93a728 // smmla v8.4s, v25.16b, v19.16b\n" + "and v2.16b, v2.16b, v5.16b\n" + "and v1.16b, v1.16b, v5.16b\n" + ".inst 0x4e92a72b // smmla v11.4s, v25.16b, v18.16b\n" + ".inst 0x4e93a71f // smmla v31.4s, v24.16b, v19.16b\n" + ".inst 0x4e92a70d // smmla v13.4s, v24.16b, v18.16b\n" + "and v30.16b, v30.16b, v5.16b\n" + ".inst 0x4e91a73b // smmla v27.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a726 // smmla v6.4s, v25.16b, v16.16b\n" + "and v29.16b, v29.16b, v5.16b\n" + ".inst 0x4e91a71a // smmla v26.4s, v24.16b, v17.16b\n" + ".inst 0x4e90a70f // smmla v15.4s, v24.16b, v16.16b\n" + "and v3.16b, v3.16b, v5.16b\n" + ".inst 0x4e84a6e8 // smmla v8.4s, v23.16b, v4.16b\n" + ".inst 0x4e8aa6eb // smmla v11.4s, v23.16b, v10.16b\n" + "and v28.16b, v28.16b, v5.16b\n" + ".inst 0x4e84a6df // smmla v31.4s, v22.16b, v4.16b\n" + ".inst 0x4e8aa6cd // smmla v13.4s, v22.16b, v10.16b\n" + ".inst 0x4e82a6fb // smmla v27.4s, v23.16b, v2.16b\n" + ".inst 0x4e81a6e6 // smmla v6.4s, v23.16b, v1.16b\n" + ".inst 0x4e82a6da // smmla v26.4s, v22.16b, v2.16b\n" + ".inst 0x4e81a6cf // smmla v15.4s, v22.16b, v1.16b\n" + ".inst 0x4e9ea6a8 // smmla v8.4s, v21.16b, v30.16b\n" + ".inst 0x4e9da6ab // smmla v11.4s, v21.16b, v29.16b\n" + ".inst 0x4e9ea69f // smmla v31.4s, v20.16b, v30.16b\n" + ".inst 0x4e9da68d // smmla v13.4s, v20.16b, v29.16b\n" + ".inst 0x4e83a6bb // smmla v27.4s, v21.16b, v3.16b\n" + ".inst 0x4e9ca6a6 // smmla v6.4s, v21.16b, v28.16b\n" + ".inst 0x4e83a69a // smmla v26.4s, v20.16b, v3.16b\n" + ".inst 0x4e9ca68f // smmla v15.4s, v20.16b, v28.16b\n" + "bgt 7b\n" + "ldr q21, [x23, #0x0]\n" + "ldr q19, [x23, #0x10]\n" + "uzp1 v2.2d, v8.2d, v11.2d\n" + "uzp1 v1.2d, v27.2d, v6.2d\n" + "ldr q18, [x24, #0x0]\n" + "ldr q17, [x23, #0x20]\n" + "uzp2 v0.2d, v8.2d, v11.2d\n" + "uzp2 v12.2d, v27.2d, v6.2d\n" + "ldr q30, [x23, #0x30]\n" + "ldr q16, [x24, #0x10]\n" + "uzp1 v29.2d, v31.2d, v13.2d\n" + "uzp1 v28.2d, v26.2d, v15.2d\n" + "ld1r { v27.4s }, [%x[clamp_vals]]\n" + "uzp2 v20.2d, v31.2d, v13.2d\n" + "uzp2 v25.2d, v26.2d, v15.2d\n" + "add x20, %x[clamp_vals], #0x4\n" + "ld1r { v24.4s }, [x20]\n" + "mla v2.4s, v21.4s, v18.s[0]\n" + "mla v1.4s, v19.4s, v18.s[0]\n" + "mov x20, %x[dst]\n" + "mla v0.4s, v21.4s, v18.s[1]\n" + "mla v12.4s, v19.4s, v18.s[1]\n" + "fmul v23.4s, v17.4s, v16.s[0]\n" + "cmp x27, #0x1\n" + "mla v29.4s, v21.4s, v18.s[2]\n" + "mla v28.4s, v19.4s, v18.s[2]\n" + "fmul v22.4s, v30.4s, v16.s[0]\n" + "add x23, x23, #0x40\n" + "mla v20.4s, v21.4s, v18.s[3]\n" + "mla v25.4s, v19.4s, v18.s[3]\n" + "fmul v21.4s, v17.4s, v16.s[1]\n" + "scvtf v2.4s, v2.4s\n" + "scvtf v1.4s, v1.4s\n" + "scvtf v0.4s, v0.4s\n" + "scvtf v12.4s, v12.4s\n" + "fmul v8.4s, v30.4s, v16.s[1]\n" + "scvtf v29.4s, v29.4s\n" + "fmul v19.4s, v17.4s, v16.s[2]\n" + "scvtf v28.4s, v28.4s\n" + "fmul v18.4s, v30.4s, v16.s[2]\n" + "scvtf v20.4s, v20.4s\n" + "fmul v17.4s, v17.4s, v16.s[3]\n" + "scvtf v25.4s, v25.4s\n" + "fmul v16.4s, v30.4s, v16.s[3]\n" + "fmul v2.4s, v2.4s, v23.4s\n" + "fmul v1.4s, v1.4s, v22.4s\n" + "fmul v0.4s, v0.4s, v21.4s\n" + "fmul v12.4s, v12.4s, v8.4s\n" + "fmul v29.4s, v29.4s, v19.4s\n" + "fmul v28.4s, v28.4s, v18.4s\n" + "fmul v20.4s, v20.4s, v17.4s\n" + "fmul v25.4s, v25.4s, v16.4s\n" + "fmax v2.4s, v2.4s, v27.4s\n" + "fmax v1.4s, v1.4s, v27.4s\n" + "fmax v0.4s, v0.4s, v27.4s\n" + "fmax v12.4s, v12.4s, v27.4s\n" + "fmax v29.4s, v29.4s, v27.4s\n" + "fmax v28.4s, v28.4s, v27.4s\n" + "fmax v20.4s, v20.4s, v27.4s\n" + "fmax v25.4s, v25.4s, v27.4s\n" + "fmin v2.4s, v2.4s, v24.4s\n" + "fmin v1.4s, v1.4s, v24.4s\n" + "fmin v0.4s, v0.4s, v24.4s\n" + "fmin v12.4s, v12.4s, v24.4s\n" + "fmin v29.4s, v29.4s, v24.4s\n" + "fmin v28.4s, v28.4s, v24.4s\n" + "fmin v20.4s, v20.4s, v24.4s\n" + "str q2, [x20, #0x0]\n" + "fmin v25.4s, v25.4s, v24.4s\n" + "str q1, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 8f\n" + "cmp x27, #0x2\n" + "str q0, [x20, #0x0]\n" + "str q12, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 8f\n" + "cmp x27, #0x3\n" + "str q29, [x20, #0x0]\n" + "str q28, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 8f\n" + "str q20, [x20, #0x0]\n" + "str q25, [x20, #0x10]\n" + "8:" // Row tail: Accumulator store skip + "subs x22, x22, #0x8\n" + "add %x[dst], %x[dst], #0x20\n" + "bne 6b\n" + "subs x27, x27, #0x4\n" + "add %x[lhs_p], %x[lhs_p], x26\n" + "mov %x[dst], x21\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [lhs_p] "+&r"(lhs_p), [dst] "+&r"(dst) + : [rhs_p] "r"(rhs_p), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), + [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", + "x26", "x27"); +#else + KAI_ASSERT(false); + KAI_UNUSED(m); + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(lhs_p); + KAI_UNUSED(rhs_p); + KAI_UNUSED(dst); + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(scalar_min); + KAI_UNUSED(scalar_max); +#endif +} diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h new file mode 100644 index 00000000..88629a02 --- /dev/null +++ b/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h @@ -0,0 +1,123 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +#include +#include + +/** + * @brief Function to get the mr value, which must be used to pack the LHS matrix with + * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function + * + * @return the mr value + */ +size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the nr value + */ +size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void); + +/** + * @brief Function to get the kr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the kr value + */ +size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * + * @return the sr value + */ +size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). It must be a multiple of 64. + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t m_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. + * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be a multiple of 64. + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t n_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the DST matrix + * + * @param[in] m_idx Row index in the DST matrix. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * + * return the DST offset in bytes + */ +size_t +kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride); + +/** + * @brief Function to query the size in bytes for the constant workspace. + * + * @param[in] m Number of rows in the destination (DST) matrix. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] n Number of columns in the destination (DST) matrix. It must be a multiple 4. + */ +size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t m, size_t n); + +/** + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clip (min-max) operation. + * + * LHS matrix: 8-bit quantized asymmeitric per-row (qa8dx) and packed + * RHS matrix: 4-bit quantized symmetric per-channel (qs4cx) and packed. + * Output tile: (rows x cols) = 8 x 4 + * Accumulation performed in a single for loop: 32 + * Instruction used: i8mm + * + * @param[in] m The number of output rows written. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] n The number of output columns written. It must be a multiple of 4. + * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 32. + * @param[in] lhs_p The LHS matrix packed. + * When the activation are dynamically quantized, you can obtain this matrix + * by calling the @ref kai_run_lhs_quant_pack_qa8dsP_f32 micro-kernel which performs + * both the dynamic quantization to 8-bit and activation packing in a single step. + * @param[in] rhs_p The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 + * @param[out] dst Result of the vector-by-matrix + * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. + * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) + * @param[in] scalar_min Min value used to clip the final result. + * @param[in] scalar_max Max value used to clip the final result. + */ +void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif -- GitLab From 15002daaffd93fd45c7c4ef0e674ba65c0af2649 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Sun, 12 May 2024 18:53:37 +0100 Subject: [PATCH 03/14] Refactor file/function names - Refactor all file/functions names to have only lower-case alphanumerical values - Add interface file for the micro-kernel - Update the README.md Signed-off-by: Gian Marco Iodice --- README.md | 37 +-- .../CMakeLists.txt | 41 ++++ .../matmul_clamp_f32_qai8dxp_qsu4cxp.cpp} | 225 +++++++++--------- .../matmul_f32_qa8dx_qs4cx/CMakeLists.txt | 40 ---- src/kai_common.h | 14 +- ...f32.c => kai_lhs_quant_pack_qai8dxp_f32.c} | 22 +- ...f32.h => kai_lhs_quant_pack_qai8dxp_f32.h} | 20 +- ... => kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c} | 19 +- ... => kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h} | 29 ++- ...i8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c} | 58 +++-- ...i8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h} | 73 +++--- ...i8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.c} | 58 +++-- ...i8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h} | 75 +++--- ..._qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c} | 68 +++--- ..._qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h} | 75 +++--- ..._qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.c} | 62 +++-- ..._qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h} | 75 +++--- ..._qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c} | 74 +++--- ..._qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h} | 75 +++--- ..._qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c} | 74 +++--- ..._qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h} | 77 +++--- ...tmul_clamp_f32_qai8dxp_qsu4cxp_interface.h | 51 ++++ 22 files changed, 798 insertions(+), 544 deletions(-) create mode 100644 examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt rename examples/{matmul_f32_qa8dx_qs4cx/matmul_f32_qa8dx_qs4cx.cpp => matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp} (62%) delete mode 100644 examples/matmul_f32_qa8dx_qs4cx/CMakeLists.txt rename src/matmul/{kai_lhs_quant_pack_qa8dxP_f32.c => kai_lhs_quant_pack_qai8dxp_f32.c} (87%) rename src/matmul/{kai_lhs_quant_pack_qa8dxP_f32.h => kai_lhs_quant_pack_qai8dxp_f32.h} (80%) rename src/matmul/{kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c => kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c} (88%) rename src/matmul/{kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h => kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h} (79%) rename src/matmul/{matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c => matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c} (80%) rename src/matmul/{matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h => matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h} (58%) rename src/matmul/{matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.c => matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.c} (81%) rename src/matmul/{matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h => matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h} (58%) rename src/matmul/{matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c => matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c} (78%) rename src/matmul/{matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h => matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h} (58%) rename src/matmul/{matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.c => matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.c} (88%) rename src/matmul/{matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h => matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h} (58%) rename src/matmul/{matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.c => matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c} (88%) rename src/matmul/{matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h => matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h} (58%) rename src/matmul/{matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.c => matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c} (92%) rename src/matmul/{matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h => matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h} (57%) create mode 100644 src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h diff --git a/README.md b/README.md index 48d62e3c..be2008d4 100644 --- a/README.md +++ b/README.md @@ -44,11 +44,18 @@ A micro-kernel exists for different Arm® architectures, technologies, and compu Some of the key features of KleidiAI are the following: - No dependencies on external libraries +<<<<<<< HEAD - No internal memory allocation - No internal threading mechanisms - Stateless, stable, and consistent API +======= +- No dynamic memory allocation +- No memory management​ +- No scheduling +- Stateless, stable, and consistent API​ +>>>>>>> Refactor file/function names - Performance-critical compute-bound and memory-bound micro-kernels -- Specialized micro-kernels for different Arm® CPU architectures and technologies +- Specialized micro-kernels utilizing different Arm® CPU architectural features (for example, FEAT_DotProd and FEAT_I8MM) - Specialized micro-kernels for different fusion patterns - Micro-kernel as a standalone library, consisting of only a .c and .h files @@ -58,8 +65,8 @@ Some of the key features of KleidiAI are the following: Arm® Neon™ -- dotprod (Armv8.2-A onwards) -- i8mm (Armv8.6-A onwards) +- FEAT_DotProd is optional in Armv8.2-A and mandatory in Armv8.4-A +- FEAT_I8MM is optional in Armv8.2-A and mandatory in Armv8.6-A

Filename convention

@@ -67,14 +74,14 @@ The `src/` directory is the home for all micro-kernels. The micro-kernels are gr Inside the operator directory, you can find: -- *The common micro-kernels*, which are micro-kernels necessary for the correct functioning of the micro-kernels. For example, some of these may be required for packing the input tensors. +- *The common micro-kernels*, which are helper micro-kernels necessary for the correct functioning of the main ones. For example, some of these may be required for packing the input tensors. - *The micro-kernels* files, which are held in separate sub-directories. The name of the micro-kernel folder provides the description of the operation performed and the data type of the destination and source tensors. The general syntax for the micro-kernel folder is as follows: -`____...` +`____...` -All .c and .h pair files in that folders are micro-kernel variants. The variants are differentiated by specifying the computational paramaters (for example, the block size), the Arm® technology, and Arm® architecture feature exploited. The general syntax for the micro-kernel variant is as follows: +All .c and .h pair files in that folder are micro-kernel variants. The variants are differentiated by specifying the computational paramaters (for example, the block size), the Arm® technology (for example, Arm® Neon™), and Arm® architecture feature exploited (for example, FEAT_DotProd). The general syntax for the micro-kernel variant is as follows: `kai____.c/.h` @@ -91,10 +98,10 @@ Some of the data types currently supported with the KleidiAI library are the fol | Data type | Abbreviation | Notes | | ----------- | ----------- | ----------- | | Floating-point 32-bit | f32 | | -| Quantized (q) Symmetric (s) 4-bit (4) Per-Channel (cx) quantization parameters | qs4cx | An fp32 multiplier shared among all values of the same channel. `x` denotes the entirety of the channel | -| Quantized (q) Asymmetric (a) 8-bit (8) Per-Dimension (dx) (for example, Per-Row) quantization parameters | qa8dx | An fp32 multiplier and a int32 zero offset shared among all values of the same dimension. | +| Quantized (q) Symmetric (s) Unsigned (u) 4-bit (4) Per-Channel (cx) quantization parameters | qsu4cx | An fp32 multiplier shared among all values of the same channel. `x` denotes the entirety of the channel | +| Quantized (q) Asymmetric (a) Signed (i) 8-bit (8) Per-Dimension (dx) (for example, Per-Row) quantization parameters | qai8dx | An fp32 multiplier and a int32 zero offset shared among all values of the same dimension. | -> ℹ️ In some cases, we may append the letter `P` to the data type to specify that the tensor is expected to be packed. A packed tensor is a tensor that has been rearranged in our preferred data layout from the original data layout to improve the performance of the micro-kernel. In addition to the letter `P`, we may append other upper-case alphanumerical values to specify the attributes of the data packing (for example, the block packing size). +> ℹ️ In some cases, we may append the letter `p` to the data type to specify that the tensor is expected to be packed. A packed tensor is a tensor that has been rearranged in our preferred data layout from the original data layout to improve the performance of the micro-kernel. In addition to the letter `p`, we may append other alphanumerical values to specify the attributes of the data packing (for example, the block packing size).

Supported micro-kernels

@@ -108,26 +115,26 @@ Some of the data types currently supported with the KleidiAI library are the fol Matrix-multiplication with LHS packed and RHS packed matrices - matmul_clip_f32_qa8dxP_qs4cxP + matmul_clamp_f32_qai8dxp_qsu4cxp - LHS: qa8dxP
- RHS: qs4cxP
+ LHS: qai8dxp
+ RHS: qsu4cxp
DST: f32
TensorFlow Lite
- The packing function for the RHS matrix is available in the `kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c/.h` files.
+ The packing function for the RHS matrix is available in the `kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c/.h` files.
Since the RHS matrix often contains constant values, we recommend packing the RHS matrix only once and freeing the content of the original RHS matrix.
Dynamic quantization and LHS matrix packing - kai_lhs_quant_pack_qa8dxP1X8_f32, kai_lhs_quant_pack_qa8dxP4X8_f32 + kai_lhs_quant_pack_qai8dxp_f32 SRC: f32
- DST: qa8cx
+ DST: qai8cx
TensorFlow Lite
diff --git a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt new file mode 100644 index 00000000..c4a79437 --- /dev/null +++ b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt @@ -0,0 +1,41 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.16) + +# KleidiAI include directories +include_directories( + ../../src/ + ../../src/matmul/ + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") + +# Project name +project(matmul_clamp_f32_qai8dxp_qsu4cxp) + +# Files requires to build the executable +add_executable(matmul_clamp_f32_qai8dxp_qsu4cxp + matmul_clamp_f32_qai8dxp_qsu4cxp.cpp + ../../src/kai_common.h + ../../src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h + ../../src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c + ../../src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h + ../../src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c) + diff --git a/examples/matmul_f32_qa8dx_qs4cx/matmul_f32_qa8dx_qs4cx.cpp b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp similarity index 62% rename from examples/matmul_f32_qa8dx_qs4cx/matmul_f32_qa8dx_qs4cx.cpp rename to examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp index 2110fe46..586f0296 100644 --- a/examples/matmul_f32_qa8dx_qs4cx/matmul_f32_qa8dx_qs4cx.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp @@ -5,14 +5,15 @@ // // Include micro-kernel variants -#include "kai_lhs_quant_pack_qa8dxP_f32.h" -#include "kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h" -#include "kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h" -#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h" -#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h" -#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h" -#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h" -#include "kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h" +#include "kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h" +#include "kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h" #include #include @@ -24,99 +25,85 @@ #define INT4_MIN (-8) #define INT4_MAX (7) -// All micro-kernels variants of the same type share the same interfaces -// In this case, the micro-kernel type is: matmul_clip_f32_qa8dxP_qs4cxP - -// Micro-kernel helper interfaces ("get" methods) -typedef size_t (*kai_get_mr_func_t)(void); -typedef size_t (*kai_get_nr_func_t)(void); -typedef size_t (*kai_get_kr_func_t)(void); -typedef size_t (*kai_get_sr_func_t)(void); -typedef size_t (*kai_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); -typedef size_t (*kai_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); -typedef size_t (*kai_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); -typedef size_t (*kai_get_dst_size_func_t)(size_t m, size_t n); - -// Micro-kernel core interface ("run" method) -typedef void (*kai_run_matmul_func_t)( - size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, - size_t dst_stride_col, float scalar_min, float scalar_max); - // Micro-kernel interface struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { - kai_get_mr_func_t get_mr = nullptr; - kai_get_nr_func_t get_nr = nullptr; - kai_get_nr_func_t get_kr = nullptr; - kai_get_sr_func_t get_sr = nullptr; - kai_get_lhs_packed_offset_func_t get_lhs_packed_offset = nullptr; - kai_get_rhs_packed_offset_func_t get_rhs_packed_offset = nullptr; - kai_get_dst_offset_func_t get_dst_offset = nullptr; - kai_get_dst_size_func_t get_dst_size = nullptr; - kai_run_matmul_func_t run_matmul = nullptr; + kai_matmul_clamp_f32_qai8dxp_qsu4cxp_ukernel ukernel; std::string name = {}; }; kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { - {kai_get_mr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, - kai_get_nr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, - kai_get_kr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, - kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, - kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, - kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, - kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod, - "matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod"}, - {kai_get_mr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, - kai_get_nr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, - kai_get_kr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, - kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, - kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, - kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, - kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod, - "matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod"}, - {kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, - kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, - kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, - kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, - kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, - kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, - kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm, - "matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm"}, - {kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, - kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, - kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, - kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, - kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, - kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, - kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm, - "matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm"}, - {kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, - kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, - kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, - kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, - kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, - kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm, - "matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm"}, - {kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, - kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, - kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, - kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, - kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, - kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm, - "matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm"}, }; // Number of micro-kernel variants stored in the array @@ -343,7 +330,6 @@ static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, } int main(int argc, char** argv) { - const size_t m = 17; const size_t n = 32; // It must be a multiple of 8 const size_t k = 64; // It must be a multiple of 64 @@ -362,7 +348,6 @@ int main(int argc, char** argv) { uint8_t* rhs_scales_f32 = new uint8_t[rhs_scales_size_f32]; fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); - fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); quant_qs4cx_f32(n, k, (const float*)rhs_native_mtx_f32, (uint8_t*)rhs_native_mtx_qs4cx, (float*)rhs_scales_f32); @@ -401,29 +386,30 @@ int main(int argc, char** argv) { std::cout << "Testing " << ukernel_variants[idx_variant].name << std::endl; - const size_t mr = ukernel_variants[idx_variant].get_mr(); - const size_t nr = ukernel_variants[idx_variant].get_nr(); - const size_t kr = ukernel_variants[idx_variant].get_kr(); - const size_t sr = ukernel_variants[idx_variant].get_sr(); + // Get the packing parameters + const size_t mr = ukernel_variants[idx_variant].ukernel.get_mr(); + const size_t nr = ukernel_variants[idx_variant].ukernel.get_nr(); + const size_t kr = ukernel_variants[idx_variant].ukernel.get_kr(); + const size_t sr = ukernel_variants[idx_variant].ukernel.get_sr(); - const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP_f32(m, k, mr, kr); - const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qs4cxP_qs4cxS1S0(n, k, nr, kr); - const size_t dst_size = ukernel_variants[idx_variant].get_dst_size(m, n); + // Get the size in bytes for the packed matrices + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr); + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(n, k, nr, kr); + const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); + // Allocate the matrices uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; - // LHS packing - kai_run_lhs_quant_pack_qa8dxP_f32( - m, k, mr, kr, sr, (const float*)lhs_native_mtx_f32, k * sizeof(float), lhs_packed_mtx_qa8dx); - - struct kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0_params params; + // If the RHS matrix contains constant values, the packing can be performed + // only once + struct kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0_params params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; // RHS packing - kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0( + kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( 1, n, k, nr, kr, sr, (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS NULL, // Bias @@ -431,17 +417,22 @@ int main(int argc, char** argv) { rhs_packed_mtx_qs4cx, // DST 0, ¶ms); + // LHS packing + kai_run_lhs_quant_pack_qai8dxp_f32( + m, k, mr, kr, sr, (const float*)lhs_native_mtx_f32, k * sizeof(float), lhs_packed_mtx_qa8dx); + + // Matmul { const size_t dst_stride = n * sizeof(float); + const size_t lhs_offset = ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k); + const size_t rhs_offset = ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(0, k); + const size_t dst_offset = ukernel_variants[idx_variant].ukernel.get_dst_offset(0, 0, dst_stride); - const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + - ukernel_variants[idx_variant].get_lhs_packed_offset(0, k)); - const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + - ukernel_variants[idx_variant].get_rhs_packed_offset(0, k)); - float* dst_ptr = - (float*)((uint8_t*)dst_act_mtx_f32 + ukernel_variants[idx_variant].get_dst_offset(0, 0, dst_stride)); + const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); + const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + rhs_offset); + float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); - ukernel_variants[idx_variant].run_matmul( + ukernel_variants[idx_variant].ukernel.run_matmul( m, n, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); } diff --git a/examples/matmul_f32_qa8dx_qs4cx/CMakeLists.txt b/examples/matmul_f32_qa8dx_qs4cx/CMakeLists.txt deleted file mode 100644 index 29a24a14..00000000 --- a/examples/matmul_f32_qa8dx_qs4cx/CMakeLists.txt +++ /dev/null @@ -1,40 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -# -# SPDX-License-Identifier: Apache-2.0 -# - -cmake_minimum_required(VERSION 3.16) - -# KleidiAI include directories -include_directories( - ../../src/ - ../../src/matmul/ - ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") - -# Project name -project(matmul_f32_qa8dx_qs4cx) - -# Files requires to build the executable -add_executable(matmul_f32_qa8dx_qs4cx - matmul_f32_qa8dx_qs4cx.cpp - ../../src/kai_common.h - ../../src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h - ../../src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c - ../../src/matmul/kai_lhs_quant_pack_qa8dxP_f32.h - ../../src/matmul/kai_lhs_quant_pack_qa8dxP_f32.c - ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h - ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c - ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h - ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.c - ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h - ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c - ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h - ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.c - ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h - ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.c - ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h - ../../src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.c) - diff --git a/src/kai_common.h b/src/kai_common.h index 01d70824..30c58afa 100644 --- a/src/kai_common.h +++ b/src/kai_common.h @@ -30,6 +30,12 @@ extern "C" { KAI_ERROR(msg); \ } \ } while (0) +#define KAI_ASSERT(x) \ + do { \ + if (!(x)) { \ + exit(EXIT_FAILURE); \ + } \ + } while (0) // NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) #define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond) @@ -49,15 +55,9 @@ extern "C" { #define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b)) inline static size_t kai_roundup(size_t a, size_t b) { - size_t rem = a % b; - if (rem) { - return a + b - rem; - } else { - return a; - } + return ((a + b - 1) / b) * b; } - #ifdef __cplusplus } #endif diff --git a/src/matmul/kai_lhs_quant_pack_qa8dxP_f32.c b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c similarity index 87% rename from src/matmul/kai_lhs_quant_pack_qa8dxP_f32.c rename to src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c index 3f68d8a4..ec3abb57 100644 --- a/src/matmul/kai_lhs_quant_pack_qa8dxP_f32.c +++ b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c @@ -3,9 +3,9 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_lhs_quant_pack_qa8dxP_f32.h" +#include "kai_lhs_quant_pack_qai8dxp_f32.h" -#include "../kai_common.h" +#include "kai_common.h" #include #include @@ -15,26 +15,30 @@ static const size_t kai_num_bytes_per_multiplier = sizeof(float); static const size_t kai_num_bytes_per_offset = sizeof(int32_t); -size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP_f32(size_t m_idx, size_t lhs_stride) { +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_stride) { return m_idx * lhs_stride; } -size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP_f32(size_t m_idx, size_t k, size_t mr, size_t kr) { +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr) { + KAI_ASSERT(k % kr == 0); + KAI_UNUSED(kr); const size_t dst_stride = mr * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); return (m_idx / mr) * dst_stride; } -size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP_f32(size_t m, size_t k, size_t mr, size_t kr) { +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr) { + KAI_ASSERT(k % kr == 0); + KAI_UNUSED(kr); const size_t m_roundup = kai_roundup(m, mr); const size_t t_size = m_roundup * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); return t_size; } -void kai_run_lhs_quant_pack_qa8dxP_f32( +void kai_run_lhs_quant_pack_qai8dxp_f32( size_t m, size_t k, size_t mr, size_t kr, size_t sr, const float* restrict lhs, size_t lhs_stride, - void* restrict lhs_p) { + void* restrict lhs_packed) { KAI_ASSERT(k % kr == 0); KAI_ASSERT((kr % sr) == 0); @@ -112,7 +116,7 @@ void kai_run_lhs_quant_pack_qa8dxP_f32( const size_t dst_x = (row_idx % mr); const size_t dst_y = (row_idx / mr); - uint8_t* dst_ptr = (uint8_t*)lhs_p + dst_y * dst_stride; + uint8_t* dst_ptr = (uint8_t*)lhs_packed + dst_y * dst_stride; dst_ptr += dst_x * k_block_len * sizeof(int8_t); @@ -134,7 +138,7 @@ void kai_run_lhs_quant_pack_qa8dxP_f32( dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); } - dst_ptr = (uint8_t*)lhs_p + dst_y * dst_stride + mr * (k * sizeof(int8_t)); + dst_ptr = (uint8_t*)lhs_packed + dst_y * dst_stride + mr * (k * sizeof(int8_t)); dst_ptr += dst_x * kai_num_bytes_per_offset; diff --git a/src/matmul/kai_lhs_quant_pack_qa8dxP_f32.h b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h similarity index 80% rename from src/matmul/kai_lhs_quant_pack_qa8dxP_f32.h rename to src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h index d27b3773..a13f32a9 100644 --- a/src/matmul/kai_lhs_quant_pack_qa8dxP_f32.h +++ b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h @@ -5,15 +5,13 @@ // #pragma once +#include +#include + #ifdef __cplusplus extern "C" { -#else -#include #endif -#include -#include - /** * @brief Function to calculate the offset in bytes for the LHS matrix (not packed) * @@ -24,7 +22,7 @@ extern "C" { * * return the offset in bytes to the LHS matrix */ -size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP_f32(size_t m_idx, size_t lhs_stride); +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_stride); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -37,7 +35,7 @@ size_t kai_get_lhs_offset_lhs_quant_pack_qa8dxP_f32(size_t m_idx, size_t lhs_str * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP_f32(size_t m_idx, size_t k, size_t mr, size_t kr); +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr); /** * @brief Function to return the memory required for storing the quantized and packed LHS matrix @@ -47,7 +45,7 @@ size_t kai_get_lhs_packed_offset_lhs_quant_pack_qa8dxP_f32(size_t m_idx, size_t * * return the size in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP_f32(size_t m, size_t k, size_t mr, size_t kr); +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr); /** * @brief Micro-kernel to quantize and pack the LHS matrix @@ -56,10 +54,10 @@ size_t kai_get_lhs_packed_size_lhs_quant_pack_qa8dxP_f32(size_t m, size_t k, siz * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. * @param[in] lhs LHS of the vector-by-matrix. * @param[in] lhs_stride Stride in bytes between two rows of LHS. - * @param[out] lhs_p The quantized and packed LHS matrix. + * @param[out] lhs_packed The quantized and packed LHS matrix. */ -void kai_run_lhs_quant_pack_qa8dxP_f32( - size_t m, size_t k, size_t mr, size_t kr, size_t sr, const float* lhs, size_t lhs_stride, void* lhs_p); +void kai_run_lhs_quant_pack_qai8dxp_f32( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, const float* lhs, size_t lhs_stride, void* lhs_packed); #ifdef __cplusplus } diff --git a/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c similarity index 88% rename from src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c rename to src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c index eb42ca11..000771fd 100644 --- a/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.c +++ b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h" +#include "kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h" #include "../kai_common.h" @@ -19,11 +19,11 @@ inline static int8_t kai_int4_sign_extend(int8_t x) { return (x ^ 0x8) - 8; } -size_t kai_get_rhs_offset_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n_idx, size_t rhs_stride) { +size_t kai_get_rhs_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride) { return n_idx * rhs_stride; } -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n_idx, size_t k, size_t nr, size_t kr) { +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n_idx, size_t k, size_t nr, size_t kr) { KAI_ASSERT((k % 2) == 0); KAI_ASSERT((k % kr) == 0); KAI_ASSERT((n_idx % nr) == 0); @@ -35,7 +35,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n_idx, siz return (n_idx / nr) * rhs_packed_stride; } -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n, size_t k, size_t nr, size_t kr) { +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr) { KAI_ASSERT((k % 2) == 0); KAI_ASSERT((k % kr) == 0); KAI_ASSERT((n % nr) == 0); @@ -49,10 +49,11 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n, size_t k, return num_rows * rhs_packed_stride; } -void kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0( +void kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const int32_t* bias, - const float* scale, void* rhs_p, size_t extra_bytes, - const struct kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0_params* params) { + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0_params* params) { + // Temporary asserts KAI_ASSERT(num_groups == 1); KAI_ASSERT((k % 2) == 0); @@ -65,7 +66,7 @@ void kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0( KAI_ASSERT(kr >= 1 && kr <= 16); KAI_ASSERT(rhs != NULL); KAI_ASSERT(scale != NULL); - KAI_ASSERT(rhs_p != NULL); + KAI_ASSERT(rhs_packed != NULL); KAI_ASSERT(params != NULL); KAI_ASSERT(params->rhs_zero_point == 8); KAI_ASSERT(params->lhs_zero_point == 1); @@ -79,7 +80,7 @@ void kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0( for (size_t y = 0; y < n; y += nr) { const uint8_t* src_row = (const uint8_t*)rhs + y * rhs_stride; - uint8_t* dst_row = (uint8_t*)rhs_p + (y / nr) * rhs_packed_stride; + uint8_t* dst_row = (uint8_t*)rhs_packed + (y / nr) * rhs_packed_stride; int32_t* sums = (int32_t*)(dst_row + nr * (k / 2)); diff --git a/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h similarity index 79% rename from src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h rename to src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h index 95b70563..3e3257db 100644 --- a/src/matmul/kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0.h +++ b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h @@ -5,16 +5,14 @@ // #pragma once +#include +#include + #ifdef __cplusplus extern "C" { -#else -#include #endif -#include -#include - -struct kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0_params { +struct kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0_params { int8_t lhs_zero_point; uint8_t rhs_zero_point; }; @@ -30,7 +28,7 @@ struct kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0_params { * * return the offset in bytes to the RHS matrix (not packed) */ -size_t kai_get_rhs_offset_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n_idx, size_t rhs_stride); +size_t kai_get_rhs_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -43,7 +41,7 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n_idx, size_t rhs * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n_idx, size_t k, size_t nr, size_t kr); +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n_idx, size_t k, size_t nr, size_t kr); /** * @brief Function to return the memory required for storing the quantized and packed RHS matrix @@ -55,7 +53,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n_idx, siz * * return the size in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n, size_t k, size_t nr, size_t kr); +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr); /** * @brief Micro-kernel to quantize and pack the RHS matrix. @@ -69,19 +67,20 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qs4cxP_qs4cxS1S0(size_t n, size_t k, * @param[in] k The common dimension between the LHS and RHS matrix (K). * @param[in] nr The number of columns written by the matmul micro-kernel. * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. - * @param[in] sr The number of kr splits. It can be 1 (not splits) up to kr. + * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. * However, kr must be multiple of sr. * @param[in] rhs The RHS matrix containing the 4-bit values. - * Size in bytes is expected to be: n * k * (sizeof(uint8_t) / 2). + * Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). * @param[in] bias The biases. * @param[in] scale The scale for each output channel. - * @param[out] rhs_p The quantized and packed RHS matrix. - * @param[in] extra_bytes Extra bytes to append to the end of each row of the quantized and packed RHS matrix. + * @param[out] rhs_packed The quantized and packed RHS matrix. + * @param[in] extra_bytes Extra bytes to append to the end of each row of the quantized and packed RHS matrix. * @param[in] params Parameters for the function. */ -void kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0( +void kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const int32_t* bias, - const float* scale, void* rhs_p, size_t extra_bytes, const struct kai_rhs_pack_nxk_qs4cxP_qs4cxS1S0_params* params); + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0_params* params); #ifdef __cplusplus } diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c similarity index 80% rename from src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c index 257fea30..44f16198 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c @@ -3,13 +3,15 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h" -#include "../../kai_common.h" +#include "kai_common.h" #include #include +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; static const size_t kai_mr = 1; static const size_t kai_nr = 4; static const size_t kai_kr = 16; @@ -20,54 +22,68 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); -size_t kai_get_mr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void) { return kai_mr; } -size_t kai_get_nr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void) { return kai_nr; } -size_t kai_get_kr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void) { return kai_kr; } -size_t kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); - return (m_idx / kai_mr) * lhs_packed_stride; + return (m_idx / kai_m_step) * lhs_packed_stride; } -size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + // Temporary assert KAI_ASSERT((k % kai_k0) == 0); - KAI_ASSERT((n_idx % kai_nr) == 0); const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - return (n_idx / kai_nr) * rhs_packed_stride; + return (n_idx / kai_n_step) * rhs_packed_stride; } -size_t kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSERT((n_idx % kai_nr) == 0); + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(size_t m, size_t n) { + // Temporary assert KAI_ASSERT((n % kai_nr) == 0); return m * n * sizeof(float); } -void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( - size_t m, size_t n, size_t k, const void* restrict lhs_p, const void* restrict rhs_p, float* restrict dst, +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod( + size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_DOTPROD) + // Temporary asserts KAI_ASSERT(n % kai_nr == 0); KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); @@ -83,11 +99,11 @@ void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( const int8x16_t nibble_mask = vdupq_n_s8(0xF0); - const uint8_t* lhs_ptr_start = lhs_p; + const uint8_t* lhs_ptr_start = lhs_packed; for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { - const uint8_t* rhs_ptr = rhs_p; + const uint8_t* rhs_ptr = rhs_packed; for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { const uint8_t* lhs_ptr = lhs_ptr_start; @@ -189,7 +205,7 @@ void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( main_acc = vmulq_f32(main_acc, lhs_scale); - // Clip (min-max) operation + // clamp (min-max) operation const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); @@ -205,8 +221,8 @@ void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( KAI_UNUSED(m); KAI_UNUSED(n); KAI_UNUSED(k); - KAI_UNUSED(lhs_p); - KAI_UNUSED(rhs_p); + KAI_UNUSED(lhs_packed); + KAI_UNUSED(rhs_packed); KAI_UNUSED(dst); KAI_UNUSED(dst_stride_row); KAI_UNUSED(dst_stride_col); diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h similarity index 58% rename from src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h index 4a2c3e43..291da6a9 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h @@ -5,46 +5,65 @@ // #pragma once -#ifdef __cplusplus -extern "C" { -#else +#ifndef __cplusplus #include #endif - #include #include +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @return the m step value + */ +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void); + +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @return the n step + */ +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void); + /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function * * @return the mr value */ -size_t kai_get_mr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the nr value */ -size_t kai_get_nr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the kr value */ -size_t kai_get_kr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void); /** - * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * @brief Function to get the sr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the sr value */ -size_t kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -57,7 +76,7 @@ size_t kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void); * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -68,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -79,7 +98,7 @@ size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon * * return the DST offset in bytes */ -size_t kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride); /** @@ -88,13 +107,13 @@ size_t kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotpro * @param[in] m Number of rows in the destination (DST) matrix * @param[in] n Number of columns in the destination (DST) matrix */ -size_t kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(size_t m, size_t n); /** - * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clip (min-max) operation. + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * - * LHS matrix: 8-bit quantized asymmeitric per-row (qa8dx) and packed - * RHS matrix: 4-bit quantized symmetric per-channel (qs4cx) and packed. + * LHS matrix: Signed 8-bit quantized asymmitric per-row (qai8dx) and packed + * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsu4cx) and packed. * Output tile: (rows x cols) = 1 x 4 * Accumulation performed in a single for loop: 64 * Instruction used: dotprod @@ -102,20 +121,20 @@ size_t kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod( * @param[in] m The number of output rows written. * @param[in] n The number of output columns written. It must be a multiple of 4. * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 64. - * @param[in] lhs_p The LHS matrix packed. + * @param[in] lhs_packed The LHS matrix packed. * When the activation are dynamically quantized, you can obtain this matrix - * by calling the @ref kai_lhs_quant_pack_qa8dxP1X8_f32 micro-kernel which performs + * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs * both the dynamic quantization to 8-bit and activation packing in a single step. - * @param[in] rhs_p The RHS matrix packed, which is obtained by calling @ref - * kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 + * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 * @param[out] dst Result of the vector-by-matrix * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) - * @param[in] scalar_min Min value used to clip the final result. - * @param[in] scalar_max Max value used to clip the final result. + * @param[in] scalar_min Min value used to clamp the final result. + * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod( - size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); #ifdef __cplusplus diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.c similarity index 81% rename from src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.c index 01b4c377..9dfaac5e 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.c @@ -3,13 +3,15 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h" -#include "../../kai_common.h" +#include "kai_common.h" #include #include +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 8; static const size_t kai_mr = 1; static const size_t kai_nr = 8; static const size_t kai_kr = 16; @@ -20,54 +22,68 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); -size_t kai_get_mr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void) { return kai_mr; } -size_t kai_get_nr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void) { return kai_nr; } -size_t kai_get_kr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void) { return kai_kr; } -size_t kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); - return (m_idx / kai_mr) * lhs_packed_stride; + return (m_idx / kai_m_step) * lhs_packed_stride; } -size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + // Temporary assert KAI_ASSERT((k % kai_k0) == 0); - KAI_ASSERT((n_idx % kai_nr) == 0); const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - return (n_idx / kai_nr) * rhs_packed_stride; + return (n_idx / kai_n_step) * rhs_packed_stride; } -size_t kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSERT((n_idx % kai_nr) == 0); + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(size_t m, size_t n) { + // Temporary assert KAI_ASSERT((n % kai_nr) == 0); return m * n * sizeof(float); } -void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod( - size_t m, size_t n, size_t k, const void* restrict lhs_p, const void* restrict rhs_p, float* restrict dst, +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_DOTPROD) + // Temporary asserts KAI_ASSERT(n % kai_nr == 0); KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); @@ -83,10 +99,10 @@ void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod( const int8x16_t nibble_mask = vdupq_n_s8(0xF0); - const uint8_t* lhs_ptr_start = lhs_p; + const uint8_t* lhs_ptr_start = lhs_packed; for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { - const uint8_t* rhs_ptr = rhs_p; + const uint8_t* rhs_ptr = rhs_packed; for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { const uint8_t* lhs_ptr = lhs_ptr_start; @@ -192,7 +208,7 @@ void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod( main_acc0 = vmulq_f32(main_acc0, lhs_scale); main_acc1 = vmulq_f32(main_acc1, lhs_scale); - // Clip (min-max) operation + // clamp (min-max) operation const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); @@ -212,8 +228,8 @@ void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x8x32_neon_dotprod( KAI_UNUSED(m); KAI_UNUSED(n); KAI_UNUSED(k); - KAI_UNUSED(lhs_p); - KAI_UNUSED(rhs_p); + KAI_UNUSED(lhs_packed); + KAI_UNUSED(rhs_packed); KAI_UNUSED(dst); KAI_UNUSED(dst_stride_row); KAI_UNUSED(dst_stride_col); diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h similarity index 58% rename from src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h index d8524a2b..234a1279 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h @@ -5,46 +5,65 @@ // #pragma once -#ifdef __cplusplus -extern "C" { -#else +#ifndef __cplusplus #include #endif - #include #include +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @return the m step value + */ +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void); + +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @return the n step + */ +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void); + /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function * * @return the mr value */ -size_t kai_get_mr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the nr value */ -size_t kai_get_nr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the kr value */ -size_t kai_get_kr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void); /** - * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * @brief Function to get the sr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the sr value */ -size_t kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -57,7 +76,7 @@ size_t kai_get_sr_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(void); * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -68,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -79,7 +98,7 @@ size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon * * return the DST offset in bytes */ -size_t kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride); /** @@ -88,34 +107,34 @@ size_t kai_get_dst_offset_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotpro * @param[in] m Number of rows in the destination (DST) matrix * @param[in] n Number of columns in the destination (DST) matrix */ -size_t kai_get_dst_size_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(size_t m, size_t n); /** - * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clip (min-max) operation. + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * - * LHS matrix: 8-bit quantized asymmeitric per-row (qa8dx) and packed - * RHS matrix: 4-bit quantized symmetric per-channel (qs4cx) and packed. - * Output tile: (rows x cols) = 1 x 4 + * LHS matrix: Signed 8-bit quantized asymmitric per-row (qai8dx) and packed + * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsu4cx) and packed. + * Output tile: (rows x cols) = 1 x 8 * Accumulation performed in a single for loop: 64 * Instruction used: dotprod * * @param[in] m The number of output rows written. * @param[in] n The number of output columns written. It must be a multiple of 4. * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 64. - * @param[in] lhs_p The LHS matrix packed. + * @param[in] lhs_packed The LHS matrix packed. * When the activation are dynamically quantized, you can obtain this matrix - * by calling the @ref kai_lhs_quant_pack_qa8dxP1X8_f32 micro-kernel which performs + * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs * both the dynamic quantization to 8-bit and activation packing in a single step. - * @param[in] rhs_p The RHS matrix packed, which is obtained by calling @ref - * kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 + * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 * @param[out] dst Result of the vector-by-matrix * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) - * @param[in] scalar_min Min value used to clip the final result. - * @param[in] scalar_max Max value used to clip the final result. + * @param[in] scalar_min Min value used to clamp the final result. + * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clip_f32_qa8dxP1X8_qs4cxP4X8_1x4x64_neon_dotprod( - size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); #ifdef __cplusplus diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c similarity index 78% rename from src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c index b91f0441..1a3359c4 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c @@ -3,13 +3,15 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h" -#include "../../kai_common.h" +#include "kai_common.h" #include #include +static const size_t kai_m_step = 4; +static const size_t kai_n_step = 4; static const size_t kai_mr = 4; static const size_t kai_nr = 4; static const size_t kai_kr = 16; @@ -20,56 +22,68 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); -size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) { return kai_mr; } -size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) { return kai_nr; } -size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) { return kai_kr; } -size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m_idx, size_t k) { - KAI_ASSERT((m_idx % kai_mr) == 0); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); - return (m_idx / kai_mr) * lhs_packed_stride; + return (m_idx / kai_m_step) * lhs_packed_stride; } -size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + // Temporary assert KAI_ASSERT((k % kai_k0) == 0); - KAI_ASSERT((n_idx % kai_nr) == 0); const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - return (n_idx / kai_nr) * rhs_packed_stride; + return (n_idx / kai_n_step) * rhs_packed_stride; } -size_t -kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSERT((m_idx % kai_mr) == 0); - KAI_ASSERT((n_idx % kai_nr) == 0); +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t m, size_t n) { + // Temporary assert KAI_ASSERT((n % kai_nr) == 0); return m * n * sizeof(float); } -void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm( - size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) + // Temporary asserts KAI_ASSERT(n % kai_nr == 0); KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); @@ -89,13 +103,13 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm( "madd x26, %x[num_blocks], x26, x20\n" "cbz x25, 5f\n" "1:" // Row loop - "mov x24, %x[rhs_p]\n" + "mov x24, %x[rhs_packed]\n" "mov x23, %x[n]\n" "add x22, %x[dst], %x[dst_stride_row], LSL #2\n" "2:" // Column loop "movi v3.4s, #0x0\n" "movi v2.4s, #0x0\n" - "mov x21, %x[lhs_p]\n" + "mov x21, %x[lhs_packed]\n" "mov x20, %x[num_blocks]\n" "movi v1.4s, #0x0\n" "movi v0.4s, #0x0\n" @@ -195,13 +209,13 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm( "add %x[dst], %x[dst], #0x10\n" "bne 2b\n" "subs x25, x25, #0x4\n" - "add %x[lhs_p], %x[lhs_p], x26\n" + "add %x[lhs_packed], %x[lhs_packed], x26\n" "mov %x[dst], x22\n" "bgt 1b\n" "5:" // Row loop skip - : [lhs_p] "+&r"(lhs_p), [dst] "+&r"(dst) - : [rhs_p] "r"(rhs_p), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), - [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) + : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), + [num_blocks] "r"(num_blocks), [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26"); @@ -210,8 +224,8 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm( KAI_UNUSED(m); KAI_UNUSED(n); KAI_UNUSED(k); - KAI_UNUSED(lhs_p); - KAI_UNUSED(rhs_p); + KAI_UNUSED(lhs_packed); + KAI_UNUSED(rhs_packed); KAI_UNUSED(dst); KAI_UNUSED(dst_stride_row); KAI_UNUSED(dst_stride_col); diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h similarity index 58% rename from src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h index b4f965fb..387c25a6 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h @@ -5,46 +5,65 @@ // #pragma once -#ifdef __cplusplus -extern "C" { -#else +#ifndef __cplusplus #include #endif - #include #include +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @return the m step value + */ +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); + +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @return the n step + */ +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); + /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function * * @return the mr value */ -size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the nr value */ -size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the kr value */ -size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); /** - * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * @brief Function to get the sr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the sr value */ -size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -57,7 +76,7 @@ size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(void); * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -68,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -79,8 +98,8 @@ size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon * * return the DST offset in bytes */ -size_t -kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride); +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride); /** * @brief Function to query the size in bytes for the constant workspace. @@ -88,13 +107,13 @@ kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m * @param[in] m Number of rows in the destination (DST) matrix. It must be either 1, 2, 3, or any multiple of 4. * @param[in] n Number of columns in the destination (DST) matrix. It must be a multiple 4. */ -size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t m, size_t n); /** - * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clip (min-max) operation. + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * - * LHS matrix: 8-bit quantized asymmeitric per-row (qa8dx) and packed - * RHS matrix: 4-bit quantized symmetric per-channel (qs4cx) and packed. + * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed + * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsu4cx) and packed. * Output tile: (rows x cols) = 4 x 4 * Accumulation performed in a single for loop: 32 * Instruction used: i8mm @@ -102,20 +121,20 @@ size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm(siz * @param[in] m The number of output rows written. It must be either 1, 2, 3, or any multiple of 4. * @param[in] n The number of output columns written. It must be a multiple of 4. * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 32. - * @param[in] lhs_p The LHS matrix packed. + * @param[in] lhs_packed The LHS matrix packed. * When the activation are dynamically quantized, you can obtain this matrix - * by calling the @ref kai_run_lhs_quant_pack_qa8dsP_f32 micro-kernel which performs + * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs * both the dynamic quantization to 8-bit and activation packing in a single step. - * @param[in] rhs_p The RHS matrix packed, which is obtained by calling @ref - * kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 + * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 * @param[out] dst Result of the vector-by-matrix * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) - * @param[in] scalar_min Min value used to clip the final result. - * @param[in] scalar_max Max value used to clip the final result. + * @param[in] scalar_min Min value used to clamp the final result. + * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x4x32_neon_i8mm( - size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); #ifdef __cplusplus diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.c similarity index 88% rename from src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.c index 664e4b88..5405b4d1 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.c @@ -3,13 +3,15 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h" -#include "../../kai_common.h" +#include "kai_common.h" #include #include +static const size_t kai_m_step = 4; +static const size_t kai_n_step = 8; static const size_t kai_mr = 4; static const size_t kai_nr = 8; static const size_t kai_kr = 16; @@ -20,56 +22,68 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); -size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void) { return kai_mr; } -size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void) { return kai_nr; } -size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void) { return kai_kr; } -size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t m_idx, size_t k) { - KAI_ASSERT((m_idx % kai_mr) == 0); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); - return (m_idx / kai_mr) * lhs_packed_stride; + return (m_idx / kai_m_step) * lhs_packed_stride; } -size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + // Temporary assert KAI_ASSERT((k % kai_k0) == 0); - KAI_ASSERT((n_idx % kai_nr) == 0); const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - return (n_idx / kai_nr) * rhs_packed_stride; + return (n_idx / kai_n_step) * rhs_packed_stride; } -size_t -kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSERT((m_idx % kai_mr) == 0); - KAI_ASSERT((n_idx % kai_nr) == 0); +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(size_t m, size_t n) { + // Temporary assert KAI_ASSERT((n % kai_nr) == 0); return m * n * sizeof(float); } -void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm( - size_t m, size_t n, size_t k, const void* restrict lhs_p, const void* restrict rhs_p, float* restrict dst, +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) + // Temporary asserts KAI_ASSERT(n % kai_nr == 0); KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); @@ -85,10 +99,10 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm( const int8x16_t nibble_mask = vdupq_n_s8(0xF0); - const uint8_t* lhs_ptr_start = lhs_p; + const uint8_t* lhs_ptr_start = lhs_packed; for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { - const uint8_t* rhs_ptr = rhs_p; + const uint8_t* rhs_ptr = rhs_packed; for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { const uint8_t* lhs_ptr = lhs_ptr_start; @@ -266,7 +280,7 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm( main_acc2_4567 = vmulq_laneq_f32(main_acc2_4567, lhs_scale, 2); main_acc3_4567 = vmulq_laneq_f32(main_acc3_4567, lhs_scale, 3); - // Clip (min-max) operation + // clamp (min-max) operation const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); @@ -319,8 +333,8 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm( KAI_UNUSED(m); KAI_UNUSED(n); KAI_UNUSED(k); - KAI_UNUSED(lhs_p); - KAI_UNUSED(rhs_p); + KAI_UNUSED(lhs_packed); + KAI_UNUSED(rhs_packed); KAI_UNUSED(dst); KAI_UNUSED(dst_stride_row); KAI_UNUSED(dst_stride_col); diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h similarity index 58% rename from src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h index fac55cf4..1d82ab86 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h @@ -5,46 +5,65 @@ // #pragma once -#ifdef __cplusplus -extern "C" { -#else +#ifndef __cplusplus #include #endif - #include #include +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @return the m step value + */ +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); + +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @return the n step + */ +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); + /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function * * @return the mr value */ -size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the nr value */ -size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the kr value */ -size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); /** - * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * @brief Function to get the sr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the sr value */ -size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -57,7 +76,7 @@ size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(void); * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -68,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -79,8 +98,8 @@ size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon * * return the DST offset in bytes */ -size_t -kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride); +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride); /** * @brief Function to query the size in bytes for the constant workspace. @@ -88,13 +107,13 @@ kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t m * @param[in] m Number of rows in the destination (DST) matrix. It must be either 1, 2, 3, or any multiple of 4. * @param[in] n Number of columns in the destination (DST) matrix. It must be a multiple 4. */ -size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(size_t m, size_t n); /** - * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clip (min-max) operation. + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * - * LHS matrix: 8-bit quantized asymmeitric per-row (qa8dx) and packed - * RHS matrix: 4-bit quantized symmetric per-channel (qs4cx) and packed. + * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed + * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsu4cx) and packed. * Output tile: (rows x cols) = 4 x 8 * Accumulation performed in a single for loop: 32 * Instruction used: i8mm @@ -102,20 +121,20 @@ size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm(siz * @param[in] m The number of output rows written. It must be either 1, 2, 3, or any multiple of 4. * @param[in] n The number of output columns written. It must be a multiple of 4. * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 32. - * @param[in] lhs_p The LHS matrix packed. + * @param[in] lhs_packed The LHS matrix packed. * When the activation are dynamically quantized, you can obtain this matrix - * by calling the @ref kai_run_lhs_quant_pack_qa8dsP_f32 micro-kernel which performs + * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs * both the dynamic quantization to 8-bit and activation packing in a single step. - * @param[in] rhs_p The RHS matrix packed, which is obtained by calling @ref - * kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 + * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 * @param[out] dst Result of the vector-by-matrix * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) - * @param[in] scalar_min Min value used to clip the final result. - * @param[in] scalar_max Max value used to clip the final result. + * @param[in] scalar_min Min value used to clamp the final result. + * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_4x8x32_neon_i8mm( - size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); #ifdef __cplusplus diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c similarity index 88% rename from src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c index eea059c9..097bd2e9 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c @@ -3,13 +3,15 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h" -#include "../../kai_common.h" +#include "kai_common.h" #include #include +static const size_t kai_m_step = 8; +static const size_t kai_n_step = 4; static const size_t kai_mr = 4; static const size_t kai_nr = 4; static const size_t kai_kr = 16; @@ -20,56 +22,68 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); -size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) { return kai_mr; } -size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) { return kai_nr; } -size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) { return kai_kr; } -size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t m_idx, size_t k) { - KAI_ASSERT((m_idx % kai_mr) == 0); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); - return (m_idx / kai_mr) * lhs_packed_stride; + return (m_idx / kai_m_step) * lhs_packed_stride; } -size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + // Temporary assert KAI_ASSERT((k % kai_k0) == 0); - KAI_ASSERT((n_idx % kai_nr) == 0); const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - return (n_idx / kai_nr) * rhs_packed_stride; + return (n_idx / kai_n_step) * rhs_packed_stride; } -size_t -kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSERT((m_idx % kai_mr) == 0); - KAI_ASSERT((n_idx % kai_nr) == 0); +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n) { + // Temporary assert KAI_ASSERT((n % kai_nr) == 0); return m * n * sizeof(float); } -void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm( - size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) + // Temporary asserts KAI_ASSERT(n % kai_nr == 0); KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); @@ -90,11 +104,11 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm( "madd x26, %x[num_blocks], x26, x20\n" "blt 4f\n" "1:" // Row loop - "mov x24, %x[rhs_p]\n" + "mov x24, %x[rhs_packed]\n" "mov x23, %x[n]\n" "add x22, %x[dst], %x[dst_stride_row], LSL #3\n" "2:" // Column loop - "mov x25, %x[lhs_p]\n" + "mov x25, %x[lhs_packed]\n" "movi v10.4s, #0x0\n" "movi v9.4s, #0x0\n" "mov x21, %x[num_blocks]\n" @@ -260,18 +274,18 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm( "sub x27, x27, #0x8\n" "cmp x27, #0x8\n" "mov %x[dst], x22\n" - "madd %x[lhs_p], x20, x26, %x[lhs_p]\n" + "madd %x[lhs_packed], x20, x26, %x[lhs_packed]\n" "bge 1b\n" "4:" // Row loop skip "cbz x27, 9f\n" "5:" // Row tail: Row loop - "mov x24, %x[rhs_p]\n" + "mov x24, %x[rhs_packed]\n" "mov x23, %x[n]\n" "add x22, %x[dst], %x[dst_stride_row], LSL #2\n" "6:" // Row tail: Column loop "movi v10.4s, #0x0\n" "movi v9.4s, #0x0\n" - "mov x25, %x[lhs_p]\n" + "mov x25, %x[lhs_packed]\n" "mov x20, %x[num_blocks]\n" "movi v8.4s, #0x0\n" "movi v7.4s, #0x0\n" @@ -371,13 +385,13 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm( "add %x[dst], %x[dst], #0x10\n" "bne 6b\n" "subs x27, x27, #0x4\n" - "add %x[lhs_p], %x[lhs_p], x26\n" + "add %x[lhs_packed], %x[lhs_packed], x26\n" "mov %x[dst], x22\n" "bgt 5b\n" "9:" // Row tail: Row loop skip - : [lhs_p] "+&r"(lhs_p), [dst] "+&r"(dst) - : [rhs_p] "r"(rhs_p), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), - [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) + : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), + [num_blocks] "r"(num_blocks), [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27"); @@ -386,8 +400,8 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm( KAI_UNUSED(m); KAI_UNUSED(n); KAI_UNUSED(k); - KAI_UNUSED(lhs_p); - KAI_UNUSED(rhs_p); + KAI_UNUSED(lhs_packed); + KAI_UNUSED(rhs_packed); KAI_UNUSED(dst); KAI_UNUSED(dst_stride_row); KAI_UNUSED(dst_stride_col); diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h similarity index 58% rename from src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h index bd07c802..4e869247 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h @@ -5,46 +5,65 @@ // #pragma once -#ifdef __cplusplus -extern "C" { -#else +#ifndef __cplusplus #include #endif - #include #include +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @return the m step value + */ +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); + +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @return the n step + */ +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); + /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function * * @return the mr value */ -size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the nr value */ -size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the kr value */ -size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); /** - * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * @brief Function to get the sr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the sr value */ -size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -57,7 +76,7 @@ size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(void); * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -68,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -79,8 +98,8 @@ size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon * * return the DST offset in bytes */ -size_t -kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride); +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride); /** * @brief Function to query the size in bytes for the constant workspace. @@ -88,13 +107,13 @@ kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t m * @param[in] m Number of rows in the destination (DST) matrix. It must be either 1, 2, 3, or any multiple of 4. * @param[in] n Number of columns in the destination (DST) matrix. It must be a multiple 4. */ -size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n); /** - * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clip (min-max) operation. + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * - * LHS matrix: 8-bit quantized asymmeitric per-row (qa8dx) and packed - * RHS matrix: 4-bit quantized symmetric per-channel (qs4cx) and packed. + * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed + * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsu4cx) and packed. * Output tile: (rows x cols) = 8 x 4 * Accumulation performed in a single for loop: 32 * Instruction used: i8mm @@ -102,20 +121,20 @@ size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm(siz * @param[in] m The number of output rows written. It must be either 1, 2, 3, or any multiple of 4. * @param[in] n The number of output columns written. It must be a multiple of 4. * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 32. - * @param[in] lhs_p The LHS matrix packed. + * @param[in] lhs_packed The LHS matrix packed. * When the activation are dynamically quantized, you can obtain this matrix - * by calling the @ref kai_run_lhs_quant_pack_qa8dsP_f32 micro-kernel which performs + * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs * both the dynamic quantization to 8-bit and activation packing in a single step. - * @param[in] rhs_p The RHS matrix packed, which is obtained by calling @ref - * kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 + * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 * @param[out] dst Result of the vector-by-matrix * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) - * @param[in] scalar_min Min value used to clip the final result. - * @param[in] scalar_max Max value used to clip the final result. + * @param[in] scalar_min Min value used to clamp the final result. + * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x4x32_neon_i8mm( - size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); #ifdef __cplusplus diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c similarity index 92% rename from src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c index fd58ab2b..4ff45dc0 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c @@ -3,13 +3,15 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h" -#include "../../kai_common.h" +#include "kai_common.h" #include #include +static const size_t kai_m_step = 8; +static const size_t kai_n_step = 8; static const size_t kai_mr = 4; static const size_t kai_nr = 8; static const size_t kai_kr = 16; @@ -20,56 +22,68 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); -size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) { return kai_mr; } -size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) { return kai_nr; } -size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) { return kai_kr; } -size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t m_idx, size_t k) { - KAI_ASSERT((m_idx % kai_mr) == 0); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); - return (m_idx / kai_mr) * lhs_packed_stride; + return (m_idx / kai_m_step) * lhs_packed_stride; } -size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + // Temporary assert KAI_ASSERT((k % kai_k0) == 0); - KAI_ASSERT((n_idx % kai_nr) == 0); const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - return (n_idx / kai_nr) * rhs_packed_stride; + return (n_idx / kai_n_step) * rhs_packed_stride; } -size_t -kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride) { - KAI_ASSERT((m_idx % kai_mr) == 0); - KAI_ASSERT((n_idx % kai_nr) == 0); +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t m, size_t n) { + // Temporary assert KAI_ASSERT((n % kai_nr) == 0); return m * n * sizeof(float); } -void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm( - size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) + // Temporary asserts KAI_ASSERT(n % kai_nr == 0); KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); @@ -90,11 +104,11 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm( "madd x26, %x[num_blocks], x26, x20\n" "blt 4f\n" "1:" // Row loop - "mov x25, %x[rhs_p]\n" + "mov x25, %x[rhs_packed]\n" "mov x23, %x[n]\n" "add x22, %x[dst], %x[dst_stride_row], LSL #3\n" "2:" // Column loop - "mov x24, %x[lhs_p]\n" + "mov x24, %x[lhs_packed]\n" "movi v8.4s, #0x0\n" "movi v27.4s, #0x0\n" "mov x21, %x[num_blocks]\n" @@ -378,18 +392,18 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm( "sub x27, x27, #0x8\n" "cmp x27, #0x8\n" "mov %x[dst], x22\n" - "madd %x[lhs_p], x20, x26, %x[lhs_p]\n" + "madd %x[lhs_packed], x20, x26, %x[lhs_packed]\n" "bge 1b\n" "4:" // Row loop skip "cbz x27, 9f\n" "5:" // Row tail: Row loop - "mov x23, %x[rhs_p]\n" + "mov x23, %x[rhs_packed]\n" "mov x22, %x[n]\n" "add x21, %x[dst], %x[dst_stride_row], LSL #2\n" "6:" // Row tail: Column loop "movi v8.4s, #0x0\n" "movi v27.4s, #0x0\n" - "mov x24, %x[lhs_p]\n" + "mov x24, %x[lhs_packed]\n" "mov x20, %x[num_blocks]\n" "movi v11.4s, #0x0\n" "movi v6.4s, #0x0\n" @@ -555,13 +569,13 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm( "add %x[dst], %x[dst], #0x20\n" "bne 6b\n" "subs x27, x27, #0x4\n" - "add %x[lhs_p], %x[lhs_p], x26\n" + "add %x[lhs_packed], %x[lhs_packed], x26\n" "mov %x[dst], x21\n" "bgt 5b\n" "9:" // Row tail: Row loop skip - : [lhs_p] "+&r"(lhs_p), [dst] "+&r"(dst) - : [rhs_p] "r"(rhs_p), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), - [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) + : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), + [num_blocks] "r"(num_blocks), [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", @@ -571,8 +585,8 @@ void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm( KAI_UNUSED(m); KAI_UNUSED(n); KAI_UNUSED(k); - KAI_UNUSED(lhs_p); - KAI_UNUSED(rhs_p); + KAI_UNUSED(lhs_packed); + KAI_UNUSED(rhs_packed); KAI_UNUSED(dst); KAI_UNUSED(dst_stride_row); KAI_UNUSED(dst_stride_col); diff --git a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h similarity index 57% rename from src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h index 88629a02..ffe9ef9c 100644 --- a/src/matmul/matmul_clip_f32_qa8dxP_qs4cxP/kai_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h @@ -5,46 +5,65 @@ // #pragma once -#ifdef __cplusplus -extern "C" { -#else +#ifndef __cplusplus #include #endif - #include #include +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @return the m step value + */ +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); + +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @return the n step + */ +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); + /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function * * @return the mr value */ -size_t kai_get_mr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the nr value */ -size_t kai_get_nr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the kr value */ -size_t kai_get_kr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); /** - * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 function + * @brief Function to get the sr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function * * @return the sr value */ -size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -57,7 +76,7 @@ size_t kai_get_sr_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(void); * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -68,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -79,8 +98,8 @@ size_t kai_get_rhs_packed_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon * * return the DST offset in bytes */ -size_t -kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t m_idx, size_t n_idx, size_t dst_stride); +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride); /** * @brief Function to query the size in bytes for the constant workspace. @@ -88,34 +107,34 @@ kai_get_dst_offset_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t m * @param[in] m Number of rows in the destination (DST) matrix. It must be either 1, 2, 3, or any multiple of 4. * @param[in] n Number of columns in the destination (DST) matrix. It must be a multiple 4. */ -size_t kai_get_dst_size_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t m, size_t n); /** - * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clip (min-max) operation. + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * - * LHS matrix: 8-bit quantized asymmeitric per-row (qa8dx) and packed - * RHS matrix: 4-bit quantized symmetric per-channel (qs4cx) and packed. - * Output tile: (rows x cols) = 8 x 4 + * LHS matrix: Signed 8-bit quantized asymmetric per-row (qau8dx) and packed + * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsi4cx) and packed. + * Output tile: (rows x cols) = 8 x 8 * Accumulation performed in a single for loop: 32 * Instruction used: i8mm * * @param[in] m The number of output rows written. It must be either 1, 2, 3, or any multiple of 4. * @param[in] n The number of output columns written. It must be a multiple of 4. * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 32. - * @param[in] lhs_p The LHS matrix packed. + * @param[in] lhs_packed The LHS matrix packed. * When the activation are dynamically quantized, you can obtain this matrix - * by calling the @ref kai_run_lhs_quant_pack_qa8dsP_f32 micro-kernel which performs + * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs * both the dynamic quantization to 8-bit and activation packing in a single step. - * @param[in] rhs_p The RHS matrix packed, which is obtained by calling @ref - * kai_run_rhs_pack_nxk_qs4cxP_qs4cxS1S0 + * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 * @param[out] dst Result of the vector-by-matrix * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) - * @param[in] scalar_min Min value used to clip the final result. - * @param[in] scalar_max Max value used to clip the final result. + * @param[in] scalar_min Min value used to clamp the final result. + * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clip_f32_qa8dxP4X8_qs4cxP4X8_8x8x32_neon_i8mm( - size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); #ifdef __cplusplus diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h new file mode 100644 index 00000000..dbf08a0a --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h @@ -0,0 +1,51 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f32_qai8dxp_qsu4cxp + +/** Micro-kernel helper functions ("get" methods) */ +typedef size_t (*kai_get_m_step_func_t)(void); +typedef size_t (*kai_get_n_step_func_t)(void); +typedef size_t (*kai_get_mr_func_t)(void); +typedef size_t (*kai_get_nr_func_t)(void); +typedef size_t (*kai_get_kr_func_t)(void); +typedef size_t (*kai_get_sr_func_t)(void); +typedef size_t (*kai_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_get_dst_size_func_t)(size_t m, size_t n); + +/** Micro-kernel core function ("run" method) */ +typedef void (*kai_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/** Micro-kernel interface */ +struct kai_matmul_clamp_f32_qai8dxp_qsu4cxp_ukernel { + kai_get_m_step_func_t get_m_step; + kai_get_n_step_func_t get_n_step; + kai_get_mr_func_t get_mr; + kai_get_nr_func_t get_nr; + kai_get_nr_func_t get_kr; + kai_get_sr_func_t get_sr; + kai_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_get_dst_offset_func_t get_dst_offset; + kai_get_dst_size_func_t get_dst_size; + kai_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif -- GitLab From 5e2b758365f940637351e30df784387669e59769 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 13 May 2024 12:33:36 +0100 Subject: [PATCH 04/14] Rebased Signed-off-by: Gian Marco Iodice --- src/kai_common.h | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/kai_common.h b/src/kai_common.h index 30c58afa..9bb2550a 100644 --- a/src/kai_common.h +++ b/src/kai_common.h @@ -30,12 +30,7 @@ extern "C" { KAI_ERROR(msg); \ } \ } while (0) -#define KAI_ASSERT(x) \ - do { \ - if (!(x)) { \ - exit(EXIT_FAILURE); \ - } \ - } while (0) + // NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) #define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond) -- GitLab From 121c268aec93aac41d6dd74ecaaff6ddac3c4896 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 13 May 2024 17:20:09 +0100 Subject: [PATCH 05/14] Add m_idx_start input argument to the lhs packing ukernel - This argument is required to write in the correct location when the ukernel is called to process only a portion of the output matrix Signed-off-by: Gian Marco Iodice --- .../matmul_clamp_f32_qai8dxp_qsu4cxp.cpp | 38 +++++++------------ src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c | 29 ++++++++------ src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h | 30 ++++++++++----- .../kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h | 2 +- 4 files changed, 53 insertions(+), 46 deletions(-) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp index 586f0296..72469a5d 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp @@ -5,6 +5,13 @@ // // Include micro-kernel variants +#include +#include +#include +#include +#include +#include + #include "kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h" @@ -15,13 +22,6 @@ #include "kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h" #include "kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h" -#include -#include -#include -#include -#include -#include - #define INT4_MIN (-8) #define INT4_MAX (7) @@ -110,7 +110,6 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, size_t seed) { - std::srand(seed); // Fill the array with random values between -1 and 1 @@ -120,11 +119,9 @@ static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, si } static void quant_qs4cx_f32(size_t n, size_t k, const float* rhs_f32, uint8_t* rhs_qs4cx, float* rhs_scales_f32) { - const size_t dst_stride = (k / 2) * sizeof(int8_t); for (size_t row_idx = 0; row_idx < n; ++row_idx) { - const float* src_ptr = rhs_f32 + row_idx * k; float max0 = -FLT_MAX; @@ -181,11 +178,9 @@ static void quant_qs4cx_f32(size_t n, size_t k, const float* rhs_f32, uint8_t* r }; static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { - const size_t dst_stride = (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); for (size_t row_idx = 0; row_idx < m; ++row_idx) { - const float* src_ptr = lhs_f32 + row_idx * k; float max0 = -FLT_MAX; @@ -253,15 +248,12 @@ static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t static void ref_matmul_f32_qa8dx_qs4cx( size_t m, size_t n, size_t k, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4cx, const float* rhs_scales_f32, float* dst_f32, float scalar_min, float scalar_max) { - const size_t lhs_stride = k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); const size_t rhs_stride = (k / 2) * sizeof(uint8_t); for (size_t row_idx = 0; row_idx < m; ++row_idx) { - const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; for (size_t col_idx = 0; col_idx < n; ++col_idx) { - // Main f32 accumulator int32_t iacc = 0; @@ -315,7 +307,6 @@ static void ref_matmul_f32_qa8dx_qs4cx( }; static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, const float* ref, const float* act) { - bool is_valid = true; for (size_t i = 0; i < num_rows * num_cols; ++i) { @@ -331,8 +322,8 @@ static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, int main(int argc, char** argv) { const size_t m = 17; - const size_t n = 32; // It must be a multiple of 8 - const size_t k = 64; // It must be a multiple of 64 + const size_t n = 32; // It must be a multiple of 8 + const size_t k = 64; // It must be a multiple of 64 const size_t seed_lhs = 4568; const size_t seed_rhs = seed_lhs + 4; @@ -383,7 +374,6 @@ int main(int argc, char** argv) { //------------------------------------ //------------------------------------ for (size_t idx_variant = 0; idx_variant < num_ukernel_variants; ++idx_variant) { - std::cout << "Testing " << ukernel_variants[idx_variant].name << std::endl; // Get the packing parameters @@ -411,15 +401,15 @@ int main(int argc, char** argv) { // RHS packing kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( 1, n, k, nr, kr, sr, - (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS - NULL, // Bias - (const float*)(rhs_scales_f32), // Scale - rhs_packed_mtx_qs4cx, // DST + (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS + NULL, // Bias + (const float*)(rhs_scales_f32), // Scale + rhs_packed_mtx_qs4cx, // DST 0, ¶ms); // LHS packing kai_run_lhs_quant_pack_qai8dxp_f32( - m, k, mr, kr, sr, (const float*)lhs_native_mtx_f32, k * sizeof(float), lhs_packed_mtx_qa8dx); + m, k, mr, kr, sr, 0, (const float*)lhs_native_mtx_f32, k * sizeof(float), lhs_packed_mtx_qa8dx); // Matmul { diff --git a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c index ec3abb57..41642c45 100644 --- a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c +++ b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c @@ -5,13 +5,13 @@ // #include "kai_lhs_quant_pack_qai8dxp_f32.h" -#include "kai_common.h" - #include #include #include #include +#include "kai_common.h" + static const size_t kai_num_bytes_per_multiplier = sizeof(float); static const size_t kai_num_bytes_per_offset = sizeof(int32_t); @@ -19,12 +19,15 @@ size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_st return m_idx * lhs_stride; } -size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr) { +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { KAI_ASSERT(k % kr == 0); KAI_UNUSED(kr); + + const size_t dst_y = (m_idx / mr); const size_t dst_stride = mr * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); - return (m_idx / mr) * dst_stride; + // It always points to the beginning of the row + return dst_y * dst_stride; } size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr) { @@ -37,8 +40,8 @@ size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, si } void kai_run_lhs_quant_pack_qai8dxp_f32( - size_t m, size_t k, size_t mr, size_t kr, size_t sr, const float* restrict lhs, size_t lhs_stride, - void* restrict lhs_packed) { + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* restrict lhs, + size_t lhs_stride, void* restrict lhs_packed) { KAI_ASSERT(k % kr == 0); KAI_ASSERT((kr % sr) == 0); @@ -113,12 +116,9 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( // Round to nearest integer const int32_t nudged_zero_point0 = lrintf(zero_point0); - const size_t dst_x = (row_idx % mr); - const size_t dst_y = (row_idx / mr); - - uint8_t* dst_ptr = (uint8_t*)lhs_packed + dst_y * dst_stride; + const size_t dst_x = ((row_idx + m_idx_start) % mr); - dst_ptr += dst_x * k_block_len * sizeof(int8_t); + uint8_t* dst_ptr = (uint8_t*)lhs_packed + dst_x * k_block_len * sizeof(int8_t); // Quantize the channels k_idx = 0; @@ -138,7 +138,7 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); } - dst_ptr = (uint8_t*)lhs_packed + dst_y * dst_stride + mr * (k * sizeof(int8_t)); + dst_ptr = (uint8_t*)lhs_packed + mr * (k * sizeof(int8_t)); dst_ptr += dst_x * kai_num_bytes_per_offset; @@ -154,5 +154,10 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( *((float*)(dst_ptr)) = recip_scale0; src_ptr += (lhs_stride / sizeof(float)); + + // Move to the next row if we have interleaved all Mr rows + if ((((row_idx + 1) + m_idx_start) % mr) == 0) { + lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); + } } } diff --git a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h index a13f32a9..464d9905 100644 --- a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h +++ b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h @@ -32,16 +32,22 @@ size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_st * * @param[in] m_idx Row index in the LHS matrix (not packed). * @param[in] k Total number of columns in the LHS matrix (not packed). + * @param[in] mr The number of M rows to interleave on the same output row. + * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. + * However, kr must be multiple of sr. * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr); +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); /** * @brief Function to return the memory required for storing the quantized and packed LHS matrix * - * @param[in] m Total number of rows in the LHS matrix (not packed). - * @param[in] k Total number of columns in the LHS matrix (not packed). + * @param[in] m Total number of rows in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). + * @param[in] mr The number of M rows to interleave on the same output row. + * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. * * return the size in bytes to the packed LHS matrix */ @@ -50,14 +56,20 @@ size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, si /** * @brief Micro-kernel to quantize and pack the LHS matrix * - * @param[in] m The number of output rows written. - * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. - * @param[in] lhs LHS of the vector-by-matrix. - * @param[in] lhs_stride Stride in bytes between two rows of LHS. - * @param[out] lhs_packed The quantized and packed LHS matrix. + * @param[in] m The number of output rows written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. + * @param[in] mr The number of M rows to interleave on the same output row. + * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. + * However, kr must be multiple of sr. + * @param[in] m_idx_start The starting M index. + * @param[in] lhs LHS of the vector-by-matrix. + * @param[in] lhs_stride Stride in bytes between two rows of LHS. + * @param[out] lhs_packed The quantized and packed LHS matrix. */ void kai_run_lhs_quant_pack_qai8dxp_f32( - size_t m, size_t k, size_t mr, size_t kr, size_t sr, const float* lhs, size_t lhs_stride, void* lhs_packed); + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, size_t lhs_stride, + void* lhs_packed); #ifdef __cplusplus } diff --git a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h index 3e3257db..fb609850 100644 --- a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h +++ b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h @@ -65,7 +65,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n, size_t * @param[in] num_groups The number of groups. It must be 1. * @param[in] n The number of columns of the output matrix (N). * @param[in] k The common dimension between the LHS and RHS matrix (K). - * @param[in] nr The number of columns written by the matmul micro-kernel. + * @param[in] nr The number of N columns to interleave on the same output output row. * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. * However, kr must be multiple of sr. -- GitLab From 4a667ee71b03923edbe952636a2d05e085d9a61e Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 13 May 2024 19:00:36 +0100 Subject: [PATCH 06/14] Fix kai_common.h inclusion in kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 Signed-off-by: Gian Marco Iodice --- src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c index 000771fd..721fbe43 100644 --- a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c +++ b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c @@ -5,13 +5,13 @@ // #include "kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h" -#include "../kai_common.h" - #include #include #include #include +#include "kai_common.h" + static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); @@ -53,7 +53,6 @@ void kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const int32_t* bias, const float* scale, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0_params* params) { - // Temporary asserts KAI_ASSERT(num_groups == 1); KAI_ASSERT((k % 2) == 0); -- GitLab From 794825737c8a094fa51c8dd11954c2f7658b5c11 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 14 May 2024 16:28:59 +0100 Subject: [PATCH 07/14] Update CMakeLists.txt file - Include the micro-kernels into the CMakeLists.txt file - Fix file function names for the micro-kernels performing 8 columns Signed-off-by: Gian Marco Iodice --- CMakeLists.txt | 20 +++++++ .../CMakeLists.txt | 11 ++-- .../matmul_clamp_f32_qai8dxp_qsu4cxp.cpp | 52 +++++++++---------- ...i8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c} | 29 +++++------ ...i8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h} | 22 ++++---- ..._qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c} | 28 +++++----- ..._qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h} | 22 ++++---- 7 files changed, 100 insertions(+), 84 deletions(-) rename src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/{kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.c => kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c} (92%) rename src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/{kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h => kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h} (88%) rename src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/{kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.c => kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c} (95%) rename src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/{kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h => kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h} (88%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f3e82a8..7f44fb5e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,6 +56,26 @@ set(KLEIDIAI_WARNING_FLAGS $<$:${KLEIDIAI_WARNING_FLAGS_CXX}> ) +add_library(kleidiai) + +target_sources(Kleidiai PRIVATE + src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c + src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c + src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c + src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c + src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c + src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c + src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c + src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c) + +target_include_directories(kleidiai + PRIVATE . +) + +target_compile_options(kleidiai + PRIVATE ${KLEIDIAI_WARNING_FLAGS} +) + if(KLEIDIAI_BUILD_TESTS) enable_testing() include(GoogleTest) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt index c4a79437..8fce9f07 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt +++ b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt @@ -14,9 +14,6 @@ include_directories( set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") -# Project name -project(matmul_clamp_f32_qai8dxp_qsu4cxp) - # Files requires to build the executable add_executable(matmul_clamp_f32_qai8dxp_qsu4cxp matmul_clamp_f32_qai8dxp_qsu4cxp.cpp @@ -28,14 +25,14 @@ add_executable(matmul_clamp_f32_qai8dxp_qsu4cxp ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp index 72469a5d..e12d3aca 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp @@ -14,10 +14,10 @@ #include "kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h" -#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h" -#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h" #include "kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h" @@ -44,18 +44,18 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, "matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod, - "matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod"}, {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, @@ -80,18 +80,18 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, "matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm"}, {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c similarity index 92% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c index 9dfaac5e..bbf44936 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c @@ -3,13 +3,13 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h" - -#include "kai_common.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h" #include #include +#include "kai_common.h" + static const size_t kai_m_step = 1; static const size_t kai_n_step = 8; static const size_t kai_mr = 1; @@ -22,38 +22,38 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void) { return kai_m_step; } -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void) { +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void) { return kai_n_step; } -size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void) { +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); return (m_idx / kai_m_step) * lhs_packed_stride; } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); // Temporary assert @@ -64,7 +64,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_n return (n_idx / kai_n_step) * rhs_packed_stride; } -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride) { KAI_ASSERT((m_idx % kai_m_step) == 0); KAI_ASSERT((n_idx % kai_n_step) == 0); @@ -72,14 +72,14 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dot return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t m, size_t n) { // Temporary assert KAI_ASSERT((n % kai_nr) == 0); return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod( +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_DOTPROD) @@ -104,7 +104,6 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod( for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { const uint8_t* rhs_ptr = rhs_packed; for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { - const uint8_t* lhs_ptr = lhs_ptr_start; // Main f32 accumulator diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h similarity index 88% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h index 234a1279..275fcc69 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h @@ -22,7 +22,7 @@ extern "C" { * * @return the m step value */ -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void); +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void); /** * @brief Function to get the n step value. @@ -31,7 +31,7 @@ size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod * * @return the n step */ -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void); +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void); /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with @@ -39,7 +39,7 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod * * @return the mr value */ -size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with @@ -47,7 +47,7 @@ size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(voi * * @return the nr value */ -size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with @@ -55,7 +55,7 @@ size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(voi * * @return the kr value */ -size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void); /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with @@ -63,7 +63,7 @@ size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(voi * * @return the sr value */ -size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -76,7 +76,7 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(voi * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -87,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_n * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -98,7 +98,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_n * * return the DST offset in bytes */ -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride); /** @@ -107,7 +107,7 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dot * @param[in] m Number of rows in the destination (DST) matrix * @param[in] n Number of columns in the destination (DST) matrix */ -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t m, size_t n); /** * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. @@ -133,7 +133,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotpr * @param[in] scalar_min Min value used to clamp the final result. * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x8x32_neon_dotprod( +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c similarity index 95% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c index 5405b4d1..18768f08 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c @@ -3,13 +3,13 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h" - -#include "kai_common.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h" #include #include +#include "kai_common.h" + static const size_t kai_m_step = 4; static const size_t kai_n_step = 8; static const size_t kai_mr = 4; @@ -22,38 +22,38 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) { return kai_m_step; } -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void) { +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) { return kai_n_step; } -size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void) { +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); return (m_idx / kai_m_step) * lhs_packed_stride; } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); // Temporary assert @@ -64,7 +64,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_n return (n_idx / kai_n_step) * rhs_packed_stride; } -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm( size_t m_idx, size_t n_idx, size_t dst_stride) { KAI_ASSERT((m_idx % kai_m_step) == 0); KAI_ASSERT((n_idx % kai_n_step) == 0); @@ -72,14 +72,14 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8m return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t m, size_t n) { // Temporary assert KAI_ASSERT((n % kai_nr) == 0); return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm( +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm( size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h similarity index 88% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h index 1d82ab86..1a2221cf 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h @@ -22,7 +22,7 @@ extern "C" { * * @return the m step value */ -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the n step value. @@ -31,7 +31,7 @@ size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(vo * * @return the n step */ -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with @@ -39,7 +39,7 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(vo * * @return the mr value */ -size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with @@ -47,7 +47,7 @@ size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); * * @return the nr value */ -size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with @@ -55,7 +55,7 @@ size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); * * @return the kr value */ -size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with @@ -63,7 +63,7 @@ size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); * * @return the sr value */ -size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -76,7 +76,7 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(void); * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -87,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_n * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -98,7 +98,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_n * * return the DST offset in bytes */ -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm( size_t m_idx, size_t n_idx, size_t dst_stride); /** @@ -107,7 +107,7 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8m * @param[in] m Number of rows in the destination (DST) matrix. It must be either 1, 2, 3, or any multiple of 4. * @param[in] n Number of columns in the destination (DST) matrix. It must be a multiple 4. */ -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t m, size_t n); /** * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. @@ -133,7 +133,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm( * @param[in] scalar_min Min value used to clamp the final result. * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x8x32_neon_i8mm( +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); -- GitLab From ea11c8b403cb9f16c7a2d96e9fc23655086f3bd7 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 23 May 2024 09:35:12 +0100 Subject: [PATCH 08/14] Remove M, N and K restrictions in the Int4 matmul micro-kernels - Make the Int4 matmul micro-kernels working with arbitrary M, N and K - The only restriction is on K which must be an even value - Update the doxygen comment - Update the test case Signed-off-by: Gian Marco Iodice --- CMakeLists.txt | 2 +- .../CMakeLists.txt | 4 +- .../matmul_clamp_f32_qai8dxp_qsu4cxp.cpp | 36 +- src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c | 83 +- src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h | 15 +- .../kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c | 78 +- .../kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h | 38 +- ...i8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c} | 123 +- ...i8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h} | 48 +- ...ai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c | 82 +- ...ai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h | 28 +- ...2_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c | 325 +++-- ...2_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h | 30 +- ...2_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c | 719 +++++----- ...2_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h | 32 +- ...2_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c | 526 ++++---- ...2_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h | 36 +- ...2_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c | 1152 ++++++++++------- ...2_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h | 36 +- ...tmul_clamp_f32_qai8dxp_qsu4cxp_interface.h | 45 +- 20 files changed, 1920 insertions(+), 1518 deletions(-) rename src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/{kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c => kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c} (64%) rename src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/{kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h => kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h} (77%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7f44fb5e..dab8f92d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,7 +61,7 @@ add_library(kleidiai) target_sources(Kleidiai PRIVATE src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c - src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c + src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c diff --git a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt index 8fce9f07..8de22326 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt +++ b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt @@ -23,8 +23,8 @@ add_executable(matmul_clamp_f32_qai8dxp_qsu4cxp ../../src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h ../../src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h diff --git a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp index e12d3aca..80d529ea 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp @@ -13,7 +13,7 @@ #include #include "kai_lhs_quant_pack_qai8dxp_f32.h" -#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h" @@ -32,18 +32,18 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { }; kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod, - "matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod"}, {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, @@ -321,9 +321,9 @@ static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, } int main(int argc, char** argv) { - const size_t m = 17; - const size_t n = 32; // It must be a multiple of 8 - const size_t k = 64; // It must be a multiple of 64 + const size_t m = 13; + const size_t n = 17; + const size_t k = 18; const size_t seed_lhs = 4568; const size_t seed_rhs = seed_lhs + 4; @@ -383,8 +383,8 @@ int main(int argc, char** argv) { const size_t sr = ukernel_variants[idx_variant].ukernel.get_sr(); // Get the size in bytes for the packed matrices - const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr); - const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(n, k, nr, kr); + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(n, k, nr, kr, sr); const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); // Allocate the matrices diff --git a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c index 41642c45..4a5fa657 100644 --- a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c +++ b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c @@ -15,34 +15,44 @@ static const size_t kai_num_bytes_per_multiplier = sizeof(float); static const size_t kai_num_bytes_per_offset = sizeof(int32_t); +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for memory alignment. + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) { + const size_t k_internal = kai_k_roundedup(k, kr, sr); + + KAI_ASSERT((k_internal % 2) == 0); + + return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); +} + +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f32(size_t mr) { + KAI_UNUSED(mr); + return 1; +} + size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_stride) { return m_idx * lhs_stride; } size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { - KAI_ASSERT(k % kr == 0); - KAI_UNUSED(kr); - - const size_t dst_y = (m_idx / mr); - const size_t dst_stride = mr * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); - // It always points to the beginning of the row - return dst_y * dst_stride; + return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr); } -size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr) { - KAI_ASSERT(k % kr == 0); - KAI_UNUSED(kr); - const size_t m_roundup = kai_roundup(m, mr); +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + const size_t num_rows = kai_roundup(m, mr) / mr; - const size_t t_size = m_roundup * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); - return t_size; + return num_rows * kai_lhs_packed_stride(k, mr, kr, sr); } void kai_run_lhs_quant_pack_qai8dxp_f32( size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* restrict lhs, size_t lhs_stride, void* restrict lhs_packed) { - KAI_ASSERT(k % kr == 0); KAI_ASSERT((kr % sr) == 0); if (m == 0) { @@ -50,22 +60,26 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( } const size_t num_rows = m; - const size_t num_cols = k; const float* src_ptr = lhs; - const size_t dst_stride = mr * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); + const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr); const size_t k_block_len = kr / sr; + const size_t k_internal = kai_k_roundedup(k, kr, sr); for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); // Find min/max for each channel - size_t k_idx = 0; - for (; k_idx <= (num_cols - 8); k_idx += 8) { - const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + k_idx); - const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + k_idx); + int32_t k_idx = 0; +#if defined(__aarch64__) + for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { + const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx); + const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx); // Calculate the max vmax0 = vmaxq_f32(src0_0, vmax0); @@ -75,19 +89,15 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( vmin0 = vminq_f32(src0_0, vmin0); vmin0 = vminq_f32(vmin0, src0_1); } - - for (; k_idx < num_cols; ++k_idx) { - const float src0_0 = *(src_ptr + k_idx); - - // Calculate the max - vmax0 = vsetq_lane_f32(KAI_MAX(src0_0, vgetq_lane_f32(vmax0, 0)), vmax0, 0); - // Calculate the min - vmin0 = vsetq_lane_f32(KAI_MIN(src0_0, vgetq_lane_f32(vmin0, 0)), vmin0, 0); - } - // Get the max/min - const float max0 = vmaxvq_f32(vmax0); - const float min0 = vminvq_f32(vmin0); + max0 = vmaxvq_f32(vmax0); + min0 = vminvq_f32(vmin0); +#endif + for (; k_idx < k; ++k_idx) { + const float src0_0 = *(src_ptr + (size_t)k_idx); + max0 = KAI_MAX(src0_0, max0); + min0 = KAI_MIN(src0_0, min0); + } // Maximum/minimum int8 values const float qmin = (float)INT8_MIN; @@ -122,9 +132,12 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( // Quantize the channels k_idx = 0; - for (; k_idx < num_cols; k_idx += k_block_len) { + for (; k_idx < k_internal; k_idx += k_block_len) { for (size_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { - const float src0_0 = *(src_ptr + k_idx + k_block_idx); + // Clamp at the last valid k-index + const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k); + + const float src0_0 = *(src_ptr + k_idx_start); // Scale the values int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); @@ -138,7 +151,7 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); } - dst_ptr = (uint8_t*)lhs_packed + mr * (k * sizeof(int8_t)); + dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); dst_ptr += dst_x * kai_num_bytes_per_offset; diff --git a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h index 464d9905..85715dc6 100644 --- a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h +++ b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h @@ -12,6 +12,17 @@ extern "C" { #endif +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @param[in] mr The number of M rows to interleave on the same output row. + * + * @return the m step value + */ +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f32(size_t mr); + /** * @brief Function to calculate the offset in bytes for the LHS matrix (not packed) * @@ -35,7 +46,6 @@ size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_st * @param[in] mr The number of M rows to interleave on the same output row. * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. - * However, kr must be multiple of sr. * * return the offset in bytes to the packed LHS matrix */ @@ -48,10 +58,11 @@ size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t * @param[in] k Total number of columns in the LHS matrix (not packed). * @param[in] mr The number of M rows to interleave on the same output row. * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. * * return the size in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr); +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr, size_t sr); /** * @brief Micro-kernel to quantize and pack the LHS matrix diff --git a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c index 721fbe43..421e4430 100644 --- a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c +++ b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c @@ -15,54 +15,56 @@ static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t kr, size_t nr, size_t sr) { + const size_t k_internal = kai_k_roundedup(k, kr, sr); + + KAI_ASSERT((k_internal % 2) == 0); + + return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + inline static int8_t kai_int4_sign_extend(int8_t x) { return (x ^ 0x8) - 8; } +size_t kai_get_n_step_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t nr) { + return nr; +} + size_t kai_get_rhs_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride) { return n_idx * rhs_stride; } -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n_idx, size_t k, size_t nr, size_t kr) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kr) == 0); +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { KAI_ASSERT((n_idx % nr) == 0); - KAI_UNUSED(kr); - - const size_t rhs_packed_stride = nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - - return (n_idx / nr) * rhs_packed_stride; + return (n_idx / nr) * kai_rhs_packed_stride(k, kr, nr, sr); } -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr) { - KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((k % kr) == 0); - KAI_ASSERT((n % nr) == 0); - - KAI_UNUSED(kr); - - const size_t num_rows = n / nr; +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { + const size_t num_rows = kai_roundup(n, nr) / nr; - const size_t rhs_packed_stride = nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - - return num_rows * rhs_packed_stride; + return num_rows * kai_rhs_packed_stride(k, kr, nr, sr); } void kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const int32_t* bias, const float* scale, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0_params* params) { - // Temporary asserts - KAI_ASSERT(num_groups == 1); KAI_ASSERT((k % 2) == 0); - KAI_ASSERT((n % nr) == 0); - KAI_ASSERT((k % kr) == 0); + KAI_ASSERT(num_groups == 1); KAI_ASSERT(bias == NULL); KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); - KAI_ASSERT(sr == 2); - KAI_ASSERT(kr >= 1 && kr <= 16); KAI_ASSERT(rhs != NULL); KAI_ASSERT(scale != NULL); KAI_ASSERT(rhs_packed != NULL); @@ -75,26 +77,38 @@ void kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( const size_t rhs_zero_point = params->rhs_zero_point; const size_t rhs_stride = k / 2; - const size_t rhs_packed_stride = nr * ((k / 2) + sizeof(float) + sizeof(int32_t)); + const size_t rhs_packed_stride = kai_rhs_packed_stride(k, kr, nr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); for (size_t y = 0; y < n; y += nr) { const uint8_t* src_row = (const uint8_t*)rhs + y * rhs_stride; uint8_t* dst_row = (uint8_t*)rhs_packed + (y / nr) * rhs_packed_stride; - int32_t* sums = (int32_t*)(dst_row + nr * (k / 2)); + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); // Initialize to zero the RHS reduction sums memset(sums, 0, nr * sizeof(int32_t)); - for (size_t x = 0; x < k; x += (kr * 2)) { + for (size_t x = 0; x < k_internal; x += (kr * sr)) { for (size_t s = 0; s < sr; ++s) { for (size_t i = 0; i < nr; ++i) { for (size_t kr_idx = 0; kr_idx < kr / sr; kr_idx += 2) { - const size_t src_addr_byte0 = i * rhs_stride + (x / 2) + kr_idx / 2 + s * (kr / sr) / 2; - const size_t src_addr_byte1 = src_addr_byte0 + (kr / 2); + const size_t k_idx_start0 = (x / 2) + kr_idx / 2 + s * (kr / sr) / 2; + const size_t k_idx_start1 = k_idx_start0 + (kr / 2); - const uint8_t byte0 = src_row[src_addr_byte0]; - const uint8_t byte1 = src_row[src_addr_byte1]; + const size_t src_addr_byte0 = i * rhs_stride + k_idx_start0; + const size_t src_addr_byte1 = i * rhs_stride + k_idx_start1; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k_idx_start0 < (k / 2)) { + byte0 = src_row[src_addr_byte0]; + } + + if (k_idx_start1 < (k / 2)) { + byte1 = src_row[src_addr_byte1]; + } if (rhs_zero_point == 0) { int8_t src_x0_lo = (byte0 & 0x0F); diff --git a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h index fb609850..bc607599 100644 --- a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h +++ b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h @@ -17,13 +17,24 @@ struct kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0_params { uint8_t rhs_zero_point; }; +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @param[in] nr The number of columns written by the matmul micro-kernel + * + * @return the n step value + */ +size_t kai_get_n_step_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t nr); + /** * @brief Function to calculate the offset in bytes for the RHS matrix (not packed), which holds * the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. * Two int4 values are stored in one byte. The lower order part of the byte (low) holds * the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). * - * @param[in] n_idx Row index in the RHS matrix (not packed). + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. * @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) * * return the offset in bytes to the RHS matrix (not packed) @@ -32,31 +43,34 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n_idx, size_t r /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, - * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. * - * @param[in] n_idx Row index in the RHS matrix (not packed). + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. * @param[in] k The common dimension between the LHS and RHS matrix (K) * @param[in] nr The number of columns written by the matmul micro-kernel - * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n_idx, size_t k, size_t nr, size_t kr); +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr); /** - * @brief Function to return the memory required for storing the quantized and packed RHS matrix + * @brief Function to return the memory required for storing the packed RHS matrix * * @param[in] n The number of rows in the RHS matrix (not packed) * @param[in] k The number of columns in the RHS matrix (not packed). * @param[in] nr The number of columns written by the matmul micro-kernel * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. * * return the size in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr); +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr); /** - * @brief Micro-kernel to quantize and pack the RHS matrix. + * @brief Micro-kernel to pack the RHS matrix. * * @note The int4 values are stored in a N x K matrix, where N is number of rows and K is the number of columns. * Two int4 values are stored in one byte. The lower order part of the byte (low) holds @@ -64,7 +78,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n, size_t * * @param[in] num_groups The number of groups. It must be 1. * @param[in] n The number of columns of the output matrix (N). - * @param[in] k The common dimension between the LHS and RHS matrix (K). + * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be an even value. * @param[in] nr The number of N columns to interleave on the same output output row. * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. @@ -73,9 +87,9 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n, size_t * Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). * @param[in] bias The biases. * @param[in] scale The scale for each output channel. - * @param[out] rhs_packed The quantized and packed RHS matrix. - * @param[in] extra_bytes Extra bytes to append to the end of each row of the quantized and packed RHS matrix. - * @param[in] params Parameters for the function. + * @param[out] rhs_packed The packed RHS matrix. + * @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. + * @param[in] params Parameters for the micro-kernel. */ void kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const int32_t* bias, diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c similarity index 64% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c index 44f16198..a51b2dac 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c @@ -3,68 +3,85 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h" - -#include "kai_common.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h" #include #include +#include "kai_common.h" + static const size_t kai_m_step = 1; static const size_t kai_n_step = 4; static const size_t kai_mr = 1; static const size_t kai_nr = 4; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_k0 = 64; +static const size_t kai_k0 = kai_kr * kai_sr; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void) { +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void) { return kai_m_step; } -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void) { +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void) { return kai_n_step; } -size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void) { +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); - const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); - return (m_idx / kai_m_step) * lhs_packed_stride; + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); - // Temporary assert - KAI_ASSERT((k % kai_k0) == 0); - - const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - - return (n_idx / kai_n_step) * rhs_packed_stride; + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); } -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride) { KAI_ASSERT((m_idx % kai_m_step) == 0); KAI_ASSERT((n_idx % kai_n_step) == 0); @@ -72,20 +89,14 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dot return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(size_t m, size_t n) { - // Temporary assert - KAI_ASSERT((n % kai_nr) == 0); - +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod( +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod( size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_DOTPROD) - // Temporary asserts - KAI_ASSERT(n % kai_nr == 0); - KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { @@ -95,61 +106,46 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod( const size_t num_rows = m; const size_t num_cols = n; - const size_t lhs_packed_stride = kai_mr * (k + sizeof(float) + sizeof(float)); + const size_t lhs_packed_stride = kai_lhs_packed_stride(k); + const size_t k_internal = kai_k_roundedup(k); const int8x16_t nibble_mask = vdupq_n_s8(0xF0); const uint8_t* lhs_ptr_start = lhs_packed; for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { - const uint8_t* rhs_ptr = rhs_packed; for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { - const uint8_t* lhs_ptr = lhs_ptr_start; // Main f32 accumulator int32x4_t iacc0011 = vdupq_n_s32(0); int32x4_t iacc2233 = vdupq_n_s32(0); - for (size_t b = 0; b < k; b += kai_k0) { + for (size_t b = 0; b < k_internal; b += kai_k0) { // Set up RHS const int8x16_t rhs_raw_vec_0 = vld1q_s8((const int8_t*)(rhs_ptr + 0)); const int8x16_t rhs_raw_vec_1 = vld1q_s8((const int8_t*)(rhs_ptr + 16)); const int8x16_t rhs_raw_vec_2 = vld1q_s8((const int8_t*)(rhs_ptr + 32)); const int8x16_t rhs_raw_vec_3 = vld1q_s8((const int8_t*)(rhs_ptr + 48)); - const int8x16_t rhs_raw_vec_4 = vld1q_s8((const int8_t*)(rhs_ptr + 64)); - const int8x16_t rhs_raw_vec_5 = vld1q_s8((const int8_t*)(rhs_ptr + 80)); - const int8x16_t rhs_raw_vec_6 = vld1q_s8((const int8_t*)(rhs_ptr + 96)); - const int8x16_t rhs_raw_vec_7 = vld1q_s8((const int8_t*)(rhs_ptr + 112)); // Low nibble const int8x16_t rhs_vec_0_0 = vshlq_n_s8(rhs_raw_vec_0, 4); const int8x16_t rhs_vec_1_0 = vshlq_n_s8(rhs_raw_vec_1, 4); const int8x16_t rhs_vec_2_0 = vshlq_n_s8(rhs_raw_vec_2, 4); const int8x16_t rhs_vec_3_0 = vshlq_n_s8(rhs_raw_vec_3, 4); - const int8x16_t rhs_vec_4_0 = vshlq_n_s8(rhs_raw_vec_4, 4); - const int8x16_t rhs_vec_5_0 = vshlq_n_s8(rhs_raw_vec_5, 4); - const int8x16_t rhs_vec_6_0 = vshlq_n_s8(rhs_raw_vec_6, 4); - const int8x16_t rhs_vec_7_0 = vshlq_n_s8(rhs_raw_vec_7, 4); // High nibble const int8x16_t rhs_vec_0_1 = vandq_s8(rhs_raw_vec_0, nibble_mask); const int8x16_t rhs_vec_1_1 = vandq_s8(rhs_raw_vec_1, nibble_mask); const int8x16_t rhs_vec_2_1 = vandq_s8(rhs_raw_vec_2, nibble_mask); const int8x16_t rhs_vec_3_1 = vandq_s8(rhs_raw_vec_3, nibble_mask); - const int8x16_t rhs_vec_4_1 = vandq_s8(rhs_raw_vec_4, nibble_mask); - const int8x16_t rhs_vec_5_1 = vandq_s8(rhs_raw_vec_5, nibble_mask); - const int8x16_t rhs_vec_6_1 = vandq_s8(rhs_raw_vec_6, nibble_mask); - const int8x16_t rhs_vec_7_1 = vandq_s8(rhs_raw_vec_7, nibble_mask); const int8x16_t lhs_vec_0 = vld1q_s8((const int8_t*)(lhs_ptr + 0)); const int8x16_t lhs_vec_1 = vld1q_s8((const int8_t*)(lhs_ptr + 16)); - const int8x16_t lhs_vec_2 = vld1q_s8((const int8_t*)(lhs_ptr + 32)); - const int8x16_t lhs_vec_3 = vld1q_s8((const int8_t*)(lhs_ptr + 48)); - lhs_ptr += 64; - rhs_ptr += 128; + lhs_ptr += 32; + rhs_ptr += 64; int8x16_t t; @@ -165,19 +161,6 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod( t = vcombine_s8(vget_high_s8(lhs_vec_1), vget_high_s8(lhs_vec_1)); iacc0011 = vdotq_s32(iacc0011, rhs_vec_2_1, t); iacc2233 = vdotq_s32(iacc2233, rhs_vec_3_1, t); - - t = vcombine_s8(vget_low_s8(lhs_vec_2), vget_low_s8(lhs_vec_2)); - iacc0011 = vdotq_s32(iacc0011, rhs_vec_4_0, t); - iacc2233 = vdotq_s32(iacc2233, rhs_vec_5_0, t); - t = vcombine_s8(vget_high_s8(lhs_vec_2), vget_high_s8(lhs_vec_2)); - iacc0011 = vdotq_s32(iacc0011, rhs_vec_6_0, t); - iacc2233 = vdotq_s32(iacc2233, rhs_vec_7_0, t); - t = vcombine_s8(vget_low_s8(lhs_vec_3), vget_low_s8(lhs_vec_3)); - iacc0011 = vdotq_s32(iacc0011, rhs_vec_4_1, t); - iacc2233 = vdotq_s32(iacc2233, rhs_vec_5_1, t); - t = vcombine_s8(vget_high_s8(lhs_vec_3), vget_high_s8(lhs_vec_3)); - iacc0011 = vdotq_s32(iacc0011, rhs_vec_6_1, t); - iacc2233 = vdotq_s32(iacc2233, rhs_vec_7_1, t); } int32x4_t iacc = vpaddq_s32(iacc0011, iacc2233); @@ -212,7 +195,21 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod( main_acc = vmaxq_f32(main_acc, vmin_f32); main_acc = vminq_f32(main_acc, vmax_f32); - vst1q_f32((float*)((uint8_t*)dst + col_idx * sizeof(float) + row_idx * dst_stride_row), main_acc); + if (col_idx + kai_nr <= n) { + vst1q_f32((float*)((uint8_t*)dst + col_idx * sizeof(float) + row_idx * dst_stride_row), main_acc); + } else { + size_t leftover = n % kai_nr; + *(float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc, 0); + if (leftover > 1) { + *(float*)((uint8_t*)dst + (col_idx + 1) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc, 1); + } + if (leftover > 2) { + *(float*)((uint8_t*)dst + (col_idx + 2) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc, 2); + } + } } lhs_ptr_start += lhs_packed_stride; } diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h similarity index 77% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h index 291da6a9..cd9e64c3 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h @@ -22,7 +22,7 @@ extern "C" { * * @return the m step value */ -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void); +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void); /** * @brief Function to get the n step value. @@ -31,39 +31,39 @@ size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod * * @return the n step */ -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void); +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void); /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with - * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function + * the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel * * @return the mr value */ -size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the nr value */ -size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the kr value */ -size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void); /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the sr value */ -size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -72,55 +72,55 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(voi * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. * * @param[in] m_idx Row index in the LHS matrix (not packed). - * @param[in] k Total number of columns in the LHS matrix (not packed). It must be a multiple of 64. + * @param[in] k Total number of columns in the LHS matrix (not packed). * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, - * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. * - * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. - * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be a multiple of 64. + * @param[in] n_idx Row index in the RHS matrix (not packed). + * @param[in] k The common dimension between the LHS and RHS matrix (K). * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix * * @param[in] m_idx Row index in the DST matrix. * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. - * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * @param[in] dst_stride The number of bytes in in each row of the DST matrix * * return the DST offset in bytes */ -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride); /** - * @brief Function to query the size in bytes for the constant workspace. + * @brief Function to query the size in bytes for the destination matrix. * * @param[in] m Number of rows in the destination (DST) matrix * @param[in] n Number of columns in the destination (DST) matrix */ -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n); /** * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * - * LHS matrix: Signed 8-bit quantized asymmitric per-row (qai8dx) and packed + * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsu4cx) and packed. * Output tile: (rows x cols) = 1 x 4 * Accumulation performed in a single for loop: 64 * Instruction used: dotprod * * @param[in] m The number of output rows written. - * @param[in] n The number of output columns written. It must be a multiple of 4. - * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 64. + * @param[in] n The number of output columns written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. * @param[in] lhs_packed The LHS matrix packed. * When the activation are dynamically quantized, you can obtain this matrix * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs @@ -133,7 +133,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotpr * @param[in] scalar_min Min value used to clamp the final result. * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x64_neon_dotprod( +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c index bbf44936..e11804a3 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c @@ -16,12 +16,35 @@ static const size_t kai_mr = 1; static const size_t kai_nr = 8; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_k0 = 32; +static const size_t kai_k0 = kai_kr * kai_sr; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void) { return kai_m_step; } @@ -48,20 +71,14 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(voi size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); - const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); - return (m_idx / kai_m_step) * lhs_packed_stride; + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); - // Temporary assert - KAI_ASSERT((k % kai_k0) == 0); - - const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - - return (n_idx / kai_n_step) * rhs_packed_stride; + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); } size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( @@ -73,9 +90,6 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dot } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t m, size_t n) { - // Temporary assert - KAI_ASSERT((n % kai_nr) == 0); - return m * n * sizeof(float); } @@ -83,9 +97,6 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_DOTPROD) - // Temporary asserts - KAI_ASSERT(n % kai_nr == 0); - KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { @@ -95,7 +106,8 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( const size_t num_rows = m; const size_t num_cols = n; - const size_t lhs_packed_stride = kai_mr * (k + sizeof(float) + sizeof(float)); + const size_t lhs_packed_stride = kai_lhs_packed_stride(k); + const size_t k_internal = kai_k_roundedup(k); const int8x16_t nibble_mask = vdupq_n_s8(0xF0); @@ -112,7 +124,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( int32x4_t iacc4455 = vdupq_n_s32(0); int32x4_t iacc6677 = vdupq_n_s32(0); - for (size_t b = 0; b < k; b += kai_k0) { + for (size_t b = 0; b < k_internal; b += kai_k0) { // Set up RHS const int8x16_t rhs_raw_vec_0 = vld1q_s8((const int8_t*)(rhs_ptr + 0)); const int8x16_t rhs_raw_vec_1 = vld1q_s8((const int8_t*)(rhs_ptr + 16)); @@ -217,8 +229,40 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( main_acc1 = vmaxq_f32(main_acc1, vmin_f32); main_acc1 = vminq_f32(main_acc1, vmax_f32); - vst1q_f32((float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + row_idx * dst_stride_row), main_acc0); - vst1q_f32((float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + row_idx * dst_stride_row), main_acc1); + if (col_idx + kai_nr <= n) { + vst1q_f32( + (float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + row_idx * dst_stride_row), main_acc0); + vst1q_f32( + (float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + row_idx * dst_stride_row), main_acc1); + } else { + size_t leftover = n % kai_nr; + *(float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc0, 0); + if (leftover > 1) { + *(float*)((uint8_t*)dst + (col_idx + 1) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc0, 1); + } + if (leftover > 2) { + *(float*)((uint8_t*)dst + (col_idx + 2) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc0, 2); + } + if (leftover > 3) { + *(float*)((uint8_t*)dst + (col_idx + 3) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc0, 3); + } + if (leftover > 4) { + *(float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc1, 0); + } + if (leftover > 5) { + *(float*)((uint8_t*)dst + (col_idx + 5) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc1, 1); + } + if (leftover > 6) { + *(float*)((uint8_t*)dst + (col_idx + 6) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc1, 2); + } + } } lhs_ptr_start += lhs_packed_stride; } diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h index 275fcc69..06b58e26 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h @@ -35,7 +35,7 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with - * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function + * the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel * * @return the mr value */ @@ -43,7 +43,7 @@ size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(voi /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the nr value */ @@ -51,7 +51,7 @@ size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(voi /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the kr value */ @@ -59,7 +59,7 @@ size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(voi /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the sr value */ @@ -72,7 +72,7 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(voi * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. * * @param[in] m_idx Row index in the LHS matrix (not packed). - * @param[in] k Total number of columns in the LHS matrix (not packed). It must be a multiple of 64. + * @param[in] k Total number of columns in the LHS matrix (not packed). * * return the offset in bytes to the packed LHS matrix */ @@ -80,10 +80,10 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_n /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, - * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. * - * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. - * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be a multiple of 64. + * @param[in] n_idx Row index in the RHS matrix (not packed). + * @param[in] k The common dimension between the LHS and RHS matrix (K). * * return the offset in bytes to the packed RHS matrix */ @@ -93,8 +93,8 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_n * @brief Function to calculate the offset in bytes for the DST matrix * * @param[in] m_idx Row index in the DST matrix. - * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. - * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * @param[in] n_idx Column index in the DST matrix. It must be multiple of 8. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix * * return the DST offset in bytes */ @@ -102,7 +102,7 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dot size_t m_idx, size_t n_idx, size_t dst_stride); /** - * @brief Function to query the size in bytes for the constant workspace. + * @brief Function to query the size in bytes for the destination matrix. * * @param[in] m Number of rows in the destination (DST) matrix * @param[in] n Number of columns in the destination (DST) matrix @@ -112,15 +112,15 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotpr /** * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * - * LHS matrix: Signed 8-bit quantized asymmitric per-row (qai8dx) and packed + * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsu4cx) and packed. * Output tile: (rows x cols) = 1 x 8 * Accumulation performed in a single for loop: 64 * Instruction used: dotprod * * @param[in] m The number of output rows written. - * @param[in] n The number of output columns written. It must be a multiple of 4. - * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 64. + * @param[in] n The number of output columns written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. * @param[in] lhs_packed The LHS matrix packed. * When the activation are dynamically quantized, you can obtain this matrix * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c index 1a3359c4..2671a2f8 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c @@ -5,11 +5,11 @@ // #include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h" -#include "kai_common.h" - #include #include +#include "kai_common.h" + static const size_t kai_m_step = 4; static const size_t kai_n_step = 4; static const size_t kai_mr = 4; @@ -22,6 +22,29 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) { return kai_m_step; } @@ -48,20 +71,14 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); - const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); - return (m_idx / kai_m_step) * lhs_packed_stride; + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); - // Temporary assert - KAI_ASSERT((k % kai_k0) == 0); - - const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - - return (n_idx / kai_n_step) * rhs_packed_stride; + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); } size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm( @@ -73,9 +90,6 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8m } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t m, size_t n) { - // Temporary assert - KAI_ASSERT((n % kai_nr) == 0); - return m * n * sizeof(float); } @@ -83,142 +97,173 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) - // Temporary asserts - KAI_ASSERT(n % kai_nr == 0); - KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { return; } - size_t num_blocks = k / 32; + const size_t k_internal = kai_k_roundedup(k); + + size_t num_blocks = k_internal / 32; float clamp_vals[2] = {scalar_min, scalar_max}; - __asm__ __volatile__("mov x26, #0x80\n" - "mov x20, #0x20\n" - "movi v4.16b, #0xf0\n" - "mov x25, %x[m]\n" - "madd x26, %x[num_blocks], x26, x20\n" - "cbz x25, 5f\n" - "1:" // Row loop - "mov x24, %x[rhs_packed]\n" - "mov x23, %x[n]\n" - "add x22, %x[dst], %x[dst_stride_row], LSL #2\n" - "2:" // Column loop - "movi v3.4s, #0x0\n" - "movi v2.4s, #0x0\n" - "mov x21, %x[lhs_packed]\n" - "mov x20, %x[num_blocks]\n" - "movi v1.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "3:" // Block loop - "ldr q31, [x24, #0x0]\n" - "ldr q30, [x24, #0x10]\n" - "subs x20, x20, #0x1\n" - "ldr q29, [x21, #0x0]\n" - "ldr q28, [x21, #0x10]\n" - "ldr q27, [x24, #0x20]\n" - "ldr q26, [x24, #0x30]\n" - "add x24, x24, #0x40\n" - "ldr q25, [x21, #0x20]\n" - "ldr q24, [x21, #0x30]\n" - "shl v23.16b, v31.16b, #0x4\n" - "shl v22.16b, v30.16b, #0x4\n" - "ldr q21, [x21, #0x40]\n" - "ldr q20, [x21, #0x50]\n" - "and v31.16b, v31.16b, v4.16b\n" - "and v30.16b, v30.16b, v4.16b\n" - "ldr q19, [x21, #0x60]\n" - "ldr q18, [x21, #0x70]\n" - "shl v17.16b, v27.16b, #0x4\n" - "shl v16.16b, v26.16b, #0x4\n" - ".inst 0x4e97a7a3 // smmla v3.4s, v29.16b, v23.16b\n" - ".inst 0x4e96a7a2 // smmla v2.4s, v29.16b, v22.16b\n" - "and v27.16b, v27.16b, v4.16b\n" - "add x21, x21, #0x80\n" - ".inst 0x4e97a781 // smmla v1.4s, v28.16b, v23.16b\n" - ".inst 0x4e96a780 // smmla v0.4s, v28.16b, v22.16b\n" - "and v26.16b, v26.16b, v4.16b\n" - ".inst 0x4e91a723 // smmla v3.4s, v25.16b, v17.16b\n" - ".inst 0x4e90a722 // smmla v2.4s, v25.16b, v16.16b\n" - ".inst 0x4e91a701 // smmla v1.4s, v24.16b, v17.16b\n" - ".inst 0x4e90a700 // smmla v0.4s, v24.16b, v16.16b\n" - ".inst 0x4e9fa6a3 // smmla v3.4s, v21.16b, v31.16b\n" - ".inst 0x4e9ea6a2 // smmla v2.4s, v21.16b, v30.16b\n" - ".inst 0x4e9fa681 // smmla v1.4s, v20.16b, v31.16b\n" - ".inst 0x4e9ea680 // smmla v0.4s, v20.16b, v30.16b\n" - ".inst 0x4e9ba663 // smmla v3.4s, v19.16b, v27.16b\n" - ".inst 0x4e9aa662 // smmla v2.4s, v19.16b, v26.16b\n" - ".inst 0x4e9ba641 // smmla v1.4s, v18.16b, v27.16b\n" - ".inst 0x4e9aa640 // smmla v0.4s, v18.16b, v26.16b\n" - "bgt 3b\n" - "ldr q18, [x24, #0x0]\n" - "ldr q17, [x21, #0x0]\n" - "uzp1 v26.2d, v3.2d, v2.2d\n" - "uzp2 v25.2d, v3.2d, v2.2d\n" - "ldr q24, [x24, #0x10]\n" - "ldr q16, [x21, #0x10]\n" - "uzp1 v23.2d, v1.2d, v0.2d\n" - "uzp2 v22.2d, v1.2d, v0.2d\n" - "ld1r { v21.4s }, [%x[clamp_vals]]\n" - "add x21, %x[clamp_vals], #0x4\n" - "mov x20, %x[dst]\n" - "ld1r { v20.4s }, [x21]\n" - "mla v26.4s, v18.4s, v17.s[0]\n" - "mla v25.4s, v18.4s, v17.s[1]\n" - "cmp x25, #0x1\n" - "add x24, x24, #0x20\n" - "mla v23.4s, v18.4s, v17.s[2]\n" - "mla v22.4s, v18.4s, v17.s[3]\n" - "fmul v19.4s, v24.4s, v16.s[0]\n" - "fmul v18.4s, v24.4s, v16.s[1]\n" - "fmul v17.4s, v24.4s, v16.s[2]\n" - "fmul v16.4s, v24.4s, v16.s[3]\n" - "scvtf v26.4s, v26.4s\n" - "scvtf v25.4s, v25.4s\n" - "scvtf v23.4s, v23.4s\n" - "scvtf v22.4s, v22.4s\n" - "fmul v26.4s, v26.4s, v19.4s\n" - "fmul v25.4s, v25.4s, v18.4s\n" - "fmul v23.4s, v23.4s, v17.4s\n" - "fmul v22.4s, v22.4s, v16.4s\n" - "fmax v26.4s, v26.4s, v21.4s\n" - "fmax v25.4s, v25.4s, v21.4s\n" - "fmax v23.4s, v23.4s, v21.4s\n" - "fmax v22.4s, v22.4s, v21.4s\n" - "fmin v26.4s, v26.4s, v20.4s\n" - "fmin v25.4s, v25.4s, v20.4s\n" - "fmin v23.4s, v23.4s, v20.4s\n" - "fmin v22.4s, v22.4s, v20.4s\n" - "str q26, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 4f\n" - "cmp x25, #0x2\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 4f\n" - "cmp x25, #0x3\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 4f\n" - "str q22, [x20, #0x0]\n" - "4:" // Accumulator store skip - "subs x23, x23, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bne 2b\n" - "subs x25, x25, #0x4\n" - "add %x[lhs_packed], %x[lhs_packed], x26\n" - "mov %x[dst], x22\n" - "bgt 1b\n" - "5:" // Row loop skip - : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) - : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), - [num_blocks] "r"(num_blocks), [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", - "x23", "x24", "x25", "x26"); + __asm__ __volatile__( + "mov x28, #0x80\n" + "mov x20, #0x20\n" + "movi v4.16b, #0xf0\n" + "mov x27, %x[m]\n" + "madd x28, %x[num_blocks], x28, x20\n" + "cbz x27, 8f\n" + "1:" // Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "2:" // Column loop + "movi v3.4s, #0x0\n" + "movi v2.4s, #0x0\n" + "mov x21, %x[lhs_packed]\n" + "mov x20, %x[num_blocks]\n" + "movi v1.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "3:" // Block loop + "ldr q31, [x26, #0x0]\n" + "ldr q30, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q29, [x21, #0x0]\n" + "ldr q28, [x21, #0x10]\n" + "ldr q27, [x26, #0x20]\n" + "ldr q26, [x26, #0x30]\n" + "add x26, x26, #0x40\n" + "ldr q25, [x21, #0x20]\n" + "ldr q24, [x21, #0x30]\n" + "shl v23.16b, v31.16b, #0x4\n" + "shl v22.16b, v30.16b, #0x4\n" + "ldr q21, [x21, #0x40]\n" + "ldr q20, [x21, #0x50]\n" + "and v31.16b, v31.16b, v4.16b\n" + "and v30.16b, v30.16b, v4.16b\n" + "ldr q19, [x21, #0x60]\n" + "ldr q18, [x21, #0x70]\n" + "shl v17.16b, v27.16b, #0x4\n" + "shl v16.16b, v26.16b, #0x4\n" + ".inst 0x4e97a7a3 // smmla v3.4s, v29.16b, v23.16b\n" + ".inst 0x4e96a7a2 // smmla v2.4s, v29.16b, v22.16b\n" + "and v27.16b, v27.16b, v4.16b\n" + "add x21, x21, #0x80\n" + ".inst 0x4e97a781 // smmla v1.4s, v28.16b, v23.16b\n" + ".inst 0x4e96a780 // smmla v0.4s, v28.16b, v22.16b\n" + "and v26.16b, v26.16b, v4.16b\n" + ".inst 0x4e91a723 // smmla v3.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a722 // smmla v2.4s, v25.16b, v16.16b\n" + ".inst 0x4e91a701 // smmla v1.4s, v24.16b, v17.16b\n" + ".inst 0x4e90a700 // smmla v0.4s, v24.16b, v16.16b\n" + ".inst 0x4e9fa6a3 // smmla v3.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a2 // smmla v2.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa681 // smmla v1.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea680 // smmla v0.4s, v20.16b, v30.16b\n" + ".inst 0x4e9ba663 // smmla v3.4s, v19.16b, v27.16b\n" + ".inst 0x4e9aa662 // smmla v2.4s, v19.16b, v26.16b\n" + ".inst 0x4e9ba641 // smmla v1.4s, v18.16b, v27.16b\n" + ".inst 0x4e9aa640 // smmla v0.4s, v18.16b, v26.16b\n" + "bgt 3b\n" + "ldr q18, [x26, #0x0]\n" + "ldr q17, [x21, #0x0]\n" + "uzp1 v26.2d, v3.2d, v2.2d\n" + "uzp2 v25.2d, v3.2d, v2.2d\n" + "ldr q24, [x26, #0x10]\n" + "ldr q16, [x21, #0x10]\n" + "uzp1 v23.2d, v1.2d, v0.2d\n" + "uzp2 v22.2d, v1.2d, v0.2d\n" + "ld1r { v21.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x4\n" + "ld1r { v20.4s }, [x20]\n" + "mla v26.4s, v18.4s, v17.s[0]\n" + "mla v25.4s, v18.4s, v17.s[1]\n" + "add x26, x26, #0x20\n" + "mla v23.4s, v18.4s, v17.s[2]\n" + "mla v22.4s, v18.4s, v17.s[3]\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v26.4s, v26.4s\n" + "scvtf v25.4s, v25.4s\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v22.4s, v22.4s\n" + "fmul v26.4s, v26.4s, v19.4s\n" + "fmul v25.4s, v25.4s, v18.4s\n" + "fmul v23.4s, v23.4s, v17.4s\n" + "fmul v22.4s, v22.4s, v16.4s\n" + "fmax v26.4s, v26.4s, v21.4s\n" + "fmax v25.4s, v25.4s, v21.4s\n" + "fmax v23.4s, v23.4s, v21.4s\n" + "fmax v22.4s, v22.4s, v21.4s\n" + "fmin v26.4s, v26.4s, v20.4s\n" + "fmin v25.4s, v25.4s, v20.4s\n" + "fmin v23.4s, v23.4s, v20.4s\n" + "fmin v22.4s, v22.4s, v20.4s\n" + "bge 6f\n" + "mov x23, %x[dst]\n" + "cmp x27, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GE\n" + "cmp x27, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GE\n" + "cmp x27, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GE\n" + "tbz x25, #1, 4f\n" + "str d22, [x20], #0x8\n" + "str d23, [x21], #0x8\n" + "str d25, [x22], #0x8\n" + "str d26, [x23], #0x8\n" + "tbz x25, #0, 5f\n" + "st1 { v22.s }[2], [x20]\n" + "st1 { v23.s }[2], [x21]\n" + "st1 { v25.s }[2], [x22]\n" + "st1 { v26.s }[2], [x23]\n" + "b 5f\n" + "4:" // Output block 0: partial_1_0 + "str s22, [x20, #0x0]\n" + "str s23, [x21, #0x0]\n" + "str s25, [x22, #0x0]\n" + "str s26, [x23, #0x0]\n" + "5:" // Output block 0: Done + "b 7f\n" + "6:" // Full output + "mov x20, %x[dst]\n" + "cmp x27, #0x1\n" + "str q26, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 7f\n" + "cmp x27, #0x2\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 7f\n" + "cmp x27, #0x3\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 7f\n" + "str q22, [x20, #0x0]\n" + "7:" // Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "subs x27, x27, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x28\n" + "mov %x[dst], x24\n" + "bgt 1b\n" + "8:" // Row loop skip + : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) + : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), + [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", + "x28"); #else KAI_ASSERT(false); KAI_UNUSED(m); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h index 387c25a6..0f339b33 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h @@ -35,7 +35,7 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(vo /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with - * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function + * the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel * * @return the mr value */ @@ -43,7 +43,7 @@ size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the nr value */ @@ -51,7 +51,7 @@ size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the kr value */ @@ -59,7 +59,7 @@ size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the sr value */ @@ -72,7 +72,7 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. * * @param[in] m_idx Row index in the LHS matrix (not packed). - * @param[in] k Total number of columns in the LHS matrix (not packed). It must be a multiple of 64. + * @param[in] k Total number of columns in the LHS matrix (not packed). * * return the offset in bytes to the packed LHS matrix */ @@ -80,10 +80,10 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_n /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, - * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. * * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. - * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be a multiple of 64. + * @param[in] k The common dimension between the LHS and RHS matrix (K). * * return the offset in bytes to the packed RHS matrix */ @@ -92,9 +92,9 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_n /** * @brief Function to calculate the offset in bytes for the DST matrix * - * @param[in] m_idx Row index in the DST matrix. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] m_idx Row index in the DST matrix. It must be a multiple of 4. * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. - * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * @param[in] dst_stride The number of bytes in in each row of the DST matrix * * return the DST offset in bytes */ @@ -102,10 +102,10 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8m size_t m_idx, size_t n_idx, size_t dst_stride); /** - * @brief Function to query the size in bytes for the constant workspace. + * @brief Function to query the size in bytes for the destination matrix. * - * @param[in] m Number of rows in the destination (DST) matrix. It must be either 1, 2, 3, or any multiple of 4. - * @param[in] n Number of columns in the destination (DST) matrix. It must be a multiple 4. + * @param[in] m Number of rows in the destination (DST) matrix. + * @param[in] n Number of columns in the destination (DST) matrix. */ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t m, size_t n); @@ -118,9 +118,9 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm( * Accumulation performed in a single for loop: 32 * Instruction used: i8mm * - * @param[in] m The number of output rows written. It must be either 1, 2, 3, or any multiple of 4. - * @param[in] n The number of output columns written. It must be a multiple of 4. - * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 32. + * @param[in] m The number of output rows written. + * @param[in] n The number of output columns written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. * @param[in] lhs_packed The LHS matrix packed. * When the activation are dynamically quantized, you can obtain this matrix * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c index 097bd2e9..62d1611c 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c @@ -5,11 +5,11 @@ // #include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h" -#include "kai_common.h" - #include #include +#include "kai_common.h" + static const size_t kai_m_step = 8; static const size_t kai_n_step = 4; static const size_t kai_mr = 4; @@ -22,6 +22,29 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) { return kai_m_step; } @@ -48,20 +71,14 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); - const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); - return (m_idx / kai_m_step) * lhs_packed_stride; + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); - // Temporary assert - KAI_ASSERT((k % kai_k0) == 0); - - const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - - return (n_idx / kai_n_step) * rhs_packed_stride; + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); } size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm( @@ -73,9 +90,6 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8m } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n) { - // Temporary assert - KAI_ASSERT((n % kai_nr) == 0); - return m * n * sizeof(float); } @@ -83,318 +97,391 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) - // Temporary asserts - KAI_ASSERT(n % kai_nr == 0); - KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { return; } - size_t num_blocks = k / 32; + const size_t k_internal = kai_k_roundedup(k); + + size_t num_blocks = k_internal / 32; float clamp_vals[2] = {scalar_min, scalar_max}; - __asm__ __volatile__("mov x27, %x[m]\n" - "mov x26, #0x80\n" - "movi v11.16b, #0xf0\n" - "mov x20, #0x20\n" - "cmp x27, #0x8\n" - "madd x26, %x[num_blocks], x26, x20\n" - "blt 4f\n" - "1:" // Row loop - "mov x24, %x[rhs_packed]\n" - "mov x23, %x[n]\n" - "add x22, %x[dst], %x[dst_stride_row], LSL #3\n" - "2:" // Column loop - "mov x25, %x[lhs_packed]\n" - "movi v10.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "mov x21, %x[num_blocks]\n" - "movi v8.4s, #0x0\n" - "movi v7.4s, #0x0\n" - "movi v6.4s, #0x0\n" - "movi v5.4s, #0x0\n" - "movi v4.4s, #0x0\n" - "movi v3.4s, #0x0\n" - "add x20, x25, x26\n" - "3:" // Block loop - "ldr q2, [x24, #0x0]\n" - "ldr q1, [x24, #0x10]\n" - "subs x21, x21, #0x1\n" - "ldr q20, [x25, #0x0]\n" - "ldr q19, [x25, #0x10]\n" - "ldr q18, [x20, #0x0]\n" - "ldr q0, [x20, #0x10]\n" - "ldr q31, [x24, #0x20]\n" - "ldr q30, [x24, #0x30]\n" - "shl v17.16b, v2.16b, #0x4\n" - "shl v16.16b, v1.16b, #0x4\n" - "ldr q29, [x25, #0x20]\n" - "ldr q28, [x25, #0x30]\n" - "and v2.16b, v2.16b, v11.16b\n" - "and v1.16b, v1.16b, v11.16b\n" - "ldr q27, [x20, #0x20]\n" - "ldr q26, [x20, #0x30]\n" - "add x24, x24, #0x40\n" - "ldr q25, [x25, #0x40]\n" - "ldr q24, [x25, #0x50]\n" - ".inst 0x4e91a68a // smmla v10.4s, v20.16b, v17.16b\n" - ".inst 0x4e90a689 // smmla v9.4s, v20.16b, v16.16b\n" - "ldr q23, [x20, #0x40]\n" - "ldr q22, [x20, #0x50]\n" - ".inst 0x4e91a668 // smmla v8.4s, v19.16b, v17.16b\n" - ".inst 0x4e90a667 // smmla v7.4s, v19.16b, v16.16b\n" - "ldr q21, [x25, #0x60]\n" - "ldr q20, [x25, #0x70]\n" - ".inst 0x4e91a646 // smmla v6.4s, v18.16b, v17.16b\n" - ".inst 0x4e90a645 // smmla v5.4s, v18.16b, v16.16b\n" - "ldr q19, [x20, #0x60]\n" - "ldr q18, [x20, #0x70]\n" - ".inst 0x4e91a404 // smmla v4.4s, v0.16b, v17.16b\n" - ".inst 0x4e90a403 // smmla v3.4s, v0.16b, v16.16b\n" - "shl v17.16b, v31.16b, #0x4\n" - "shl v16.16b, v30.16b, #0x4\n" - "add x25, x25, #0x80\n" - "add x20, x20, #0x80\n" - "and v31.16b, v31.16b, v11.16b\n" - "and v30.16b, v30.16b, v11.16b\n" - ".inst 0x4e91a7aa // smmla v10.4s, v29.16b, v17.16b\n" - ".inst 0x4e90a7a9 // smmla v9.4s, v29.16b, v16.16b\n" - ".inst 0x4e91a788 // smmla v8.4s, v28.16b, v17.16b\n" - ".inst 0x4e90a787 // smmla v7.4s, v28.16b, v16.16b\n" - ".inst 0x4e91a766 // smmla v6.4s, v27.16b, v17.16b\n" - ".inst 0x4e90a765 // smmla v5.4s, v27.16b, v16.16b\n" - ".inst 0x4e91a744 // smmla v4.4s, v26.16b, v17.16b\n" - ".inst 0x4e90a743 // smmla v3.4s, v26.16b, v16.16b\n" - ".inst 0x4e82a72a // smmla v10.4s, v25.16b, v2.16b\n" - ".inst 0x4e81a729 // smmla v9.4s, v25.16b, v1.16b\n" - ".inst 0x4e82a708 // smmla v8.4s, v24.16b, v2.16b\n" - ".inst 0x4e81a707 // smmla v7.4s, v24.16b, v1.16b\n" - ".inst 0x4e82a6e6 // smmla v6.4s, v23.16b, v2.16b\n" - ".inst 0x4e81a6e5 // smmla v5.4s, v23.16b, v1.16b\n" - ".inst 0x4e82a6c4 // smmla v4.4s, v22.16b, v2.16b\n" - ".inst 0x4e81a6c3 // smmla v3.4s, v22.16b, v1.16b\n" - ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" - ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" - ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" - ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" - ".inst 0x4e9fa666 // smmla v6.4s, v19.16b, v31.16b\n" - ".inst 0x4e9ea665 // smmla v5.4s, v19.16b, v30.16b\n" - ".inst 0x4e9fa644 // smmla v4.4s, v18.16b, v31.16b\n" - ".inst 0x4e9ea643 // smmla v3.4s, v18.16b, v30.16b\n" - "bgt 3b\n" - "ldr q1, [x24, #0x0]\n" - "ldr q16, [x25, #0x0]\n" - "uzp1 v0.2d, v10.2d, v9.2d\n" - "uzp2 v31.2d, v10.2d, v9.2d\n" - "ldr q30, [x24, #0x10]\n" - "ldr q29, [x25, #0x10]\n" - "uzp1 v28.2d, v8.2d, v7.2d\n" - "uzp2 v27.2d, v8.2d, v7.2d\n" - "ldr q17, [x20, #0x0]\n" - "ldr q26, [x20, #0x10]\n" - "uzp1 v25.2d, v6.2d, v5.2d\n" - "uzp2 v24.2d, v6.2d, v5.2d\n" - "ld1r { v23.4s }, [%x[clamp_vals]]\n" - "mla v0.4s, v1.4s, v16.s[0]\n" - "mla v31.4s, v1.4s, v16.s[1]\n" - "uzp1 v22.2d, v4.2d, v3.2d\n" - "mla v28.4s, v1.4s, v16.s[2]\n" - "mla v27.4s, v1.4s, v16.s[3]\n" - "fmul v21.4s, v30.4s, v29.s[0]\n" - "add x20, %x[clamp_vals], #0x4\n" - "ld1r { v20.4s }, [x20]\n" - "uzp2 v19.2d, v4.2d, v3.2d\n" - "mla v25.4s, v1.4s, v17.s[0]\n" - "mla v24.4s, v1.4s, v17.s[1]\n" - "fmul v16.4s, v30.4s, v29.s[1]\n" - "fmul v18.4s, v30.4s, v29.s[2]\n" - "mla v22.4s, v1.4s, v17.s[2]\n" - "mov x20, %x[dst]\n" - "scvtf v0.4s, v0.4s\n" - "scvtf v31.4s, v31.4s\n" - "subs x23, x23, #0x4\n" - "add x24, x24, #0x20\n" - "scvtf v28.4s, v28.4s\n" - "scvtf v27.4s, v27.4s\n" - "mla v19.4s, v1.4s, v17.s[3]\n" - "add %x[dst], %x[dst], #0x10\n" - "fmul v17.4s, v30.4s, v29.s[3]\n" - "scvtf v25.4s, v25.4s\n" - "fmul v0.4s, v0.4s, v21.4s\n" - "fmul v31.4s, v31.4s, v16.4s\n" - "fmul v16.4s, v30.4s, v26.s[0]\n" - "fmul v28.4s, v28.4s, v18.4s\n" - "scvtf v24.4s, v24.4s\n" - "fmul v18.4s, v30.4s, v26.s[1]\n" - "fmul v27.4s, v27.4s, v17.4s\n" - "scvtf v22.4s, v22.4s\n" - "fmul v17.4s, v30.4s, v26.s[2]\n" - "fmax v0.4s, v0.4s, v23.4s\n" - "fmul v25.4s, v25.4s, v16.4s\n" - "scvtf v19.4s, v19.4s\n" - "fmul v16.4s, v30.4s, v26.s[3]\n" - "fmax v31.4s, v31.4s, v23.4s\n" - "fmul v24.4s, v24.4s, v18.4s\n" - "fmax v28.4s, v28.4s, v23.4s\n" - "fmul v22.4s, v22.4s, v17.4s\n" - "fmin v0.4s, v0.4s, v20.4s\n" - "fmax v27.4s, v27.4s, v23.4s\n" - "fmul v19.4s, v19.4s, v16.4s\n" - "fmin v31.4s, v31.4s, v20.4s\n" - "fmax v25.4s, v25.4s, v23.4s\n" - "str q0, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "fmin v28.4s, v28.4s, v20.4s\n" - "fmax v24.4s, v24.4s, v23.4s\n" - "fmin v27.4s, v27.4s, v20.4s\n" - "fmax v22.4s, v22.4s, v23.4s\n" - "str q31, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "fmin v25.4s, v25.4s, v20.4s\n" - "fmax v19.4s, v19.4s, v23.4s\n" - "str q28, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "fmin v24.4s, v24.4s, v20.4s\n" - "str q27, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "fmin v22.4s, v22.4s, v20.4s\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "fmin v19.4s, v19.4s, v20.4s\n" - "str q24, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q22, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q19, [x20, #0x0]\n" - "bne 2b\n" - "mov x20, #0x2\n" - "sub x27, x27, #0x8\n" - "cmp x27, #0x8\n" - "mov %x[dst], x22\n" - "madd %x[lhs_packed], x20, x26, %x[lhs_packed]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x27, 9f\n" - "5:" // Row tail: Row loop - "mov x24, %x[rhs_packed]\n" - "mov x23, %x[n]\n" - "add x22, %x[dst], %x[dst_stride_row], LSL #2\n" - "6:" // Row tail: Column loop - "movi v10.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "mov x25, %x[lhs_packed]\n" - "mov x20, %x[num_blocks]\n" - "movi v8.4s, #0x0\n" - "movi v7.4s, #0x0\n" - "7:" // Row tail: Block loop - "ldr q31, [x24, #0x0]\n" - "ldr q30, [x24, #0x10]\n" - "subs x20, x20, #0x1\n" - "ldr q29, [x25, #0x0]\n" - "ldr q28, [x25, #0x10]\n" - "ldr q27, [x24, #0x20]\n" - "ldr q26, [x24, #0x30]\n" - "add x24, x24, #0x40\n" - "ldr q25, [x25, #0x20]\n" - "ldr q24, [x25, #0x30]\n" - "shl v23.16b, v31.16b, #0x4\n" - "shl v22.16b, v30.16b, #0x4\n" - "ldr q21, [x25, #0x40]\n" - "ldr q20, [x25, #0x50]\n" - "and v31.16b, v31.16b, v11.16b\n" - "and v30.16b, v30.16b, v11.16b\n" - "ldr q19, [x25, #0x60]\n" - "ldr q18, [x25, #0x70]\n" - "shl v17.16b, v27.16b, #0x4\n" - "shl v16.16b, v26.16b, #0x4\n" - ".inst 0x4e97a7aa // smmla v10.4s, v29.16b, v23.16b\n" - ".inst 0x4e96a7a9 // smmla v9.4s, v29.16b, v22.16b\n" - "and v27.16b, v27.16b, v11.16b\n" - "add x25, x25, #0x80\n" - ".inst 0x4e97a788 // smmla v8.4s, v28.16b, v23.16b\n" - ".inst 0x4e96a787 // smmla v7.4s, v28.16b, v22.16b\n" - "and v26.16b, v26.16b, v11.16b\n" - ".inst 0x4e91a72a // smmla v10.4s, v25.16b, v17.16b\n" - ".inst 0x4e90a729 // smmla v9.4s, v25.16b, v16.16b\n" - ".inst 0x4e91a708 // smmla v8.4s, v24.16b, v17.16b\n" - ".inst 0x4e90a707 // smmla v7.4s, v24.16b, v16.16b\n" - ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" - ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" - ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" - ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" - ".inst 0x4e9ba66a // smmla v10.4s, v19.16b, v27.16b\n" - ".inst 0x4e9aa669 // smmla v9.4s, v19.16b, v26.16b\n" - ".inst 0x4e9ba648 // smmla v8.4s, v18.16b, v27.16b\n" - ".inst 0x4e9aa647 // smmla v7.4s, v18.16b, v26.16b\n" - "bgt 7b\n" - "ldr q18, [x24, #0x0]\n" - "ldr q17, [x25, #0x0]\n" - "uzp1 v26.2d, v10.2d, v9.2d\n" - "uzp2 v25.2d, v10.2d, v9.2d\n" - "ldr q24, [x24, #0x10]\n" - "ldr q16, [x25, #0x10]\n" - "uzp1 v23.2d, v8.2d, v7.2d\n" - "uzp2 v22.2d, v8.2d, v7.2d\n" - "ld1r { v21.4s }, [%x[clamp_vals]]\n" - "add x21, %x[clamp_vals], #0x4\n" - "mov x20, %x[dst]\n" - "ld1r { v20.4s }, [x21]\n" - "mla v26.4s, v18.4s, v17.s[0]\n" - "mla v25.4s, v18.4s, v17.s[1]\n" - "cmp x27, #0x1\n" - "add x24, x24, #0x20\n" - "mla v23.4s, v18.4s, v17.s[2]\n" - "mla v22.4s, v18.4s, v17.s[3]\n" - "fmul v19.4s, v24.4s, v16.s[0]\n" - "fmul v18.4s, v24.4s, v16.s[1]\n" - "fmul v17.4s, v24.4s, v16.s[2]\n" - "fmul v16.4s, v24.4s, v16.s[3]\n" - "scvtf v26.4s, v26.4s\n" - "scvtf v25.4s, v25.4s\n" - "scvtf v23.4s, v23.4s\n" - "scvtf v22.4s, v22.4s\n" - "fmul v26.4s, v26.4s, v19.4s\n" - "fmul v25.4s, v25.4s, v18.4s\n" - "fmul v23.4s, v23.4s, v17.4s\n" - "fmul v22.4s, v22.4s, v16.4s\n" - "fmax v26.4s, v26.4s, v21.4s\n" - "fmax v25.4s, v25.4s, v21.4s\n" - "fmax v23.4s, v23.4s, v21.4s\n" - "fmax v22.4s, v22.4s, v21.4s\n" - "fmin v26.4s, v26.4s, v20.4s\n" - "fmin v25.4s, v25.4s, v20.4s\n" - "fmin v23.4s, v23.4s, v20.4s\n" - "fmin v22.4s, v22.4s, v20.4s\n" - "str q26, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 8f\n" - "cmp x27, #0x2\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 8f\n" - "cmp x27, #0x3\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 8f\n" - "str q22, [x20, #0x0]\n" - "8:" // Row tail: Accumulator store skip - "subs x23, x23, #0x4\n" - "add %x[dst], %x[dst], #0x10\n" - "bne 6b\n" - "subs x27, x27, #0x4\n" - "add %x[lhs_packed], %x[lhs_packed], x26\n" - "mov %x[dst], x22\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) - : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), - [num_blocks] "r"(num_blocks), [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27"); + __asm__ __volatile__( + "mov x12, %x[m]\n" + "mov x11, #0x80\n" + "movi v11.16b, #0xf0\n" + "mov x20, #0x20\n" + "cmp x12, #0x8\n" + "madd x11, %x[num_blocks], x11, x20\n" + "blt 8f\n" + "1:" // Row loop + "mov x10, %x[rhs_packed]\n" + "mov x9, %x[n]\n" + "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" + "2:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v10.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "mov x21, %x[num_blocks]\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "movi v4.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "add x20, x22, x11\n" + "3:" // Block loop + "ldr q2, [x10, #0x0]\n" + "ldr q1, [x10, #0x10]\n" + "subs x21, x21, #0x1\n" + "ldr q20, [x22, #0x0]\n" + "ldr q19, [x22, #0x10]\n" + "ldr q18, [x20, #0x0]\n" + "ldr q0, [x20, #0x10]\n" + "ldr q31, [x10, #0x20]\n" + "ldr q30, [x10, #0x30]\n" + "shl v17.16b, v2.16b, #0x4\n" + "shl v16.16b, v1.16b, #0x4\n" + "ldr q29, [x22, #0x20]\n" + "ldr q28, [x22, #0x30]\n" + "and v2.16b, v2.16b, v11.16b\n" + "and v1.16b, v1.16b, v11.16b\n" + "ldr q27, [x20, #0x20]\n" + "ldr q26, [x20, #0x30]\n" + "add x10, x10, #0x40\n" + "ldr q25, [x22, #0x40]\n" + "ldr q24, [x22, #0x50]\n" + ".inst 0x4e91a68a // smmla v10.4s, v20.16b, v17.16b\n" + ".inst 0x4e90a689 // smmla v9.4s, v20.16b, v16.16b\n" + "ldr q23, [x20, #0x40]\n" + "ldr q22, [x20, #0x50]\n" + ".inst 0x4e91a668 // smmla v8.4s, v19.16b, v17.16b\n" + ".inst 0x4e90a667 // smmla v7.4s, v19.16b, v16.16b\n" + "ldr q21, [x22, #0x60]\n" + "ldr q20, [x22, #0x70]\n" + ".inst 0x4e91a646 // smmla v6.4s, v18.16b, v17.16b\n" + ".inst 0x4e90a645 // smmla v5.4s, v18.16b, v16.16b\n" + "ldr q19, [x20, #0x60]\n" + "ldr q18, [x20, #0x70]\n" + ".inst 0x4e91a404 // smmla v4.4s, v0.16b, v17.16b\n" + ".inst 0x4e90a403 // smmla v3.4s, v0.16b, v16.16b\n" + "shl v17.16b, v31.16b, #0x4\n" + "shl v16.16b, v30.16b, #0x4\n" + "add x22, x22, #0x80\n" + "add x20, x20, #0x80\n" + "and v31.16b, v31.16b, v11.16b\n" + "and v30.16b, v30.16b, v11.16b\n" + ".inst 0x4e91a7aa // smmla v10.4s, v29.16b, v17.16b\n" + ".inst 0x4e90a7a9 // smmla v9.4s, v29.16b, v16.16b\n" + ".inst 0x4e91a788 // smmla v8.4s, v28.16b, v17.16b\n" + ".inst 0x4e90a787 // smmla v7.4s, v28.16b, v16.16b\n" + ".inst 0x4e91a766 // smmla v6.4s, v27.16b, v17.16b\n" + ".inst 0x4e90a765 // smmla v5.4s, v27.16b, v16.16b\n" + ".inst 0x4e91a744 // smmla v4.4s, v26.16b, v17.16b\n" + ".inst 0x4e90a743 // smmla v3.4s, v26.16b, v16.16b\n" + ".inst 0x4e82a72a // smmla v10.4s, v25.16b, v2.16b\n" + ".inst 0x4e81a729 // smmla v9.4s, v25.16b, v1.16b\n" + ".inst 0x4e82a708 // smmla v8.4s, v24.16b, v2.16b\n" + ".inst 0x4e81a707 // smmla v7.4s, v24.16b, v1.16b\n" + ".inst 0x4e82a6e6 // smmla v6.4s, v23.16b, v2.16b\n" + ".inst 0x4e81a6e5 // smmla v5.4s, v23.16b, v1.16b\n" + ".inst 0x4e82a6c4 // smmla v4.4s, v22.16b, v2.16b\n" + ".inst 0x4e81a6c3 // smmla v3.4s, v22.16b, v1.16b\n" + ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" + ".inst 0x4e9fa666 // smmla v6.4s, v19.16b, v31.16b\n" + ".inst 0x4e9ea665 // smmla v5.4s, v19.16b, v30.16b\n" + ".inst 0x4e9fa644 // smmla v4.4s, v18.16b, v31.16b\n" + ".inst 0x4e9ea643 // smmla v3.4s, v18.16b, v30.16b\n" + "bgt 3b\n" + "ldr q20, [x10, #0x0]\n" + "ldr q19, [x22, #0x0]\n" + "uzp1 v2.2d, v10.2d, v9.2d\n" + "uzp2 v1.2d, v10.2d, v9.2d\n" + "ldr q18, [x20, #0x0]\n" + "ldr q0, [x10, #0x10]\n" + "uzp1 v31.2d, v8.2d, v7.2d\n" + "uzp2 v30.2d, v8.2d, v7.2d\n" + "ldr q17, [x22, #0x10]\n" + "ldr q16, [x20, #0x10]\n" + "uzp1 v29.2d, v6.2d, v5.2d\n" + "uzp2 v28.2d, v6.2d, v5.2d\n" + "ld1r { v27.4s }, [%x[clamp_vals]]\n" + "uzp1 v26.2d, v4.2d, v3.2d\n" + "uzp2 v25.2d, v4.2d, v3.2d\n" + "mla v2.4s, v20.4s, v19.s[0]\n" + "mla v1.4s, v20.4s, v19.s[1]\n" + "mla v31.4s, v20.4s, v19.s[2]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x9, #0x4\n" + "ld1r { v24.4s }, [x20]\n" + "mla v30.4s, v20.4s, v19.s[3]\n" + "mla v29.4s, v20.4s, v18.s[0]\n" + "fmul v23.4s, v0.4s, v17.s[0]\n" + "mla v28.4s, v20.4s, v18.s[1]\n" + "mla v26.4s, v20.4s, v18.s[2]\n" + "fmul v22.4s, v0.4s, v17.s[1]\n" + "add x10, x10, #0x20\n" + "mla v25.4s, v20.4s, v18.s[3]\n" + "scvtf v2.4s, v2.4s\n" + "scvtf v1.4s, v1.4s\n" + "scvtf v31.4s, v31.4s\n" + "fmul v21.4s, v0.4s, v17.s[2]\n" + "scvtf v30.4s, v30.4s\n" + "fmul v20.4s, v0.4s, v17.s[3]\n" + "scvtf v29.4s, v29.4s\n" + "fmul v19.4s, v0.4s, v16.s[0]\n" + "scvtf v28.4s, v28.4s\n" + "fmul v18.4s, v0.4s, v16.s[1]\n" + "scvtf v26.4s, v26.4s\n" + "fmul v17.4s, v0.4s, v16.s[2]\n" + "scvtf v25.4s, v25.4s\n" + "fmul v16.4s, v0.4s, v16.s[3]\n" + "fmul v2.4s, v2.4s, v23.4s\n" + "fmul v1.4s, v1.4s, v22.4s\n" + "fmul v31.4s, v31.4s, v21.4s\n" + "fmul v30.4s, v30.4s, v20.4s\n" + "fmul v29.4s, v29.4s, v19.4s\n" + "fmul v28.4s, v28.4s, v18.4s\n" + "fmul v26.4s, v26.4s, v17.4s\n" + "fmul v25.4s, v25.4s, v16.4s\n" + "fmax v2.4s, v2.4s, v27.4s\n" + "fmax v1.4s, v1.4s, v27.4s\n" + "fmax v31.4s, v31.4s, v27.4s\n" + "fmax v30.4s, v30.4s, v27.4s\n" + "fmax v29.4s, v29.4s, v27.4s\n" + "fmax v28.4s, v28.4s, v27.4s\n" + "fmax v26.4s, v26.4s, v27.4s\n" + "fmax v25.4s, v25.4s, v27.4s\n" + "fmin v2.4s, v2.4s, v24.4s\n" + "fmin v1.4s, v1.4s, v24.4s\n" + "fmin v31.4s, v31.4s, v24.4s\n" + "fmin v30.4s, v30.4s, v24.4s\n" + "fmin v29.4s, v29.4s, v24.4s\n" + "fmin v28.4s, v28.4s, v24.4s\n" + "fmin v26.4s, v26.4s, v24.4s\n" + "fmin v25.4s, v25.4s, v24.4s\n" + "bge 6f\n" + "mov x27, %x[dst]\n" + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x27, %x[dst_stride_row], LSL #1\n" + "add x21, x27, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "tbz x9, #1, 4f\n" + "str d25, [x23], #0x8\n" + "str d26, [x25], #0x8\n" + "str d28, [x24], #0x8\n" + "str d29, [x26], #0x8\n" + "str d30, [x20], #0x8\n" + "str d31, [x22], #0x8\n" + "str d1, [x21], #0x8\n" + "str d2, [x27], #0x8\n" + "tbz x9, #0, 5f\n" + "st1 { v25.s }[2], [x23]\n" + "st1 { v26.s }[2], [x25]\n" + "st1 { v28.s }[2], [x24]\n" + "st1 { v29.s }[2], [x26]\n" + "st1 { v30.s }[2], [x20]\n" + "st1 { v31.s }[2], [x22]\n" + "st1 { v1.s }[2], [x21]\n" + "st1 { v2.s }[2], [x27]\n" + "b 5f\n" + "4:" // Output block 0: partial_1_0 + "str s25, [x23, #0x0]\n" + "str s26, [x25, #0x0]\n" + "str s28, [x24, #0x0]\n" + "str s29, [x26, #0x0]\n" + "str s30, [x20, #0x0]\n" + "str s31, [x22, #0x0]\n" + "str s1, [x21, #0x0]\n" + "str s2, [x27, #0x0]\n" + "5:" // Output block 0: Done + "b 7f\n" + "6:" // Full output + "mov x20, %x[dst]\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q1, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q31, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q29, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q26, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q25, [x20, #0x0]\n" + "7:" // Output stage exit + "subs x9, x9, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "mov x20, #0x2\n" + "sub x12, x12, #0x8\n" + "cmp x12, #0x8\n" + "mov %x[dst], x28\n" + "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" + "bge 1b\n" + "8:" // Row loop skip + "cbz x12, 16f\n" + "9:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "10:" // Row tail: Column loop + "movi v10.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "mov x22, %x[lhs_packed]\n" + "mov x20, %x[num_blocks]\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "11:" // Row tail: Block loop + "ldr q31, [x26, #0x0]\n" + "ldr q30, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q29, [x22, #0x0]\n" + "ldr q28, [x22, #0x10]\n" + "ldr q27, [x26, #0x20]\n" + "ldr q26, [x26, #0x30]\n" + "add x26, x26, #0x40\n" + "ldr q25, [x22, #0x20]\n" + "ldr q24, [x22, #0x30]\n" + "shl v23.16b, v31.16b, #0x4\n" + "shl v22.16b, v30.16b, #0x4\n" + "ldr q21, [x22, #0x40]\n" + "ldr q20, [x22, #0x50]\n" + "and v31.16b, v31.16b, v11.16b\n" + "and v30.16b, v30.16b, v11.16b\n" + "ldr q19, [x22, #0x60]\n" + "ldr q18, [x22, #0x70]\n" + "shl v17.16b, v27.16b, #0x4\n" + "shl v16.16b, v26.16b, #0x4\n" + ".inst 0x4e97a7aa // smmla v10.4s, v29.16b, v23.16b\n" + ".inst 0x4e96a7a9 // smmla v9.4s, v29.16b, v22.16b\n" + "and v27.16b, v27.16b, v11.16b\n" + "add x22, x22, #0x80\n" + ".inst 0x4e97a788 // smmla v8.4s, v28.16b, v23.16b\n" + ".inst 0x4e96a787 // smmla v7.4s, v28.16b, v22.16b\n" + "and v26.16b, v26.16b, v11.16b\n" + ".inst 0x4e91a72a // smmla v10.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a729 // smmla v9.4s, v25.16b, v16.16b\n" + ".inst 0x4e91a708 // smmla v8.4s, v24.16b, v17.16b\n" + ".inst 0x4e90a707 // smmla v7.4s, v24.16b, v16.16b\n" + ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" + ".inst 0x4e9ba66a // smmla v10.4s, v19.16b, v27.16b\n" + ".inst 0x4e9aa669 // smmla v9.4s, v19.16b, v26.16b\n" + ".inst 0x4e9ba648 // smmla v8.4s, v18.16b, v27.16b\n" + ".inst 0x4e9aa647 // smmla v7.4s, v18.16b, v26.16b\n" + "bgt 11b\n" + "ldr q18, [x26, #0x0]\n" + "ldr q17, [x22, #0x0]\n" + "uzp1 v26.2d, v10.2d, v9.2d\n" + "uzp2 v25.2d, v10.2d, v9.2d\n" + "ldr q24, [x26, #0x10]\n" + "ldr q16, [x22, #0x10]\n" + "uzp1 v23.2d, v8.2d, v7.2d\n" + "uzp2 v22.2d, v8.2d, v7.2d\n" + "ld1r { v21.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x4\n" + "ld1r { v20.4s }, [x20]\n" + "mla v26.4s, v18.4s, v17.s[0]\n" + "mla v25.4s, v18.4s, v17.s[1]\n" + "add x26, x26, #0x20\n" + "mla v23.4s, v18.4s, v17.s[2]\n" + "mla v22.4s, v18.4s, v17.s[3]\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v26.4s, v26.4s\n" + "scvtf v25.4s, v25.4s\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v22.4s, v22.4s\n" + "fmul v26.4s, v26.4s, v19.4s\n" + "fmul v25.4s, v25.4s, v18.4s\n" + "fmul v23.4s, v23.4s, v17.4s\n" + "fmul v22.4s, v22.4s, v16.4s\n" + "fmax v26.4s, v26.4s, v21.4s\n" + "fmax v25.4s, v25.4s, v21.4s\n" + "fmax v23.4s, v23.4s, v21.4s\n" + "fmax v22.4s, v22.4s, v21.4s\n" + "fmin v26.4s, v26.4s, v20.4s\n" + "fmin v25.4s, v25.4s, v20.4s\n" + "fmin v23.4s, v23.4s, v20.4s\n" + "fmin v22.4s, v22.4s, v20.4s\n" + "bge 14f\n" + "mov x23, %x[dst]\n" + "cmp x12, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GE\n" + "cmp x12, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GE\n" + "cmp x12, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GE\n" + "tbz x25, #1, 12f\n" + "str d22, [x20], #0x8\n" + "str d23, [x21], #0x8\n" + "str d25, [x22], #0x8\n" + "str d26, [x23], #0x8\n" + "tbz x25, #0, 13f\n" + "st1 { v22.s }[2], [x20]\n" + "st1 { v23.s }[2], [x21]\n" + "st1 { v25.s }[2], [x22]\n" + "st1 { v26.s }[2], [x23]\n" + "b 13f\n" + "12:" // Row tail: Output block 0: partial_1_0 + "str s22, [x20, #0x0]\n" + "str s23, [x21, #0x0]\n" + "str s25, [x22, #0x0]\n" + "str s26, [x23, #0x0]\n" + "13:" // Row tail: Output block 0: Done + "b 15f\n" + "14:" // Row tail: Full output + "mov x20, %x[dst]\n" + "cmp x12, #0x1\n" + "str q26, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 15f\n" + "cmp x12, #0x2\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 15f\n" + "cmp x12, #0x3\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 15f\n" + "str q22, [x20, #0x0]\n" + "15:" // Row tail: Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 10b\n" + "subs x12, x12, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x11\n" + "mov %x[dst], x24\n" + "bgt 9b\n" + "16:" // Row tail: Row loop skip + : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) + : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), + [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", + "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); #else KAI_ASSERT(false); KAI_UNUSED(m); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h index 4e869247..e09d5abd 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h @@ -35,7 +35,7 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(vo /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with - * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function + * the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel * * @return the mr value */ @@ -43,7 +43,7 @@ size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the nr value */ @@ -51,7 +51,7 @@ size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the kr value */ @@ -59,7 +59,7 @@ size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the sr value */ @@ -67,12 +67,12 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, - * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * which contains the packed 8-bit quantized asymmetric per-row (qai8dx) values. * * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. * - * @param[in] m_idx Row index in the LHS matrix (not packed). - * @param[in] k Total number of columns in the LHS matrix (not packed). It must be a multiple of 64. + * @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 8 + * @param[in] k Total number of columns in the LHS matrix (not packed). * * return the offset in bytes to the packed LHS matrix */ @@ -80,10 +80,10 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_n /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, - * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. * * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. - * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be a multiple of 64. + * @param[in] k The common dimension between the LHS and RHS matrix (K). * * return the offset in bytes to the packed RHS matrix */ @@ -92,7 +92,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_n /** * @brief Function to calculate the offset in bytes for the DST matrix * - * @param[in] m_idx Row index in the DST matrix. It must be either 1, 2, 3, or any multiple of 4. + * @param[in] m_idx Row index in the DST matrix. It must be a multiple of 8. * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. * @param[in] dst_stride The number of bytes in in each row of the DST matrix * @@ -102,10 +102,10 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8m size_t m_idx, size_t n_idx, size_t dst_stride); /** - * @brief Function to query the size in bytes for the constant workspace. + * @brief Function to query the size in bytes for the destination matrix. * - * @param[in] m Number of rows in the destination (DST) matrix. It must be either 1, 2, 3, or any multiple of 4. - * @param[in] n Number of columns in the destination (DST) matrix. It must be a multiple 4. + * @param[in] m Number of rows in the destination (DST) matrix. + * @param[in] n Number of columns in the destination (DST) matrix. */ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n); @@ -118,9 +118,9 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm( * Accumulation performed in a single for loop: 32 * Instruction used: i8mm * - * @param[in] m The number of output rows written. It must be either 1, 2, 3, or any multiple of 4. - * @param[in] n The number of output columns written. It must be a multiple of 4. - * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 32. + * @param[in] m The number of output rows written. + * @param[in] n The number of output columns written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. * @param[in] lhs_packed The LHS matrix packed. * When the activation are dynamically quantized, you can obtain this matrix * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c index 18768f08..0c4a9225 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c @@ -22,6 +22,29 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) { return kai_m_step; } @@ -48,20 +71,14 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); - const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); - return (m_idx / kai_m_step) * lhs_packed_stride; + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); - // Temporary assert - KAI_ASSERT((k % kai_k0) == 0); - - const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - - return (n_idx / kai_n_step) * rhs_packed_stride; + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); } size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm( @@ -73,9 +90,6 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8m } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t m, size_t n) { - // Temporary assert - KAI_ASSERT((n % kai_nr) == 0); - return m * n * sizeof(float); } @@ -83,251 +97,263 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm( size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) - // Temporary asserts - KAI_ASSERT(n % kai_nr == 0); - KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { return; } - const size_t num_rows = m; - const size_t num_cols = n; - - const size_t lhs_packed_stride = kai_mr * (k + sizeof(float32_t) + sizeof(float32_t)); - - const int8x16_t nibble_mask = vdupq_n_s8(0xF0); - - const uint8_t* lhs_ptr_start = lhs_packed; - - for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { - const uint8_t* rhs_ptr = rhs_packed; - - for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { - const uint8_t* lhs_ptr = lhs_ptr_start; - - // Main f32 accumulator - int32x4_t iacc_mat_00 = vdupq_n_s32(0); - int32x4_t iacc_mat_01 = vdupq_n_s32(0); - int32x4_t iacc_mat_10 = vdupq_n_s32(0); - int32x4_t iacc_mat_11 = vdupq_n_s32(0); - - int32x4_t iacc_mat_02 = vdupq_n_s32(0); - int32x4_t iacc_mat_03 = vdupq_n_s32(0); - int32x4_t iacc_mat_12 = vdupq_n_s32(0); - int32x4_t iacc_mat_13 = vdupq_n_s32(0); - - for (size_t b = 0; b < k; b += kai_k0) { - // Set up RHS - const int8x16_t rhs_raw_mat_01_0 = vld1q_s8((const int8_t*)rhs_ptr + 0); - const int8x16_t rhs_raw_mat_23_0 = vld1q_s8((const int8_t*)rhs_ptr + 16); - const int8x16_t rhs_raw_mat_45_0 = vld1q_s8((const int8_t*)rhs_ptr + 32); - const int8x16_t rhs_raw_mat_67_0 = vld1q_s8((const int8_t*)rhs_ptr + 48); - const int8x16_t rhs_raw_mat_01_1 = vld1q_s8((const int8_t*)rhs_ptr + 64); - const int8x16_t rhs_raw_mat_23_1 = vld1q_s8((const int8_t*)rhs_ptr + 80); - const int8x16_t rhs_raw_mat_45_1 = vld1q_s8((const int8_t*)rhs_ptr + 96); - const int8x16_t rhs_raw_mat_67_1 = vld1q_s8((const int8_t*)rhs_ptr + 112); - - // Low nibble - const int8x16_t rhs_mat_01_0 = vshlq_n_s8(rhs_raw_mat_01_0, 4); - const int8x16_t rhs_mat_23_0 = vshlq_n_s8(rhs_raw_mat_23_0, 4); - const int8x16_t rhs_mat_45_0 = vshlq_n_s8(rhs_raw_mat_45_0, 4); - const int8x16_t rhs_mat_67_0 = vshlq_n_s8(rhs_raw_mat_67_0, 4); - - const int8x16_t rhs_mat_01_1 = vshlq_n_s8(rhs_raw_mat_01_1, 4); - const int8x16_t rhs_mat_23_1 = vshlq_n_s8(rhs_raw_mat_23_1, 4); - const int8x16_t rhs_mat_45_1 = vshlq_n_s8(rhs_raw_mat_45_1, 4); - const int8x16_t rhs_mat_67_1 = vshlq_n_s8(rhs_raw_mat_67_1, 4); - - // High nibble - const int8x16_t rhs_mat_01_2 = vandq_s8(rhs_raw_mat_01_0, nibble_mask); - const int8x16_t rhs_mat_23_2 = vandq_s8(rhs_raw_mat_23_0, nibble_mask); - const int8x16_t rhs_mat_45_2 = vandq_s8(rhs_raw_mat_45_0, nibble_mask); - const int8x16_t rhs_mat_67_2 = vandq_s8(rhs_raw_mat_67_0, nibble_mask); - - const int8x16_t rhs_mat_01_3 = vandq_s8(rhs_raw_mat_01_1, nibble_mask); - const int8x16_t rhs_mat_23_3 = vandq_s8(rhs_raw_mat_23_1, nibble_mask); - const int8x16_t rhs_mat_45_3 = vandq_s8(rhs_raw_mat_45_1, nibble_mask); - const int8x16_t rhs_mat_67_3 = vandq_s8(rhs_raw_mat_67_1, nibble_mask); - - // Process LHS in pairs of rows - const int8x16_t lhs_mat_01_0 = vld1q_s8((const int8_t*)lhs_ptr + 0); - const int8x16_t lhs_mat_23_0 = vld1q_s8((const int8_t*)lhs_ptr + 16); - const int8x16_t lhs_mat_01_1 = vld1q_s8((const int8_t*)lhs_ptr + 32); - const int8x16_t lhs_mat_23_1 = vld1q_s8((const int8_t*)lhs_ptr + 48); - const int8x16_t lhs_mat_01_2 = vld1q_s8((const int8_t*)lhs_ptr + 64); - const int8x16_t lhs_mat_23_2 = vld1q_s8((const int8_t*)lhs_ptr + 80); - const int8x16_t lhs_mat_01_3 = vld1q_s8((const int8_t*)lhs_ptr + 96); - const int8x16_t lhs_mat_23_3 = vld1q_s8((const int8_t*)lhs_ptr + 112); - - // Do the MMLAs into 2x2 matrices - iacc_mat_00 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vmmlaq_s32(iacc_mat_00, lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), - lhs_mat_01_2, rhs_mat_01_2), - lhs_mat_01_3, rhs_mat_01_3); - iacc_mat_01 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vmmlaq_s32(iacc_mat_01, lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), - lhs_mat_01_2, rhs_mat_23_2), - lhs_mat_01_3, rhs_mat_23_3); - iacc_mat_10 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vmmlaq_s32(iacc_mat_10, lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), - lhs_mat_23_2, rhs_mat_01_2), - lhs_mat_23_3, rhs_mat_01_3); - iacc_mat_11 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vmmlaq_s32(iacc_mat_11, lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), - lhs_mat_23_2, rhs_mat_23_2), - lhs_mat_23_3, rhs_mat_23_3); - - /// - - iacc_mat_02 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vmmlaq_s32(iacc_mat_02, lhs_mat_01_0, rhs_mat_45_0), lhs_mat_01_1, rhs_mat_45_1), - lhs_mat_01_2, rhs_mat_45_2), - lhs_mat_01_3, rhs_mat_45_3); - iacc_mat_03 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vmmlaq_s32(iacc_mat_03, lhs_mat_01_0, rhs_mat_67_0), lhs_mat_01_1, rhs_mat_67_1), - lhs_mat_01_2, rhs_mat_67_2), - lhs_mat_01_3, rhs_mat_67_3); - iacc_mat_12 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vmmlaq_s32(iacc_mat_12, lhs_mat_23_0, rhs_mat_45_0), lhs_mat_23_1, rhs_mat_45_1), - lhs_mat_23_2, rhs_mat_45_2), - lhs_mat_23_3, rhs_mat_45_3); - iacc_mat_13 = vmmlaq_s32( - vmmlaq_s32( - vmmlaq_s32(vmmlaq_s32(iacc_mat_13, lhs_mat_23_0, rhs_mat_67_0), lhs_mat_23_1, rhs_mat_67_1), - lhs_mat_23_2, rhs_mat_67_2), - lhs_mat_23_3, rhs_mat_67_3); - - // Straighten out to make 4 row vectors - lhs_ptr += 128; - rhs_ptr += 128; - } - - int32x4_t iacc_row_0_0123 = vreinterpretq_s32_u64( - vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); - int32x4_t iacc_row_1_0123 = vreinterpretq_s32_u64( - vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); - int32x4_t iacc_row_2_0123 = vreinterpretq_s32_u64( - vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); - int32x4_t iacc_row_3_0123 = vreinterpretq_s32_u64( - vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); - - int32x4_t iacc_row_0_4567 = vreinterpretq_s32_u64( - vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_02), vreinterpretq_u64_s32(iacc_mat_03))); - int32x4_t iacc_row_1_4567 = vreinterpretq_s32_u64( - vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_02), vreinterpretq_u64_s32(iacc_mat_03))); - int32x4_t iacc_row_2_4567 = vreinterpretq_s32_u64( - vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_12), vreinterpretq_u64_s32(iacc_mat_13))); - int32x4_t iacc_row_3_4567 = vreinterpretq_s32_u64( - vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_12), vreinterpretq_u64_s32(iacc_mat_13))); - - // LHS offset - const int32x4_t lhs_offset = vld1q_s32((const int32_t*)lhs_ptr); - lhs_ptr += sizeof(int32x4_t); - - // LHS scale - const float32x4_t lhs_scale = vld1q_f32((const float32_t*)lhs_ptr); - lhs_ptr += sizeof(float32x4_t); - - // RHS sum values - const int32x4_t sum_n_s32_0 = vld1q_s32((const int32_t*)(rhs_ptr)); - rhs_ptr += sizeof(int32x4_t); - const int32x4_t sum_n_s32_1 = vld1q_s32((const int32_t*)(rhs_ptr)); - rhs_ptr += sizeof(int32x4_t); - - // RHS scale - const float32x4_t rhs_scale0 = vld1q_f32((const float*)rhs_ptr); - rhs_ptr += sizeof(float32x4_t); - const float32x4_t rhs_scale1 = vld1q_f32((const float*)rhs_ptr); - rhs_ptr += sizeof(float32x4_t); - - // Add the RHS reduction sum - iacc_row_0_0123 = vmlaq_laneq_s32(iacc_row_0_0123, sum_n_s32_0, lhs_offset, 0); - iacc_row_1_0123 = vmlaq_laneq_s32(iacc_row_1_0123, sum_n_s32_0, lhs_offset, 1); - iacc_row_2_0123 = vmlaq_laneq_s32(iacc_row_2_0123, sum_n_s32_0, lhs_offset, 2); - iacc_row_3_0123 = vmlaq_laneq_s32(iacc_row_3_0123, sum_n_s32_0, lhs_offset, 3); - - iacc_row_0_4567 = vmlaq_laneq_s32(iacc_row_0_4567, sum_n_s32_1, lhs_offset, 0); - iacc_row_1_4567 = vmlaq_laneq_s32(iacc_row_1_4567, sum_n_s32_1, lhs_offset, 1); - iacc_row_2_4567 = vmlaq_laneq_s32(iacc_row_2_4567, sum_n_s32_1, lhs_offset, 2); - iacc_row_3_4567 = vmlaq_laneq_s32(iacc_row_3_4567, sum_n_s32_1, lhs_offset, 3); - - float32x4_t main_acc0_0123 = vmulq_f32(vcvtq_f32_s32(iacc_row_0_0123), rhs_scale0); - float32x4_t main_acc1_0123 = vmulq_f32(vcvtq_f32_s32(iacc_row_1_0123), rhs_scale0); - float32x4_t main_acc2_0123 = vmulq_f32(vcvtq_f32_s32(iacc_row_2_0123), rhs_scale0); - float32x4_t main_acc3_0123 = vmulq_f32(vcvtq_f32_s32(iacc_row_3_0123), rhs_scale0); - - float32x4_t main_acc0_4567 = vmulq_f32(vcvtq_f32_s32(iacc_row_0_4567), rhs_scale1); - float32x4_t main_acc1_4567 = vmulq_f32(vcvtq_f32_s32(iacc_row_1_4567), rhs_scale1); - float32x4_t main_acc2_4567 = vmulq_f32(vcvtq_f32_s32(iacc_row_2_4567), rhs_scale1); - float32x4_t main_acc3_4567 = vmulq_f32(vcvtq_f32_s32(iacc_row_3_4567), rhs_scale1); - - main_acc0_0123 = vmulq_laneq_f32(main_acc0_0123, lhs_scale, 0); - main_acc1_0123 = vmulq_laneq_f32(main_acc1_0123, lhs_scale, 1); - main_acc2_0123 = vmulq_laneq_f32(main_acc2_0123, lhs_scale, 2); - main_acc3_0123 = vmulq_laneq_f32(main_acc3_0123, lhs_scale, 3); - - main_acc0_4567 = vmulq_laneq_f32(main_acc0_4567, lhs_scale, 0); - main_acc1_4567 = vmulq_laneq_f32(main_acc1_4567, lhs_scale, 1); - main_acc2_4567 = vmulq_laneq_f32(main_acc2_4567, lhs_scale, 2); - main_acc3_4567 = vmulq_laneq_f32(main_acc3_4567, lhs_scale, 3); - - // clamp (min-max) operation - const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); - const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); - - main_acc0_0123 = vmaxq_f32(main_acc0_0123, vmin_f32); - main_acc0_0123 = vminq_f32(main_acc0_0123, vmax_f32); - main_acc1_0123 = vmaxq_f32(main_acc1_0123, vmin_f32); - main_acc1_0123 = vminq_f32(main_acc1_0123, vmax_f32); - main_acc2_0123 = vmaxq_f32(main_acc2_0123, vmin_f32); - main_acc2_0123 = vminq_f32(main_acc2_0123, vmax_f32); - - main_acc0_4567 = vmaxq_f32(main_acc0_4567, vmin_f32); - main_acc0_4567 = vminq_f32(main_acc0_4567, vmax_f32); - main_acc1_4567 = vmaxq_f32(main_acc1_4567, vmin_f32); - main_acc1_4567 = vminq_f32(main_acc1_4567, vmax_f32); - main_acc2_4567 = vmaxq_f32(main_acc2_4567, vmin_f32); - main_acc2_4567 = vminq_f32(main_acc2_4567, vmax_f32); - - // Stores the rows in reverse order to avoid out-of-bound writes. - // Override out-of-bound values with in-bound values - vst1q_f32( - (float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + KAI_MIN(row_idx + 3, m - 1) * dst_stride_row), - main_acc3_0123); - vst1q_f32( - (float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + KAI_MIN(row_idx + 3, m - 1) * dst_stride_row), - main_acc3_4567); - vst1q_f32( - (float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + KAI_MIN(row_idx + 2, m - 1) * dst_stride_row), - main_acc2_0123); - vst1q_f32( - (float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + KAI_MIN(row_idx + 2, m - 1) * dst_stride_row), - main_acc2_4567); - vst1q_f32( - (float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + KAI_MIN(row_idx + 1, m - 1) * dst_stride_row), - main_acc1_0123); - vst1q_f32( - (float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + KAI_MIN(row_idx + 1, m - 1) * dst_stride_row), - main_acc1_4567); - vst1q_f32( - (float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + KAI_MIN(row_idx + 0, m - 1) * dst_stride_row), - main_acc0_0123); - vst1q_f32( - (float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + KAI_MIN(row_idx + 0, m - 1) * dst_stride_row), - main_acc0_4567); - } - - lhs_ptr_start += lhs_packed_stride; - } + const size_t k_internal = kai_k_roundedup(k); + + size_t num_blocks = k_internal / 32; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x28, #0x80\n" + "mov x20, #0x20\n" + "movi v12.16b, #0xf0\n" + "mov x27, %x[m]\n" + "madd x28, %x[num_blocks], x28, x20\n" + "cbz x27, 10f\n" + "1:" // Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "2:" // Column loop + "movi v11.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "mov x21, %x[lhs_packed]\n" + "mov x20, %x[num_blocks]\n" + "movi v9.4s, #0x0\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "movi v4.4s, #0x0\n" + "3:" // Block loop + "ldr q3, [x26, #0x0]\n" + "ldr q2, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q1, [x26, #0x20]\n" + "ldr q0, [x26, #0x30]\n" + "ldr q31, [x21, #0x0]\n" + "ldr q30, [x21, #0x10]\n" + "ldr q29, [x26, #0x40]\n" + "ldr q28, [x26, #0x50]\n" + "shl v19.16b, v3.16b, #0x4\n" + "shl v18.16b, v2.16b, #0x4\n" + "ldr q27, [x26, #0x60]\n" + "ldr q26, [x26, #0x70]\n" + "shl v17.16b, v1.16b, #0x4\n" + "shl v16.16b, v0.16b, #0x4\n" + "ldr q25, [x21, #0x20]\n" + "ldr q24, [x21, #0x30]\n" + "and v3.16b, v3.16b, v12.16b\n" + "and v2.16b, v2.16b, v12.16b\n" + "ldr q23, [x21, #0x40]\n" + "ldr q22, [x21, #0x50]\n" + ".inst 0x4e93a7eb // smmla v11.4s, v31.16b, v19.16b\n" + ".inst 0x4e92a7e9 // smmla v9.4s, v31.16b, v18.16b\n" + "ldr q21, [x21, #0x60]\n" + "ldr q20, [x21, #0x70]\n" + ".inst 0x4e91a7ea // smmla v10.4s, v31.16b, v17.16b\n" + ".inst 0x4e90a7e8 // smmla v8.4s, v31.16b, v16.16b\n" + ".inst 0x4e93a7c7 // smmla v7.4s, v30.16b, v19.16b\n" + ".inst 0x4e92a7c5 // smmla v5.4s, v30.16b, v18.16b\n" + "shl v19.16b, v29.16b, #0x4\n" + "add x21, x21, #0x80\n" + ".inst 0x4e91a7c6 // smmla v6.4s, v30.16b, v17.16b\n" + ".inst 0x4e90a7c4 // smmla v4.4s, v30.16b, v16.16b\n" + "shl v18.16b, v28.16b, #0x4\n" + "add x26, x26, #0x80\n" + "shl v17.16b, v27.16b, #0x4\n" + "shl v16.16b, v26.16b, #0x4\n" + ".inst 0x4e93a72b // smmla v11.4s, v25.16b, v19.16b\n" + "and v1.16b, v1.16b, v12.16b\n" + "and v0.16b, v0.16b, v12.16b\n" + ".inst 0x4e92a729 // smmla v9.4s, v25.16b, v18.16b\n" + ".inst 0x4e93a707 // smmla v7.4s, v24.16b, v19.16b\n" + ".inst 0x4e92a705 // smmla v5.4s, v24.16b, v18.16b\n" + "and v29.16b, v29.16b, v12.16b\n" + ".inst 0x4e91a72a // smmla v10.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a728 // smmla v8.4s, v25.16b, v16.16b\n" + "and v28.16b, v28.16b, v12.16b\n" + ".inst 0x4e91a706 // smmla v6.4s, v24.16b, v17.16b\n" + ".inst 0x4e90a704 // smmla v4.4s, v24.16b, v16.16b\n" + "and v27.16b, v27.16b, v12.16b\n" + ".inst 0x4e83a6eb // smmla v11.4s, v23.16b, v3.16b\n" + ".inst 0x4e82a6e9 // smmla v9.4s, v23.16b, v2.16b\n" + "and v26.16b, v26.16b, v12.16b\n" + ".inst 0x4e83a6c7 // smmla v7.4s, v22.16b, v3.16b\n" + ".inst 0x4e82a6c5 // smmla v5.4s, v22.16b, v2.16b\n" + ".inst 0x4e81a6ea // smmla v10.4s, v23.16b, v1.16b\n" + ".inst 0x4e80a6e8 // smmla v8.4s, v23.16b, v0.16b\n" + ".inst 0x4e81a6c6 // smmla v6.4s, v22.16b, v1.16b\n" + ".inst 0x4e80a6c4 // smmla v4.4s, v22.16b, v0.16b\n" + ".inst 0x4e9da6ab // smmla v11.4s, v21.16b, v29.16b\n" + ".inst 0x4e9ca6a9 // smmla v9.4s, v21.16b, v28.16b\n" + ".inst 0x4e9da687 // smmla v7.4s, v20.16b, v29.16b\n" + ".inst 0x4e9ca685 // smmla v5.4s, v20.16b, v28.16b\n" + ".inst 0x4e9ba6aa // smmla v10.4s, v21.16b, v27.16b\n" + ".inst 0x4e9aa6a8 // smmla v8.4s, v21.16b, v26.16b\n" + ".inst 0x4e9ba686 // smmla v6.4s, v20.16b, v27.16b\n" + ".inst 0x4e9aa684 // smmla v4.4s, v20.16b, v26.16b\n" + "bgt 3b\n" + "ldr q20, [x26, #0x0]\n" + "ldr q19, [x26, #0x10]\n" + "uzp1 v2.2d, v11.2d, v9.2d\n" + "uzp1 v1.2d, v10.2d, v8.2d\n" + "ldr q18, [x21, #0x0]\n" + "ldr q17, [x26, #0x20]\n" + "uzp2 v0.2d, v11.2d, v9.2d\n" + "uzp2 v31.2d, v10.2d, v8.2d\n" + "ldr q30, [x26, #0x30]\n" + "ldr q16, [x21, #0x10]\n" + "uzp1 v29.2d, v7.2d, v5.2d\n" + "uzp1 v28.2d, v6.2d, v4.2d\n" + "ld1r { v27.4s }, [%x[clamp_vals]]\n" + "uzp2 v26.2d, v7.2d, v5.2d\n" + "uzp2 v25.2d, v6.2d, v4.2d\n" + "add x20, %x[clamp_vals], #0x4\n" + "ld1r { v24.4s }, [x20]\n" + "mla v2.4s, v20.4s, v18.s[0]\n" + "mla v1.4s, v19.4s, v18.s[0]\n" + "cmp x25, #0x8\n" + "mla v0.4s, v20.4s, v18.s[1]\n" + "mla v31.4s, v19.4s, v18.s[1]\n" + "fmul v23.4s, v17.4s, v16.s[0]\n" + "add x26, x26, #0x40\n" + "mla v29.4s, v20.4s, v18.s[2]\n" + "mla v28.4s, v19.4s, v18.s[2]\n" + "fmul v22.4s, v30.4s, v16.s[0]\n" + "mla v26.4s, v20.4s, v18.s[3]\n" + "mla v25.4s, v19.4s, v18.s[3]\n" + "fmul v21.4s, v17.4s, v16.s[1]\n" + "scvtf v2.4s, v2.4s\n" + "scvtf v1.4s, v1.4s\n" + "scvtf v0.4s, v0.4s\n" + "scvtf v31.4s, v31.4s\n" + "fmul v20.4s, v30.4s, v16.s[1]\n" + "scvtf v29.4s, v29.4s\n" + "fmul v19.4s, v17.4s, v16.s[2]\n" + "scvtf v28.4s, v28.4s\n" + "fmul v18.4s, v30.4s, v16.s[2]\n" + "scvtf v26.4s, v26.4s\n" + "fmul v17.4s, v17.4s, v16.s[3]\n" + "scvtf v25.4s, v25.4s\n" + "fmul v16.4s, v30.4s, v16.s[3]\n" + "fmul v2.4s, v2.4s, v23.4s\n" + "fmul v1.4s, v1.4s, v22.4s\n" + "fmul v0.4s, v0.4s, v21.4s\n" + "fmul v31.4s, v31.4s, v20.4s\n" + "fmul v29.4s, v29.4s, v19.4s\n" + "fmul v28.4s, v28.4s, v18.4s\n" + "fmul v26.4s, v26.4s, v17.4s\n" + "fmul v25.4s, v25.4s, v16.4s\n" + "fmax v2.4s, v2.4s, v27.4s\n" + "fmax v1.4s, v1.4s, v27.4s\n" + "fmax v0.4s, v0.4s, v27.4s\n" + "fmax v31.4s, v31.4s, v27.4s\n" + "fmax v29.4s, v29.4s, v27.4s\n" + "fmax v28.4s, v28.4s, v27.4s\n" + "fmax v26.4s, v26.4s, v27.4s\n" + "fmax v25.4s, v25.4s, v27.4s\n" + "fmin v2.4s, v2.4s, v24.4s\n" + "fmin v1.4s, v1.4s, v24.4s\n" + "fmin v0.4s, v0.4s, v24.4s\n" + "fmin v31.4s, v31.4s, v24.4s\n" + "fmin v29.4s, v29.4s, v24.4s\n" + "fmin v28.4s, v28.4s, v24.4s\n" + "fmin v26.4s, v26.4s, v24.4s\n" + "fmin v25.4s, v25.4s, v24.4s\n" + "bge 8f\n" + "mov x23, %x[dst]\n" + "cmp x27, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GE\n" + "cmp x27, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GE\n" + "cmp x27, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GE\n" + "tbz x25, #2, 5f\n" + "st1 { v26.4s }, [x20], #0x10\n" + "st1 { v29.4s }, [x21], #0x10\n" + "st1 { v0.4s }, [x22], #0x10\n" + "st1 { v2.4s }, [x23], #0x10\n" + "tbz x25, #1, 4f\n" + "str d25, [x20], #0x8\n" + "str d28, [x21], #0x8\n" + "str d31, [x22], #0x8\n" + "str d1, [x23], #0x8\n" + "tbz x25, #0, 7f\n" + "st1 { v25.s }[2], [x20]\n" + "st1 { v28.s }[2], [x21]\n" + "st1 { v31.s }[2], [x22]\n" + "st1 { v1.s }[2], [x23]\n" + "b 7f\n" + "4:" // Output block 0: partial_1_4 + "tbz x25, #0, 7f\n" + "str s25, [x20, #0x0]\n" + "str s28, [x21, #0x0]\n" + "str s31, [x22, #0x0]\n" + "str s1, [x23, #0x0]\n" + "b 7f\n" + "5:" // Output block 0: partial_2_0 + "tbz x25, #1, 6f\n" + "str d26, [x20], #0x8\n" + "str d29, [x21], #0x8\n" + "str d0, [x22], #0x8\n" + "str d2, [x23], #0x8\n" + "tbz x25, #0, 7f\n" + "st1 { v26.s }[2], [x20]\n" + "st1 { v29.s }[2], [x21]\n" + "st1 { v0.s }[2], [x22]\n" + "st1 { v2.s }[2], [x23]\n" + "b 7f\n" + "6:" // Output block 0: partial_1_0 + "str s26, [x20, #0x0]\n" + "str s29, [x21, #0x0]\n" + "str s0, [x22, #0x0]\n" + "str s2, [x23, #0x0]\n" + "7:" // Output block 0: Done + "b 9f\n" + "8:" // Full output + "mov x20, %x[dst]\n" + "cmp x27, #0x1\n" + "str q2, [x20, #0x0]\n" + "str q1, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 9f\n" + "cmp x27, #0x2\n" + "str q0, [x20, #0x0]\n" + "str q31, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 9f\n" + "cmp x27, #0x3\n" + "str q29, [x20, #0x0]\n" + "str q28, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 9f\n" + "str q26, [x20, #0x0]\n" + "str q25, [x20, #0x10]\n" + "9:" // Output stage exit + "subs x25, x25, #0x8\n" + "add %x[dst], %x[dst], #0x20\n" + "bgt 2b\n" + "subs x27, x27, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x28\n" + "mov %x[dst], x24\n" + "bgt 1b\n" + "10:" // Row loop skip + : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) + : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), + [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", + "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); #else KAI_ASSERT(false); KAI_UNUSED(m); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h index 1a2221cf..a043cf0e 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h @@ -35,7 +35,7 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(vo /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with - * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function + * the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel * * @return the mr value */ @@ -43,7 +43,7 @@ size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the nr value */ @@ -51,7 +51,7 @@ size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the kr value */ @@ -59,7 +59,7 @@ size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the sr value */ @@ -71,8 +71,8 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); * * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. * - * @param[in] m_idx Row index in the LHS matrix (not packed). - * @param[in] k Total number of columns in the LHS matrix (not packed). It must be a multiple of 64. + * @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 4. + * @param[in] k Total number of columns in the LHS matrix (not packed). * * return the offset in bytes to the packed LHS matrix */ @@ -80,10 +80,10 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_n /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, - * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. * - * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. - * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be a multiple of 64. + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 8. + * @param[in] k The common dimension between the LHS and RHS matrix (K). * * return the offset in bytes to the packed RHS matrix */ @@ -92,9 +92,9 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_n /** * @brief Function to calculate the offset in bytes for the DST matrix * - * @param[in] m_idx Row index in the DST matrix. It must be either 1, 2, 3, or any multiple of 4. - * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. - * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * @param[in] m_idx Row index in the DST matrix. It must be a multiple of 4. + * @param[in] n_idx Column index in the DST matrix. It must be a multiple of 8. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix * * return the DST offset in bytes */ @@ -102,10 +102,10 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8m size_t m_idx, size_t n_idx, size_t dst_stride); /** - * @brief Function to query the size in bytes for the constant workspace. + * @brief Function to query the size in bytes for the destination matrix. * - * @param[in] m Number of rows in the destination (DST) matrix. It must be either 1, 2, 3, or any multiple of 4. - * @param[in] n Number of columns in the destination (DST) matrix. It must be a multiple 4. + * @param[in] m Number of rows in the destination (DST) matrix. + * @param[in] n Number of columns in the destination (DST) matrix. */ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t m, size_t n); @@ -118,9 +118,9 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm( * Accumulation performed in a single for loop: 32 * Instruction used: i8mm * - * @param[in] m The number of output rows written. It must be either 1, 2, 3, or any multiple of 4. - * @param[in] n The number of output columns written. It must be a multiple of 4. - * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 32. + * @param[in] m The number of output rows written. + * @param[in] n The number of output columns written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. * @param[in] lhs_packed The LHS matrix packed. * When the activation are dynamically quantized, you can obtain this matrix * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c index 4ff45dc0..1bac2078 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c @@ -5,11 +5,11 @@ // #include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h" -#include "kai_common.h" - #include #include +#include "kai_common.h" + static const size_t kai_m_step = 8; static const size_t kai_n_step = 8; static const size_t kai_mr = 4; @@ -22,6 +22,29 @@ static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) { return kai_m_step; } @@ -48,20 +71,14 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); - const size_t lhs_packed_stride = kai_mr * (k + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); - return (m_idx / kai_m_step) * lhs_packed_stride; + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); - // Temporary assert - KAI_ASSERT((k % kai_k0) == 0); - - const size_t rhs_packed_stride = kai_nr * ((k / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); - - return (n_idx / kai_n_step) * rhs_packed_stride; + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); } size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm( @@ -73,9 +90,6 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8m } size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t m, size_t n) { - // Temporary assert - KAI_ASSERT((n % kai_nr) == 0); - return m * n * sizeof(float); } @@ -83,503 +97,639 @@ void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) - // Temporary asserts - KAI_ASSERT(n % kai_nr == 0); - KAI_ASSERT(k % kai_k0 == 0); KAI_ASSERT(dst_stride_col == sizeof(float)); if (m == 0) { return; } - size_t num_blocks = k / 32; + const size_t k_internal = kai_k_roundedup(k); + + size_t num_blocks = k_internal / 32; float clamp_vals[2] = {scalar_min, scalar_max}; - __asm__ __volatile__("mov x27, %x[m]\n" - "mov x26, #0x80\n" - "movi v5.16b, #0xf0\n" - "mov x20, #0x20\n" - "cmp x27, #0x8\n" - "madd x26, %x[num_blocks], x26, x20\n" - "blt 4f\n" - "1:" // Row loop - "mov x25, %x[rhs_packed]\n" - "mov x23, %x[n]\n" - "add x22, %x[dst], %x[dst_stride_row], LSL #3\n" - "2:" // Column loop - "mov x24, %x[lhs_packed]\n" - "movi v8.4s, #0x0\n" - "movi v27.4s, #0x0\n" - "mov x21, %x[num_blocks]\n" - "movi v11.4s, #0x0\n" - "movi v6.4s, #0x0\n" - "movi v31.4s, #0x0\n" - "movi v26.4s, #0x0\n" - "movi v13.4s, #0x0\n" - "movi v15.4s, #0x0\n" - "add x20, x24, x26\n" - "movi v0.4s, #0x0\n" - "movi v22.4s, #0x0\n" - "movi v30.4s, #0x0\n" - "movi v29.4s, #0x0\n" - "movi v23.4s, #0x0\n" - "movi v3.4s, #0x0\n" - "movi v16.4s, #0x0\n" - "movi v21.4s, #0x0\n" - "3:" // Block loop - "ldr q12, [x25, #0x0]\n" - "ldr q10, [x25, #0x10]\n" - "subs x21, x21, #0x1\n" - "ldr q7, [x25, #0x20]\n" - "ldr q28, [x25, #0x30]\n" - "ldr q24, [x24, #0x0]\n" - "ldr q20, [x24, #0x10]\n" - "ldr q9, [x20, #0x0]\n" - "ldr q2, [x20, #0x10]\n" - "shl v18.16b, v12.16b, #0x4\n" - "shl v17.16b, v10.16b, #0x4\n" - "ldr q1, [x25, #0x40]\n" - "ldr q4, [x25, #0x50]\n" - "shl v14.16b, v7.16b, #0x4\n" - "shl v19.16b, v28.16b, #0x4\n" - "ldr q25, [x25, #0x60]\n" - "and v12.16b, v12.16b, v5.16b\n" - "and v10.16b, v10.16b, v5.16b\n" - ".inst 0x4e92a708 // smmla v8.4s, v24.16b, v18.16b\n" - ".inst 0x4e91a70b // smmla v11.4s, v24.16b, v17.16b\n" - ".inst 0x4e92a69f // smmla v31.4s, v20.16b, v18.16b\n" - "and v7.16b, v7.16b, v5.16b\n" - ".inst 0x4e8ea71b // smmla v27.4s, v24.16b, v14.16b\n" - ".inst 0x4e93a706 // smmla v6.4s, v24.16b, v19.16b\n" - "ldr q24, [x25, #0x70]\n" - "and v28.16b, v28.16b, v5.16b\n" - ".inst 0x4e91a68d // smmla v13.4s, v20.16b, v17.16b\n" - ".inst 0x4e8ea69a // smmla v26.4s, v20.16b, v14.16b\n" - "add x25, x25, #0x80\n" - ".inst 0x4e93a68f // smmla v15.4s, v20.16b, v19.16b\n" - "ldr q20, [x24, #0x20]\n" - ".inst 0x4e92a520 // smmla v0.4s, v9.16b, v18.16b\n" - ".inst 0x4e91a53e // smmla v30.4s, v9.16b, v17.16b\n" - ".inst 0x4e8ea536 // smmla v22.4s, v9.16b, v14.16b\n" - ".inst 0x4e93a53d // smmla v29.4s, v9.16b, v19.16b\n" - "ldr q9, [x24, #0x30]\n" - ".inst 0x4e92a457 // smmla v23.4s, v2.16b, v18.16b\n" - "ldr q18, [x20, #0x20]\n" - ".inst 0x4e91a450 // smmla v16.4s, v2.16b, v17.16b\n" - "ldr q17, [x20, #0x30]\n" - ".inst 0x4e8ea443 // smmla v3.4s, v2.16b, v14.16b\n" - "ldr q14, [x24, #0x40]\n" - ".inst 0x4e93a455 // smmla v21.4s, v2.16b, v19.16b\n" - "ldr q2, [x24, #0x50]\n" - "shl v19.16b, v1.16b, #0x4\n" - "and v1.16b, v1.16b, v5.16b\n" - ".inst 0x4e93a688 // smmla v8.4s, v20.16b, v19.16b\n" - ".inst 0x4e93a53f // smmla v31.4s, v9.16b, v19.16b\n" - ".inst 0x4e93a640 // smmla v0.4s, v18.16b, v19.16b\n" - ".inst 0x4e93a637 // smmla v23.4s, v17.16b, v19.16b\n" - "shl v19.16b, v4.16b, #0x4\n" - "and v4.16b, v4.16b, v5.16b\n" - ".inst 0x4e93a68b // smmla v11.4s, v20.16b, v19.16b\n" - ".inst 0x4e93a52d // smmla v13.4s, v9.16b, v19.16b\n" - ".inst 0x4e93a65e // smmla v30.4s, v18.16b, v19.16b\n" - ".inst 0x4e93a630 // smmla v16.4s, v17.16b, v19.16b\n" - "shl v19.16b, v25.16b, #0x4\n" - ".inst 0x4e8ca5c8 // smmla v8.4s, v14.16b, v12.16b\n" - ".inst 0x4e8ca45f // smmla v31.4s, v2.16b, v12.16b\n" - "and v25.16b, v25.16b, v5.16b\n" - ".inst 0x4e93a69b // smmla v27.4s, v20.16b, v19.16b\n" - ".inst 0x4e93a53a // smmla v26.4s, v9.16b, v19.16b\n" - ".inst 0x4e93a656 // smmla v22.4s, v18.16b, v19.16b\n" - ".inst 0x4e93a623 // smmla v3.4s, v17.16b, v19.16b\n" - "shl v19.16b, v24.16b, #0x4\n" - ".inst 0x4e8aa5cb // smmla v11.4s, v14.16b, v10.16b\n" - ".inst 0x4e8aa44d // smmla v13.4s, v2.16b, v10.16b\n" - "and v24.16b, v24.16b, v5.16b\n" - ".inst 0x4e93a686 // smmla v6.4s, v20.16b, v19.16b\n" - "ldr q20, [x20, #0x40]\n" - ".inst 0x4e93a52f // smmla v15.4s, v9.16b, v19.16b\n" - "ldr q9, [x20, #0x50]\n" - ".inst 0x4e93a65d // smmla v29.4s, v18.16b, v19.16b\n" - "ldr q18, [x24, #0x60]\n" - ".inst 0x4e93a635 // smmla v21.4s, v17.16b, v19.16b\n" - "ldr q19, [x24, #0x70]\n" - "ldr q17, [x20, #0x60]\n" - ".inst 0x4e87a5db // smmla v27.4s, v14.16b, v7.16b\n" - ".inst 0x4e87a45a // smmla v26.4s, v2.16b, v7.16b\n" - "add x24, x24, #0x80\n" - ".inst 0x4e8ca680 // smmla v0.4s, v20.16b, v12.16b\n" - ".inst 0x4e8aa69e // smmla v30.4s, v20.16b, v10.16b\n" - ".inst 0x4e9ca5c6 // smmla v6.4s, v14.16b, v28.16b\n" - "ldr q14, [x20, #0x70]\n" - ".inst 0x4e9ca44f // smmla v15.4s, v2.16b, v28.16b\n" - "add x20, x20, #0x80\n" - ".inst 0x4e87a696 // smmla v22.4s, v20.16b, v7.16b\n" - ".inst 0x4e9ca69d // smmla v29.4s, v20.16b, v28.16b\n" - ".inst 0x4e8ca537 // smmla v23.4s, v9.16b, v12.16b\n" - ".inst 0x4e8aa530 // smmla v16.4s, v9.16b, v10.16b\n" - ".inst 0x4e87a523 // smmla v3.4s, v9.16b, v7.16b\n" - ".inst 0x4e9ca535 // smmla v21.4s, v9.16b, v28.16b\n" - ".inst 0x4e81a648 // smmla v8.4s, v18.16b, v1.16b\n" - ".inst 0x4e84a64b // smmla v11.4s, v18.16b, v4.16b\n" - ".inst 0x4e99a65b // smmla v27.4s, v18.16b, v25.16b\n" - ".inst 0x4e98a646 // smmla v6.4s, v18.16b, v24.16b\n" - ".inst 0x4e81a67f // smmla v31.4s, v19.16b, v1.16b\n" - ".inst 0x4e84a66d // smmla v13.4s, v19.16b, v4.16b\n" - ".inst 0x4e99a67a // smmla v26.4s, v19.16b, v25.16b\n" - ".inst 0x4e98a66f // smmla v15.4s, v19.16b, v24.16b\n" - ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" - ".inst 0x4e84a63e // smmla v30.4s, v17.16b, v4.16b\n" - ".inst 0x4e99a636 // smmla v22.4s, v17.16b, v25.16b\n" - ".inst 0x4e98a63d // smmla v29.4s, v17.16b, v24.16b\n" - ".inst 0x4e81a5d7 // smmla v23.4s, v14.16b, v1.16b\n" - ".inst 0x4e84a5d0 // smmla v16.4s, v14.16b, v4.16b\n" - ".inst 0x4e99a5c3 // smmla v3.4s, v14.16b, v25.16b\n" - ".inst 0x4e98a5d5 // smmla v21.4s, v14.16b, v24.16b\n" - "bgt 3b\n" - "ldr q20, [x25, #0x0]\n" - "ldr q12, [x25, #0x10]\n" - "uzp1 v25.2d, v8.2d, v11.2d\n" - "uzp1 v24.2d, v27.2d, v6.2d\n" - "ldr q19, [x24, #0x0]\n" - "ldr q7, [x25, #0x20]\n" - "uzp2 v9.2d, v8.2d, v11.2d\n" - "uzp2 v6.2d, v27.2d, v6.2d\n" - "ldr q8, [x25, #0x30]\n" - "ldr q10, [x24, #0x10]\n" - "uzp1 v14.2d, v31.2d, v13.2d\n" - "uzp1 v11.2d, v26.2d, v15.2d\n" - "ldr q4, [x20, #0x0]\n" - "ldr q1, [x20, #0x10]\n" - "uzp2 v27.2d, v31.2d, v13.2d\n" - "uzp2 v13.2d, v26.2d, v15.2d\n" - "ld1r { v2.4s }, [%x[clamp_vals]]\n" - "mla v25.4s, v20.4s, v19.s[0]\n" - "mla v24.4s, v12.4s, v19.s[0]\n" - "uzp1 v31.2d, v0.2d, v30.2d\n" - "mla v9.4s, v20.4s, v19.s[1]\n" - "mla v6.4s, v12.4s, v19.s[1]\n" - "uzp1 v15.2d, v22.2d, v29.2d\n" - "add x20, %x[clamp_vals], #0x4\n" - "ld1r { v28.4s }, [x20]\n" - "mla v14.4s, v20.4s, v19.s[2]\n" - "mla v11.4s, v12.4s, v19.s[2]\n" - "uzp2 v0.2d, v0.2d, v30.2d\n" - "uzp2 v29.2d, v22.2d, v29.2d\n" - "mla v27.4s, v20.4s, v19.s[3]\n" - "mla v13.4s, v12.4s, v19.s[3]\n" - "mov x20, %x[dst]\n" - "uzp1 v30.2d, v23.2d, v16.2d\n" - "uzp1 v26.2d, v3.2d, v21.2d\n" - "mla v31.4s, v20.4s, v4.s[0]\n" - "subs x23, x23, #0x8\n" - "scvtf v25.4s, v25.4s\n" - "fmul v19.4s, v7.4s, v10.s[0]\n" - "mla v15.4s, v12.4s, v4.s[0]\n" - "add x25, x25, #0x40\n" - "scvtf v24.4s, v24.4s\n" - "fmul v18.4s, v8.4s, v10.s[0]\n" - "mla v0.4s, v20.4s, v4.s[1]\n" - "add %x[dst], %x[dst], #0x20\n" - "uzp2 v23.2d, v23.2d, v16.2d\n" - "uzp2 v22.2d, v3.2d, v21.2d\n" - "mla v29.4s, v12.4s, v4.s[1]\n" - "scvtf v9.4s, v9.4s\n" - "fmul v17.4s, v7.4s, v10.s[1]\n" - "mla v30.4s, v20.4s, v4.s[2]\n" - "scvtf v6.4s, v6.4s\n" - "fmul v16.4s, v8.4s, v10.s[1]\n" - "mla v26.4s, v12.4s, v4.s[2]\n" - "scvtf v14.4s, v14.4s\n" - "fmul v21.4s, v7.4s, v10.s[2]\n" - "mla v23.4s, v20.4s, v4.s[3]\n" - "scvtf v11.4s, v11.4s\n" - "fmul v20.4s, v8.4s, v10.s[2]\n" - "mla v22.4s, v12.4s, v4.s[3]\n" - "fmul v25.4s, v25.4s, v19.4s\n" - "fmul v24.4s, v24.4s, v18.4s\n" - "scvtf v27.4s, v27.4s\n" - "fmul v19.4s, v7.4s, v10.s[3]\n" - "scvtf v13.4s, v13.4s\n" - "fmul v18.4s, v8.4s, v10.s[3]\n" - "fmul v9.4s, v9.4s, v17.4s\n" - "fmul v6.4s, v6.4s, v16.4s\n" - "scvtf v31.4s, v31.4s\n" - "fmul v17.4s, v7.4s, v1.s[0]\n" - "scvtf v15.4s, v15.4s\n" - "fmul v16.4s, v8.4s, v1.s[0]\n" - "fmul v14.4s, v14.4s, v21.4s\n" - "fmul v11.4s, v11.4s, v20.4s\n" - "scvtf v0.4s, v0.4s\n" - "fmul v21.4s, v7.4s, v1.s[1]\n" - "scvtf v29.4s, v29.4s\n" - "fmul v20.4s, v8.4s, v1.s[1]\n" - "fmul v27.4s, v27.4s, v19.4s\n" - "fmul v13.4s, v13.4s, v18.4s\n" - "scvtf v30.4s, v30.4s\n" - "fmul v19.4s, v7.4s, v1.s[2]\n" - "scvtf v26.4s, v26.4s\n" - "fmul v18.4s, v8.4s, v1.s[2]\n" - "fmax v25.4s, v25.4s, v2.4s\n" - "fmax v24.4s, v24.4s, v2.4s\n" - "fmul v31.4s, v31.4s, v17.4s\n" - "fmul v15.4s, v15.4s, v16.4s\n" - "scvtf v23.4s, v23.4s\n" - "fmul v17.4s, v7.4s, v1.s[3]\n" - "scvtf v22.4s, v22.4s\n" - "fmul v16.4s, v8.4s, v1.s[3]\n" - "fmax v9.4s, v9.4s, v2.4s\n" - "fmax v6.4s, v6.4s, v2.4s\n" - "fmul v0.4s, v0.4s, v21.4s\n" - "fmul v29.4s, v29.4s, v20.4s\n" - "fmax v14.4s, v14.4s, v2.4s\n" - "fmax v11.4s, v11.4s, v2.4s\n" - "fmul v30.4s, v30.4s, v19.4s\n" - "fmul v26.4s, v26.4s, v18.4s\n" - "fmin v25.4s, v25.4s, v28.4s\n" - "fmin v24.4s, v24.4s, v28.4s\n" - "fmax v27.4s, v27.4s, v2.4s\n" - "fmax v13.4s, v13.4s, v2.4s\n" - "fmul v23.4s, v23.4s, v17.4s\n" - "fmul v22.4s, v22.4s, v16.4s\n" - "str q25, [x20, #0x0]\n" - "str q24, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "fmin v9.4s, v9.4s, v28.4s\n" - "fmin v6.4s, v6.4s, v28.4s\n" - "fmax v31.4s, v31.4s, v2.4s\n" - "fmax v15.4s, v15.4s, v2.4s\n" - "fmin v14.4s, v14.4s, v28.4s\n" - "fmin v11.4s, v11.4s, v28.4s\n" - "str q9, [x20, #0x0]\n" - "str q6, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "fmax v0.4s, v0.4s, v2.4s\n" - "fmax v29.4s, v29.4s, v2.4s\n" - "fmin v27.4s, v27.4s, v28.4s\n" - "fmin v13.4s, v13.4s, v28.4s\n" - "str q14, [x20, #0x0]\n" - "str q11, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "fmax v30.4s, v30.4s, v2.4s\n" - "fmax v26.4s, v26.4s, v2.4s\n" - "fmin v31.4s, v31.4s, v28.4s\n" - "fmin v15.4s, v15.4s, v28.4s\n" - "str q27, [x20, #0x0]\n" - "str q13, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "fmax v23.4s, v23.4s, v2.4s\n" - "fmax v22.4s, v22.4s, v2.4s\n" - "fmin v0.4s, v0.4s, v28.4s\n" - "fmin v29.4s, v29.4s, v28.4s\n" - "str q31, [x20, #0x0]\n" - "str q15, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "fmin v30.4s, v30.4s, v28.4s\n" - "fmin v26.4s, v26.4s, v28.4s\n" - "fmin v23.4s, v23.4s, v28.4s\n" - "fmin v22.4s, v22.4s, v28.4s\n" - "str q0, [x20, #0x0]\n" - "str q29, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q30, [x20, #0x0]\n" - "str q26, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "str q23, [x20, #0x0]\n" - "str q22, [x20, #0x10]\n" - "bne 2b\n" - "mov x20, #0x2\n" - "sub x27, x27, #0x8\n" - "cmp x27, #0x8\n" - "mov %x[dst], x22\n" - "madd %x[lhs_packed], x20, x26, %x[lhs_packed]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x27, 9f\n" - "5:" // Row tail: Row loop - "mov x23, %x[rhs_packed]\n" - "mov x22, %x[n]\n" - "add x21, %x[dst], %x[dst_stride_row], LSL #2\n" - "6:" // Row tail: Column loop - "movi v8.4s, #0x0\n" - "movi v27.4s, #0x0\n" - "mov x24, %x[lhs_packed]\n" - "mov x20, %x[num_blocks]\n" - "movi v11.4s, #0x0\n" - "movi v6.4s, #0x0\n" - "movi v31.4s, #0x0\n" - "movi v26.4s, #0x0\n" - "movi v13.4s, #0x0\n" - "movi v15.4s, #0x0\n" - "7:" // Row tail: Block loop - "ldr q4, [x23, #0x0]\n" - "ldr q10, [x23, #0x10]\n" - "subs x20, x20, #0x1\n" - "ldr q2, [x23, #0x20]\n" - "ldr q1, [x23, #0x30]\n" - "ldr q0, [x24, #0x0]\n" - "ldr q12, [x24, #0x10]\n" - "ldr q30, [x23, #0x40]\n" - "ldr q29, [x23, #0x50]\n" - "shl v19.16b, v4.16b, #0x4\n" - "shl v18.16b, v10.16b, #0x4\n" - "ldr q3, [x23, #0x60]\n" - "ldr q28, [x23, #0x70]\n" - "shl v17.16b, v2.16b, #0x4\n" - "shl v16.16b, v1.16b, #0x4\n" - "ldr q25, [x24, #0x20]\n" - "ldr q24, [x24, #0x30]\n" - "and v4.16b, v4.16b, v5.16b\n" - "and v10.16b, v10.16b, v5.16b\n" - "ldr q23, [x24, #0x40]\n" - "ldr q22, [x24, #0x50]\n" - ".inst 0x4e93a408 // smmla v8.4s, v0.16b, v19.16b\n" - ".inst 0x4e92a40b // smmla v11.4s, v0.16b, v18.16b\n" - "ldr q21, [x24, #0x60]\n" - "ldr q20, [x24, #0x70]\n" - ".inst 0x4e91a41b // smmla v27.4s, v0.16b, v17.16b\n" - ".inst 0x4e90a406 // smmla v6.4s, v0.16b, v16.16b\n" - ".inst 0x4e93a59f // smmla v31.4s, v12.16b, v19.16b\n" - ".inst 0x4e92a58d // smmla v13.4s, v12.16b, v18.16b\n" - "shl v19.16b, v30.16b, #0x4\n" - "add x24, x24, #0x80\n" - ".inst 0x4e91a59a // smmla v26.4s, v12.16b, v17.16b\n" - ".inst 0x4e90a58f // smmla v15.4s, v12.16b, v16.16b\n" - "shl v18.16b, v29.16b, #0x4\n" - "add x23, x23, #0x80\n" - "shl v17.16b, v3.16b, #0x4\n" - "shl v16.16b, v28.16b, #0x4\n" - ".inst 0x4e93a728 // smmla v8.4s, v25.16b, v19.16b\n" - "and v2.16b, v2.16b, v5.16b\n" - "and v1.16b, v1.16b, v5.16b\n" - ".inst 0x4e92a72b // smmla v11.4s, v25.16b, v18.16b\n" - ".inst 0x4e93a71f // smmla v31.4s, v24.16b, v19.16b\n" - ".inst 0x4e92a70d // smmla v13.4s, v24.16b, v18.16b\n" - "and v30.16b, v30.16b, v5.16b\n" - ".inst 0x4e91a73b // smmla v27.4s, v25.16b, v17.16b\n" - ".inst 0x4e90a726 // smmla v6.4s, v25.16b, v16.16b\n" - "and v29.16b, v29.16b, v5.16b\n" - ".inst 0x4e91a71a // smmla v26.4s, v24.16b, v17.16b\n" - ".inst 0x4e90a70f // smmla v15.4s, v24.16b, v16.16b\n" - "and v3.16b, v3.16b, v5.16b\n" - ".inst 0x4e84a6e8 // smmla v8.4s, v23.16b, v4.16b\n" - ".inst 0x4e8aa6eb // smmla v11.4s, v23.16b, v10.16b\n" - "and v28.16b, v28.16b, v5.16b\n" - ".inst 0x4e84a6df // smmla v31.4s, v22.16b, v4.16b\n" - ".inst 0x4e8aa6cd // smmla v13.4s, v22.16b, v10.16b\n" - ".inst 0x4e82a6fb // smmla v27.4s, v23.16b, v2.16b\n" - ".inst 0x4e81a6e6 // smmla v6.4s, v23.16b, v1.16b\n" - ".inst 0x4e82a6da // smmla v26.4s, v22.16b, v2.16b\n" - ".inst 0x4e81a6cf // smmla v15.4s, v22.16b, v1.16b\n" - ".inst 0x4e9ea6a8 // smmla v8.4s, v21.16b, v30.16b\n" - ".inst 0x4e9da6ab // smmla v11.4s, v21.16b, v29.16b\n" - ".inst 0x4e9ea69f // smmla v31.4s, v20.16b, v30.16b\n" - ".inst 0x4e9da68d // smmla v13.4s, v20.16b, v29.16b\n" - ".inst 0x4e83a6bb // smmla v27.4s, v21.16b, v3.16b\n" - ".inst 0x4e9ca6a6 // smmla v6.4s, v21.16b, v28.16b\n" - ".inst 0x4e83a69a // smmla v26.4s, v20.16b, v3.16b\n" - ".inst 0x4e9ca68f // smmla v15.4s, v20.16b, v28.16b\n" - "bgt 7b\n" - "ldr q21, [x23, #0x0]\n" - "ldr q19, [x23, #0x10]\n" - "uzp1 v2.2d, v8.2d, v11.2d\n" - "uzp1 v1.2d, v27.2d, v6.2d\n" - "ldr q18, [x24, #0x0]\n" - "ldr q17, [x23, #0x20]\n" - "uzp2 v0.2d, v8.2d, v11.2d\n" - "uzp2 v12.2d, v27.2d, v6.2d\n" - "ldr q30, [x23, #0x30]\n" - "ldr q16, [x24, #0x10]\n" - "uzp1 v29.2d, v31.2d, v13.2d\n" - "uzp1 v28.2d, v26.2d, v15.2d\n" - "ld1r { v27.4s }, [%x[clamp_vals]]\n" - "uzp2 v20.2d, v31.2d, v13.2d\n" - "uzp2 v25.2d, v26.2d, v15.2d\n" - "add x20, %x[clamp_vals], #0x4\n" - "ld1r { v24.4s }, [x20]\n" - "mla v2.4s, v21.4s, v18.s[0]\n" - "mla v1.4s, v19.4s, v18.s[0]\n" - "mov x20, %x[dst]\n" - "mla v0.4s, v21.4s, v18.s[1]\n" - "mla v12.4s, v19.4s, v18.s[1]\n" - "fmul v23.4s, v17.4s, v16.s[0]\n" - "cmp x27, #0x1\n" - "mla v29.4s, v21.4s, v18.s[2]\n" - "mla v28.4s, v19.4s, v18.s[2]\n" - "fmul v22.4s, v30.4s, v16.s[0]\n" - "add x23, x23, #0x40\n" - "mla v20.4s, v21.4s, v18.s[3]\n" - "mla v25.4s, v19.4s, v18.s[3]\n" - "fmul v21.4s, v17.4s, v16.s[1]\n" - "scvtf v2.4s, v2.4s\n" - "scvtf v1.4s, v1.4s\n" - "scvtf v0.4s, v0.4s\n" - "scvtf v12.4s, v12.4s\n" - "fmul v8.4s, v30.4s, v16.s[1]\n" - "scvtf v29.4s, v29.4s\n" - "fmul v19.4s, v17.4s, v16.s[2]\n" - "scvtf v28.4s, v28.4s\n" - "fmul v18.4s, v30.4s, v16.s[2]\n" - "scvtf v20.4s, v20.4s\n" - "fmul v17.4s, v17.4s, v16.s[3]\n" - "scvtf v25.4s, v25.4s\n" - "fmul v16.4s, v30.4s, v16.s[3]\n" - "fmul v2.4s, v2.4s, v23.4s\n" - "fmul v1.4s, v1.4s, v22.4s\n" - "fmul v0.4s, v0.4s, v21.4s\n" - "fmul v12.4s, v12.4s, v8.4s\n" - "fmul v29.4s, v29.4s, v19.4s\n" - "fmul v28.4s, v28.4s, v18.4s\n" - "fmul v20.4s, v20.4s, v17.4s\n" - "fmul v25.4s, v25.4s, v16.4s\n" - "fmax v2.4s, v2.4s, v27.4s\n" - "fmax v1.4s, v1.4s, v27.4s\n" - "fmax v0.4s, v0.4s, v27.4s\n" - "fmax v12.4s, v12.4s, v27.4s\n" - "fmax v29.4s, v29.4s, v27.4s\n" - "fmax v28.4s, v28.4s, v27.4s\n" - "fmax v20.4s, v20.4s, v27.4s\n" - "fmax v25.4s, v25.4s, v27.4s\n" - "fmin v2.4s, v2.4s, v24.4s\n" - "fmin v1.4s, v1.4s, v24.4s\n" - "fmin v0.4s, v0.4s, v24.4s\n" - "fmin v12.4s, v12.4s, v24.4s\n" - "fmin v29.4s, v29.4s, v24.4s\n" - "fmin v28.4s, v28.4s, v24.4s\n" - "fmin v20.4s, v20.4s, v24.4s\n" - "str q2, [x20, #0x0]\n" - "fmin v25.4s, v25.4s, v24.4s\n" - "str q1, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 8f\n" - "cmp x27, #0x2\n" - "str q0, [x20, #0x0]\n" - "str q12, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 8f\n" - "cmp x27, #0x3\n" - "str q29, [x20, #0x0]\n" - "str q28, [x20, #0x10]\n" - "add x20, x20, %x[dst_stride_row]\n" - "ble 8f\n" - "str q20, [x20, #0x0]\n" - "str q25, [x20, #0x10]\n" - "8:" // Row tail: Accumulator store skip - "subs x22, x22, #0x8\n" - "add %x[dst], %x[dst], #0x20\n" - "bne 6b\n" - "subs x27, x27, #0x4\n" - "add %x[lhs_packed], %x[lhs_packed], x26\n" - "mov %x[dst], x21\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) - : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), - [num_blocks] "r"(num_blocks), [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", - "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", - "x26", "x27"); + __asm__ __volatile__( + "mov x12, %x[m]\n" + "mov x11, #0x80\n" + "movi v3.16b, #0xf0\n" + "mov x20, #0x20\n" + "cmp x12, #0x8\n" + "madd x11, %x[num_blocks], x11, x20\n" + "blt 10f\n" + "1:" // Row loop + "mov x10, %x[rhs_packed]\n" + "mov x9, %x[n]\n" + "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" + "2:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v13.4s, #0x0\n" + "movi v14.4s, #0x0\n" + "mov x21, %x[num_blocks]\n" + "movi v25.4s, #0x0\n" + "movi v16.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "add x20, x22, x11\n" + "movi v24.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v15.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v11.4s, #0x0\n" + "movi v31.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "3:" // Block loop + "ldr q21, [x10, #0x0]\n" + "ldr q20, [x10, #0x10]\n" + "subs x21, x21, #0x1\n" + "ldr q2, [x10, #0x20]\n" + "ldr q23, [x10, #0x30]\n" + "ldr q8, [x22, #0x0]\n" + "ldr q1, [x22, #0x10]\n" + "ldr q12, [x20, #0x0]\n" + "ldr q6, [x20, #0x10]\n" + "shl v17.16b, v21.16b, #0x4\n" + "shl v22.16b, v20.16b, #0x4\n" + "ldr q9, [x10, #0x40]\n" + "ldr q18, [x10, #0x50]\n" + "shl v4.16b, v2.16b, #0x4\n" + "shl v5.16b, v23.16b, #0x4\n" + "ldr q27, [x10, #0x60]\n" + "and v21.16b, v21.16b, v3.16b\n" + "and v20.16b, v20.16b, v3.16b\n" + ".inst 0x4e91a50d // smmla v13.4s, v8.16b, v17.16b\n" + ".inst 0x4e96a519 // smmla v25.4s, v8.16b, v22.16b\n" + ".inst 0x4e91a43a // smmla v26.4s, v1.16b, v17.16b\n" + "and v2.16b, v2.16b, v3.16b\n" + ".inst 0x4e84a50e // smmla v14.4s, v8.16b, v4.16b\n" + ".inst 0x4e85a510 // smmla v16.4s, v8.16b, v5.16b\n" + "ldr q8, [x10, #0x70]\n" + "and v23.16b, v23.16b, v3.16b\n" + ".inst 0x4e96a42a // smmla v10.4s, v1.16b, v22.16b\n" + ".inst 0x4e84a43e // smmla v30.4s, v1.16b, v4.16b\n" + "add x10, x10, #0x80\n" + ".inst 0x4e85a433 // smmla v19.4s, v1.16b, v5.16b\n" + "ldr q1, [x22, #0x20]\n" + ".inst 0x4e91a598 // smmla v24.4s, v12.16b, v17.16b\n" + ".inst 0x4e96a59c // smmla v28.4s, v12.16b, v22.16b\n" + ".inst 0x4e84a580 // smmla v0.4s, v12.16b, v4.16b\n" + ".inst 0x4e85a58f // smmla v15.4s, v12.16b, v5.16b\n" + "ldr q12, [x22, #0x30]\n" + ".inst 0x4e91a4dd // smmla v29.4s, v6.16b, v17.16b\n" + "ldr q17, [x20, #0x20]\n" + ".inst 0x4e96a4df // smmla v31.4s, v6.16b, v22.16b\n" + "ldr q22, [x20, #0x30]\n" + ".inst 0x4e84a4cb // smmla v11.4s, v6.16b, v4.16b\n" + "ldr q4, [x22, #0x40]\n" + ".inst 0x4e85a4c7 // smmla v7.4s, v6.16b, v5.16b\n" + "ldr q5, [x22, #0x50]\n" + "shl v6.16b, v9.16b, #0x4\n" + "and v9.16b, v9.16b, v3.16b\n" + ".inst 0x4e86a42d // smmla v13.4s, v1.16b, v6.16b\n" + ".inst 0x4e86a59a // smmla v26.4s, v12.16b, v6.16b\n" + ".inst 0x4e86a638 // smmla v24.4s, v17.16b, v6.16b\n" + ".inst 0x4e86a6dd // smmla v29.4s, v22.16b, v6.16b\n" + "shl v6.16b, v18.16b, #0x4\n" + "and v18.16b, v18.16b, v3.16b\n" + ".inst 0x4e86a439 // smmla v25.4s, v1.16b, v6.16b\n" + ".inst 0x4e86a58a // smmla v10.4s, v12.16b, v6.16b\n" + ".inst 0x4e86a63c // smmla v28.4s, v17.16b, v6.16b\n" + ".inst 0x4e86a6df // smmla v31.4s, v22.16b, v6.16b\n" + "shl v6.16b, v27.16b, #0x4\n" + ".inst 0x4e95a48d // smmla v13.4s, v4.16b, v21.16b\n" + ".inst 0x4e95a4ba // smmla v26.4s, v5.16b, v21.16b\n" + "and v27.16b, v27.16b, v3.16b\n" + ".inst 0x4e86a42e // smmla v14.4s, v1.16b, v6.16b\n" + ".inst 0x4e86a59e // smmla v30.4s, v12.16b, v6.16b\n" + ".inst 0x4e86a620 // smmla v0.4s, v17.16b, v6.16b\n" + ".inst 0x4e86a6cb // smmla v11.4s, v22.16b, v6.16b\n" + "shl v6.16b, v8.16b, #0x4\n" + ".inst 0x4e94a499 // smmla v25.4s, v4.16b, v20.16b\n" + ".inst 0x4e94a4aa // smmla v10.4s, v5.16b, v20.16b\n" + "and v8.16b, v8.16b, v3.16b\n" + ".inst 0x4e86a430 // smmla v16.4s, v1.16b, v6.16b\n" + "ldr q1, [x20, #0x40]\n" + ".inst 0x4e86a593 // smmla v19.4s, v12.16b, v6.16b\n" + "ldr q12, [x20, #0x50]\n" + ".inst 0x4e86a62f // smmla v15.4s, v17.16b, v6.16b\n" + "ldr q17, [x22, #0x60]\n" + ".inst 0x4e86a6c7 // smmla v7.4s, v22.16b, v6.16b\n" + "ldr q22, [x22, #0x70]\n" + "ldr q6, [x20, #0x60]\n" + ".inst 0x4e82a48e // smmla v14.4s, v4.16b, v2.16b\n" + ".inst 0x4e82a4be // smmla v30.4s, v5.16b, v2.16b\n" + "add x22, x22, #0x80\n" + ".inst 0x4e95a438 // smmla v24.4s, v1.16b, v21.16b\n" + ".inst 0x4e94a43c // smmla v28.4s, v1.16b, v20.16b\n" + ".inst 0x4e97a490 // smmla v16.4s, v4.16b, v23.16b\n" + "ldr q4, [x20, #0x70]\n" + ".inst 0x4e97a4b3 // smmla v19.4s, v5.16b, v23.16b\n" + "add x20, x20, #0x80\n" + ".inst 0x4e82a420 // smmla v0.4s, v1.16b, v2.16b\n" + ".inst 0x4e97a42f // smmla v15.4s, v1.16b, v23.16b\n" + ".inst 0x4e95a59d // smmla v29.4s, v12.16b, v21.16b\n" + ".inst 0x4e94a59f // smmla v31.4s, v12.16b, v20.16b\n" + ".inst 0x4e82a58b // smmla v11.4s, v12.16b, v2.16b\n" + ".inst 0x4e97a587 // smmla v7.4s, v12.16b, v23.16b\n" + ".inst 0x4e89a62d // smmla v13.4s, v17.16b, v9.16b\n" + ".inst 0x4e92a639 // smmla v25.4s, v17.16b, v18.16b\n" + ".inst 0x4e9ba62e // smmla v14.4s, v17.16b, v27.16b\n" + ".inst 0x4e88a630 // smmla v16.4s, v17.16b, v8.16b\n" + ".inst 0x4e89a6da // smmla v26.4s, v22.16b, v9.16b\n" + ".inst 0x4e92a6ca // smmla v10.4s, v22.16b, v18.16b\n" + ".inst 0x4e9ba6de // smmla v30.4s, v22.16b, v27.16b\n" + ".inst 0x4e88a6d3 // smmla v19.4s, v22.16b, v8.16b\n" + ".inst 0x4e89a4d8 // smmla v24.4s, v6.16b, v9.16b\n" + ".inst 0x4e92a4dc // smmla v28.4s, v6.16b, v18.16b\n" + ".inst 0x4e9ba4c0 // smmla v0.4s, v6.16b, v27.16b\n" + ".inst 0x4e88a4cf // smmla v15.4s, v6.16b, v8.16b\n" + ".inst 0x4e89a49d // smmla v29.4s, v4.16b, v9.16b\n" + ".inst 0x4e92a49f // smmla v31.4s, v4.16b, v18.16b\n" + ".inst 0x4e9ba48b // smmla v11.4s, v4.16b, v27.16b\n" + ".inst 0x4e88a487 // smmla v7.4s, v4.16b, v8.16b\n" + "bgt 3b\n" + "ldr q18, [x10, #0x0]\n" + "ldr q2, [x10, #0x10]\n" + "uzp1 v4.2d, v13.2d, v25.2d\n" + "uzp1 v5.2d, v14.2d, v16.2d\n" + "ldr q22, [x22, #0x0]\n" + "ldr q27, [x20, #0x0]\n" + "uzp2 v1.2d, v13.2d, v25.2d\n" + "uzp2 v20.2d, v14.2d, v16.2d\n" + "ldr q17, [x10, #0x20]\n" + "ldr q6, [x10, #0x30]\n" + "uzp1 v9.2d, v26.2d, v10.2d\n" + "uzp1 v13.2d, v30.2d, v19.2d\n" + "ldr q23, [x22, #0x10]\n" + "ldr q12, [x20, #0x10]\n" + "uzp2 v21.2d, v26.2d, v10.2d\n" + "uzp2 v25.2d, v30.2d, v19.2d\n" + "ld1r { v8.4s }, [%x[clamp_vals]]\n" + "uzp1 v16.2d, v24.2d, v28.2d\n" + "uzp1 v10.2d, v0.2d, v15.2d\n" + "mla v4.4s, v18.4s, v22.s[0]\n" + "uzp2 v30.2d, v24.2d, v28.2d\n" + "uzp2 v28.2d, v0.2d, v15.2d\n" + "mla v5.4s, v2.4s, v22.s[0]\n" + "add x20, %x[clamp_vals], #0x4\n" + "ld1r { v24.4s }, [x20]\n" + "uzp1 v14.2d, v29.2d, v31.2d\n" + "uzp1 v26.2d, v11.2d, v7.2d\n" + "mla v1.4s, v18.4s, v22.s[1]\n" + "uzp2 v0.2d, v29.2d, v31.2d\n" + "uzp2 v11.2d, v11.2d, v7.2d\n" + "mla v20.4s, v2.4s, v22.s[1]\n" + "cmp x9, #0x8\n" + "mla v9.4s, v18.4s, v22.s[2]\n" + "mla v13.4s, v2.4s, v22.s[2]\n" + "scvtf v4.4s, v4.4s\n" + "add x10, x10, #0x40\n" + "mla v21.4s, v18.4s, v22.s[3]\n" + "mla v25.4s, v2.4s, v22.s[3]\n" + "fmul v19.4s, v17.4s, v23.s[0]\n" + "mla v16.4s, v18.4s, v27.s[0]\n" + "mla v10.4s, v2.4s, v27.s[0]\n" + "scvtf v5.4s, v5.4s\n" + "mla v30.4s, v18.4s, v27.s[1]\n" + "mla v28.4s, v2.4s, v27.s[1]\n" + "fmul v15.4s, v6.4s, v23.s[0]\n" + "mla v14.4s, v18.4s, v27.s[2]\n" + "mla v26.4s, v2.4s, v27.s[2]\n" + "scvtf v1.4s, v1.4s\n" + "mla v0.4s, v18.4s, v27.s[3]\n" + "mla v11.4s, v2.4s, v27.s[3]\n" + "fmul v22.4s, v17.4s, v23.s[1]\n" + "scvtf v20.4s, v20.4s\n" + "fmul v29.4s, v6.4s, v23.s[1]\n" + "scvtf v9.4s, v9.4s\n" + "fmul v2.4s, v17.4s, v23.s[2]\n" + "scvtf v13.4s, v13.4s\n" + "fmul v18.4s, v6.4s, v23.s[2]\n" + "scvtf v21.4s, v21.4s\n" + "fmul v31.4s, v17.4s, v23.s[3]\n" + "scvtf v25.4s, v25.4s\n" + "fmul v7.4s, v6.4s, v23.s[3]\n" + "scvtf v16.4s, v16.4s\n" + "fmul v27.4s, v17.4s, v12.s[0]\n" + "scvtf v10.4s, v10.4s\n" + "fmul v23.4s, v6.4s, v12.s[0]\n" + "scvtf v30.4s, v30.4s\n" + "scvtf v28.4s, v28.4s\n" + "scvtf v14.4s, v14.4s\n" + "scvtf v26.4s, v26.4s\n" + "scvtf v0.4s, v0.4s\n" + "scvtf v11.4s, v11.4s\n" + "fmul v4.4s, v4.4s, v19.4s\n" + "fmul v19.4s, v17.4s, v12.s[1]\n" + "fmul v5.4s, v5.4s, v15.4s\n" + "fmul v15.4s, v6.4s, v12.s[1]\n" + "fmul v1.4s, v1.4s, v22.4s\n" + "fmul v22.4s, v17.4s, v12.s[2]\n" + "fmul v17.4s, v17.4s, v12.s[3]\n" + "fmul v20.4s, v20.4s, v29.4s\n" + "fmul v29.4s, v6.4s, v12.s[2]\n" + "fmul v12.4s, v6.4s, v12.s[3]\n" + "fmul v9.4s, v9.4s, v2.4s\n" + "fmul v13.4s, v13.4s, v18.4s\n" + "fmul v21.4s, v21.4s, v31.4s\n" + "fmul v25.4s, v25.4s, v7.4s\n" + "fmul v16.4s, v16.4s, v27.4s\n" + "fmul v10.4s, v10.4s, v23.4s\n" + "fmul v30.4s, v30.4s, v19.4s\n" + "fmul v28.4s, v28.4s, v15.4s\n" + "fmul v14.4s, v14.4s, v22.4s\n" + "fmul v26.4s, v26.4s, v29.4s\n" + "fmul v0.4s, v0.4s, v17.4s\n" + "fmul v11.4s, v11.4s, v12.4s\n" + "fmax v4.4s, v4.4s, v8.4s\n" + "fmax v5.4s, v5.4s, v8.4s\n" + "fmax v1.4s, v1.4s, v8.4s\n" + "fmax v20.4s, v20.4s, v8.4s\n" + "fmax v9.4s, v9.4s, v8.4s\n" + "fmax v13.4s, v13.4s, v8.4s\n" + "fmax v21.4s, v21.4s, v8.4s\n" + "fmax v25.4s, v25.4s, v8.4s\n" + "fmax v16.4s, v16.4s, v8.4s\n" + "fmax v10.4s, v10.4s, v8.4s\n" + "fmax v30.4s, v30.4s, v8.4s\n" + "fmax v28.4s, v28.4s, v8.4s\n" + "fmax v14.4s, v14.4s, v8.4s\n" + "fmax v26.4s, v26.4s, v8.4s\n" + "fmax v0.4s, v0.4s, v8.4s\n" + "fmax v11.4s, v11.4s, v8.4s\n" + "fmin v4.4s, v4.4s, v24.4s\n" + "fmin v5.4s, v5.4s, v24.4s\n" + "fmin v1.4s, v1.4s, v24.4s\n" + "fmin v20.4s, v20.4s, v24.4s\n" + "fmin v9.4s, v9.4s, v24.4s\n" + "fmin v13.4s, v13.4s, v24.4s\n" + "fmin v21.4s, v21.4s, v24.4s\n" + "fmin v25.4s, v25.4s, v24.4s\n" + "fmin v16.4s, v16.4s, v24.4s\n" + "fmin v10.4s, v10.4s, v24.4s\n" + "fmin v30.4s, v30.4s, v24.4s\n" + "fmin v28.4s, v28.4s, v24.4s\n" + "fmin v14.4s, v14.4s, v24.4s\n" + "fmin v26.4s, v26.4s, v24.4s\n" + "fmin v0.4s, v0.4s, v24.4s\n" + "fmin v11.4s, v11.4s, v24.4s\n" + "bge 8f\n" + "mov x27, %x[dst]\n" + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x27, %x[dst_stride_row], LSL #1\n" + "add x21, x27, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "tbz x9, #2, 5f\n" + "st1 { v0.4s }, [x23], #0x10\n" + "st1 { v14.4s }, [x25], #0x10\n" + "st1 { v30.4s }, [x24], #0x10\n" + "st1 { v16.4s }, [x26], #0x10\n" + "st1 { v21.4s }, [x20], #0x10\n" + "st1 { v9.4s }, [x22], #0x10\n" + "st1 { v1.4s }, [x21], #0x10\n" + "st1 { v4.4s }, [x27], #0x10\n" + "tbz x9, #1, 4f\n" + "str d11, [x23], #0x8\n" + "str d26, [x25], #0x8\n" + "str d28, [x24], #0x8\n" + "str d10, [x26], #0x8\n" + "str d25, [x20], #0x8\n" + "str d13, [x22], #0x8\n" + "str d20, [x21], #0x8\n" + "str d5, [x27], #0x8\n" + "tbz x9, #0, 7f\n" + "st1 { v11.s }[2], [x23]\n" + "st1 { v26.s }[2], [x25]\n" + "st1 { v28.s }[2], [x24]\n" + "st1 { v10.s }[2], [x26]\n" + "st1 { v25.s }[2], [x20]\n" + "st1 { v13.s }[2], [x22]\n" + "st1 { v20.s }[2], [x21]\n" + "st1 { v5.s }[2], [x27]\n" + "b 7f\n" + "4:" // Output block 0: partial_1_4 + "tbz x9, #0, 7f\n" + "str s11, [x23, #0x0]\n" + "str s26, [x25, #0x0]\n" + "str s28, [x24, #0x0]\n" + "str s10, [x26, #0x0]\n" + "str s25, [x20, #0x0]\n" + "str s13, [x22, #0x0]\n" + "str s20, [x21, #0x0]\n" + "str s5, [x27, #0x0]\n" + "b 7f\n" + "5:" // Output block 0: partial_2_0 + "tbz x9, #1, 6f\n" + "str d0, [x23], #0x8\n" + "str d14, [x25], #0x8\n" + "str d30, [x24], #0x8\n" + "str d16, [x26], #0x8\n" + "str d21, [x20], #0x8\n" + "str d9, [x22], #0x8\n" + "str d1, [x21], #0x8\n" + "str d4, [x27], #0x8\n" + "tbz x9, #0, 7f\n" + "st1 { v0.s }[2], [x23]\n" + "st1 { v14.s }[2], [x25]\n" + "st1 { v30.s }[2], [x24]\n" + "st1 { v16.s }[2], [x26]\n" + "st1 { v21.s }[2], [x20]\n" + "st1 { v9.s }[2], [x22]\n" + "st1 { v1.s }[2], [x21]\n" + "st1 { v4.s }[2], [x27]\n" + "b 7f\n" + "6:" // Output block 0: partial_1_0 + "str s0, [x23, #0x0]\n" + "str s14, [x25, #0x0]\n" + "str s30, [x24, #0x0]\n" + "str s16, [x26, #0x0]\n" + "str s21, [x20, #0x0]\n" + "str s9, [x22, #0x0]\n" + "str s1, [x21, #0x0]\n" + "str s4, [x27, #0x0]\n" + "7:" // Output block 0: Done + "b 9f\n" + "8:" // Full output + "mov x20, %x[dst]\n" + "str q4, [x20, #0x0]\n" + "str q5, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q1, [x20, #0x0]\n" + "str q20, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q9, [x20, #0x0]\n" + "str q13, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q21, [x20, #0x0]\n" + "str q25, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q16, [x20, #0x0]\n" + "str q10, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q30, [x20, #0x0]\n" + "str q28, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q14, [x20, #0x0]\n" + "str q26, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q0, [x20, #0x0]\n" + "str q11, [x20, #0x10]\n" + "9:" // Output stage exit + "subs x9, x9, #0x8\n" + "add %x[dst], %x[dst], #0x20\n" + "bgt 2b\n" + "mov x20, #0x2\n" + "sub x12, x12, #0x8\n" + "cmp x12, #0x8\n" + "mov %x[dst], x28\n" + "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" + "bge 1b\n" + "10:" // Row loop skip + "cbz x12, 20f\n" + "11:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "12:" // Row tail: Column loop + "movi v13.4s, #0x0\n" + "movi v14.4s, #0x0\n" + "mov x22, %x[lhs_packed]\n" + "mov x20, %x[num_blocks]\n" + "movi v25.4s, #0x0\n" + "movi v16.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "13:" // Row tail: Block loop + "ldr q4, [x26, #0x0]\n" + "ldr q8, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q2, [x26, #0x20]\n" + "ldr q11, [x26, #0x30]\n" + "ldr q18, [x22, #0x0]\n" + "ldr q15, [x22, #0x10]\n" + "ldr q12, [x26, #0x40]\n" + "ldr q6, [x26, #0x50]\n" + "shl v9.16b, v4.16b, #0x4\n" + "shl v22.16b, v8.16b, #0x4\n" + "ldr q28, [x26, #0x60]\n" + "ldr q27, [x26, #0x70]\n" + "shl v17.16b, v2.16b, #0x4\n" + "shl v23.16b, v11.16b, #0x4\n" + "ldr q31, [x22, #0x20]\n" + "ldr q7, [x22, #0x30]\n" + "and v4.16b, v4.16b, v3.16b\n" + "and v8.16b, v8.16b, v3.16b\n" + "ldr q24, [x22, #0x40]\n" + "ldr q1, [x22, #0x50]\n" + ".inst 0x4e89a64d // smmla v13.4s, v18.16b, v9.16b\n" + ".inst 0x4e96a659 // smmla v25.4s, v18.16b, v22.16b\n" + "ldr q21, [x22, #0x60]\n" + "ldr q20, [x22, #0x70]\n" + ".inst 0x4e91a64e // smmla v14.4s, v18.16b, v17.16b\n" + ".inst 0x4e97a650 // smmla v16.4s, v18.16b, v23.16b\n" + ".inst 0x4e89a5fa // smmla v26.4s, v15.16b, v9.16b\n" + ".inst 0x4e96a5ea // smmla v10.4s, v15.16b, v22.16b\n" + "shl v22.16b, v12.16b, #0x4\n" + "add x22, x22, #0x80\n" + ".inst 0x4e91a5fe // smmla v30.4s, v15.16b, v17.16b\n" + ".inst 0x4e97a5f3 // smmla v19.4s, v15.16b, v23.16b\n" + "shl v17.16b, v6.16b, #0x4\n" + "add x26, x26, #0x80\n" + "shl v23.16b, v28.16b, #0x4\n" + "shl v5.16b, v27.16b, #0x4\n" + ".inst 0x4e96a7ed // smmla v13.4s, v31.16b, v22.16b\n" + "and v2.16b, v2.16b, v3.16b\n" + "and v11.16b, v11.16b, v3.16b\n" + ".inst 0x4e91a7f9 // smmla v25.4s, v31.16b, v17.16b\n" + ".inst 0x4e96a4fa // smmla v26.4s, v7.16b, v22.16b\n" + ".inst 0x4e91a4ea // smmla v10.4s, v7.16b, v17.16b\n" + "and v12.16b, v12.16b, v3.16b\n" + ".inst 0x4e97a7ee // smmla v14.4s, v31.16b, v23.16b\n" + ".inst 0x4e85a7f0 // smmla v16.4s, v31.16b, v5.16b\n" + "and v6.16b, v6.16b, v3.16b\n" + ".inst 0x4e97a4fe // smmla v30.4s, v7.16b, v23.16b\n" + ".inst 0x4e85a4f3 // smmla v19.4s, v7.16b, v5.16b\n" + "and v28.16b, v28.16b, v3.16b\n" + ".inst 0x4e84a70d // smmla v13.4s, v24.16b, v4.16b\n" + ".inst 0x4e88a719 // smmla v25.4s, v24.16b, v8.16b\n" + "and v27.16b, v27.16b, v3.16b\n" + ".inst 0x4e84a43a // smmla v26.4s, v1.16b, v4.16b\n" + ".inst 0x4e88a42a // smmla v10.4s, v1.16b, v8.16b\n" + ".inst 0x4e82a70e // smmla v14.4s, v24.16b, v2.16b\n" + ".inst 0x4e8ba710 // smmla v16.4s, v24.16b, v11.16b\n" + ".inst 0x4e82a43e // smmla v30.4s, v1.16b, v2.16b\n" + ".inst 0x4e8ba433 // smmla v19.4s, v1.16b, v11.16b\n" + ".inst 0x4e8ca6ad // smmla v13.4s, v21.16b, v12.16b\n" + ".inst 0x4e86a6b9 // smmla v25.4s, v21.16b, v6.16b\n" + ".inst 0x4e8ca69a // smmla v26.4s, v20.16b, v12.16b\n" + ".inst 0x4e86a68a // smmla v10.4s, v20.16b, v6.16b\n" + ".inst 0x4e9ca6ae // smmla v14.4s, v21.16b, v28.16b\n" + ".inst 0x4e9ba6b0 // smmla v16.4s, v21.16b, v27.16b\n" + ".inst 0x4e9ca69e // smmla v30.4s, v20.16b, v28.16b\n" + ".inst 0x4e9ba693 // smmla v19.4s, v20.16b, v27.16b\n" + "bgt 13b\n" + "ldr q5, [x26, #0x0]\n" + "ldr q20, [x26, #0x10]\n" + "uzp1 v2.2d, v13.2d, v25.2d\n" + "uzp1 v21.2d, v14.2d, v16.2d\n" + "ldr q6, [x22, #0x0]\n" + "ldr q1, [x26, #0x20]\n" + "uzp2 v4.2d, v13.2d, v25.2d\n" + "uzp2 v28.2d, v14.2d, v16.2d\n" + "ldr q7, [x26, #0x30]\n" + "ldr q17, [x22, #0x10]\n" + "uzp1 v29.2d, v26.2d, v10.2d\n" + "uzp1 v15.2d, v30.2d, v19.2d\n" + "ld1r { v27.4s }, [%x[clamp_vals]]\n" + "uzp2 v26.2d, v26.2d, v10.2d\n" + "uzp2 v25.2d, v30.2d, v19.2d\n" + "add x20, %x[clamp_vals], #0x4\n" + "ld1r { v19.4s }, [x20]\n" + "mla v2.4s, v5.4s, v6.s[0]\n" + "mla v21.4s, v20.4s, v6.s[0]\n" + "cmp x25, #0x8\n" + "mla v4.4s, v5.4s, v6.s[1]\n" + "mla v28.4s, v20.4s, v6.s[1]\n" + "fmul v23.4s, v1.4s, v17.s[0]\n" + "add x26, x26, #0x40\n" + "mla v29.4s, v5.4s, v6.s[2]\n" + "mla v15.4s, v20.4s, v6.s[2]\n" + "fmul v31.4s, v7.4s, v17.s[0]\n" + "mla v26.4s, v5.4s, v6.s[3]\n" + "mla v25.4s, v20.4s, v6.s[3]\n" + "fmul v22.4s, v1.4s, v17.s[1]\n" + "scvtf v2.4s, v2.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v4.4s, v4.4s\n" + "scvtf v28.4s, v28.4s\n" + "fmul v20.4s, v7.4s, v17.s[1]\n" + "scvtf v29.4s, v29.4s\n" + "fmul v24.4s, v1.4s, v17.s[2]\n" + "scvtf v15.4s, v15.4s\n" + "fmul v10.4s, v7.4s, v17.s[2]\n" + "scvtf v26.4s, v26.4s\n" + "fmul v0.4s, v1.4s, v17.s[3]\n" + "scvtf v25.4s, v25.4s\n" + "fmul v8.4s, v7.4s, v17.s[3]\n" + "fmul v2.4s, v2.4s, v23.4s\n" + "fmul v21.4s, v21.4s, v31.4s\n" + "fmul v4.4s, v4.4s, v22.4s\n" + "fmul v28.4s, v28.4s, v20.4s\n" + "fmul v29.4s, v29.4s, v24.4s\n" + "fmul v15.4s, v15.4s, v10.4s\n" + "fmul v26.4s, v26.4s, v0.4s\n" + "fmul v25.4s, v25.4s, v8.4s\n" + "fmax v2.4s, v2.4s, v27.4s\n" + "fmax v21.4s, v21.4s, v27.4s\n" + "fmax v4.4s, v4.4s, v27.4s\n" + "fmax v28.4s, v28.4s, v27.4s\n" + "fmax v29.4s, v29.4s, v27.4s\n" + "fmax v15.4s, v15.4s, v27.4s\n" + "fmax v26.4s, v26.4s, v27.4s\n" + "fmax v25.4s, v25.4s, v27.4s\n" + "fmin v2.4s, v2.4s, v19.4s\n" + "fmin v21.4s, v21.4s, v19.4s\n" + "fmin v4.4s, v4.4s, v19.4s\n" + "fmin v28.4s, v28.4s, v19.4s\n" + "fmin v29.4s, v29.4s, v19.4s\n" + "fmin v15.4s, v15.4s, v19.4s\n" + "fmin v26.4s, v26.4s, v19.4s\n" + "fmin v25.4s, v25.4s, v19.4s\n" + "bge 18f\n" + "mov x23, %x[dst]\n" + "cmp x12, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GE\n" + "cmp x12, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GE\n" + "cmp x12, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GE\n" + "tbz x25, #2, 15f\n" + "st1 { v26.4s }, [x20], #0x10\n" + "st1 { v29.4s }, [x21], #0x10\n" + "st1 { v4.4s }, [x22], #0x10\n" + "st1 { v2.4s }, [x23], #0x10\n" + "tbz x25, #1, 14f\n" + "str d25, [x20], #0x8\n" + "str d15, [x21], #0x8\n" + "str d28, [x22], #0x8\n" + "str d21, [x23], #0x8\n" + "tbz x25, #0, 17f\n" + "st1 { v25.s }[2], [x20]\n" + "st1 { v15.s }[2], [x21]\n" + "st1 { v28.s }[2], [x22]\n" + "st1 { v21.s }[2], [x23]\n" + "b 17f\n" + "14:" // Row tail: Output block 0: partial_1_4 + "tbz x25, #0, 17f\n" + "str s25, [x20, #0x0]\n" + "str s15, [x21, #0x0]\n" + "str s28, [x22, #0x0]\n" + "str s21, [x23, #0x0]\n" + "b 17f\n" + "15:" // Row tail: Output block 0: partial_2_0 + "tbz x25, #1, 16f\n" + "str d26, [x20], #0x8\n" + "str d29, [x21], #0x8\n" + "str d4, [x22], #0x8\n" + "str d2, [x23], #0x8\n" + "tbz x25, #0, 17f\n" + "st1 { v26.s }[2], [x20]\n" + "st1 { v29.s }[2], [x21]\n" + "st1 { v4.s }[2], [x22]\n" + "st1 { v2.s }[2], [x23]\n" + "b 17f\n" + "16:" // Row tail: Output block 0: partial_1_0 + "str s26, [x20, #0x0]\n" + "str s29, [x21, #0x0]\n" + "str s4, [x22, #0x0]\n" + "str s2, [x23, #0x0]\n" + "17:" // Row tail: Output block 0: Done + "b 19f\n" + "18:" // Row tail: Full output + "mov x20, %x[dst]\n" + "cmp x12, #0x1\n" + "str q2, [x20, #0x0]\n" + "str q21, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 19f\n" + "cmp x12, #0x2\n" + "str q4, [x20, #0x0]\n" + "str q28, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 19f\n" + "cmp x12, #0x3\n" + "str q29, [x20, #0x0]\n" + "str q15, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 19f\n" + "str q26, [x20, #0x0]\n" + "str q25, [x20, #0x10]\n" + "19:" // Row tail: Output stage exit + "subs x25, x25, #0x8\n" + "add %x[dst], %x[dst], #0x20\n" + "bgt 12b\n" + "subs x12, x12, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x11\n" + "mov %x[dst], x24\n" + "bgt 11b\n" + "20:" // Row tail: Row loop skip + : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) + : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), + [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); #else KAI_ASSERT(false); KAI_UNUSED(m); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h index ffe9ef9c..12c5aca2 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h @@ -35,7 +35,7 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(vo /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with - * the @ref kai_run_lhs_quant_pack_qa8dsP_f32 function + * the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel * * @return the mr value */ @@ -43,7 +43,7 @@ size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the nr value */ @@ -51,7 +51,7 @@ size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the kr value */ @@ -59,7 +59,7 @@ size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 function + * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel * * @return the sr value */ @@ -71,8 +71,8 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); * * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. * - * @param[in] m_idx Row index in the LHS matrix (not packed). - * @param[in] k Total number of columns in the LHS matrix (not packed). It must be a multiple of 64. + * @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 8 + * @param[in] k Total number of columns in the LHS matrix (not packed). * * return the offset in bytes to the packed LHS matrix */ @@ -80,10 +80,10 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_n /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, - * which contains the packed 4-bit quantized symmetric per-channel (qs4cx) values. + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. * - * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. - * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be a multiple of 64. + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 8. + * @param[in] k The common dimension between the LHS and RHS matrix (K). * * return the offset in bytes to the packed RHS matrix */ @@ -92,9 +92,9 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_n /** * @brief Function to calculate the offset in bytes for the DST matrix * - * @param[in] m_idx Row index in the DST matrix. It must be either 1, 2, 3, or any multiple of 4. - * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. - * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * @param[in] m_idx Row index in the DST matrix. It must be a multiple of 8. + * @param[in] n_idx Column index in the DST matrix. It must be multiple of 8. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix * * return the DST offset in bytes */ @@ -102,10 +102,10 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8m size_t m_idx, size_t n_idx, size_t dst_stride); /** - * @brief Function to query the size in bytes for the constant workspace. + * @brief Function to query the size in bytes for the destination matrix. * - * @param[in] m Number of rows in the destination (DST) matrix. It must be either 1, 2, 3, or any multiple of 4. - * @param[in] n Number of columns in the destination (DST) matrix. It must be a multiple 4. + * @param[in] m Number of rows in the destination (DST) matrix. + * @param[in] n Number of columns in the destination (DST) matrix. */ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t m, size_t n); @@ -118,9 +118,9 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm( * Accumulation performed in a single for loop: 32 * Instruction used: i8mm * - * @param[in] m The number of output rows written. It must be either 1, 2, 3, or any multiple of 4. - * @param[in] n The number of output columns written. It must be a multiple of 4. - * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 32. + * @param[in] m The number of output rows written. + * @param[in] n The number of output columns written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. * @param[in] lhs_packed The LHS matrix packed. * When the activation are dynamically quantized, you can obtain this matrix * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h index dbf08a0a..0365a34f 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h @@ -15,35 +15,36 @@ extern "C" { // In this case, the micro-kernel type is: matmul_clamp_f32_qai8dxp_qsu4cxp /** Micro-kernel helper functions ("get" methods) */ -typedef size_t (*kai_get_m_step_func_t)(void); -typedef size_t (*kai_get_n_step_func_t)(void); -typedef size_t (*kai_get_mr_func_t)(void); -typedef size_t (*kai_get_nr_func_t)(void); -typedef size_t (*kai_get_kr_func_t)(void); -typedef size_t (*kai_get_sr_func_t)(void); -typedef size_t (*kai_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); -typedef size_t (*kai_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); -typedef size_t (*kai_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride); -typedef size_t (*kai_get_dst_size_func_t)(size_t m, size_t n); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_dst_size_func_t)(size_t m, size_t n); /** Micro-kernel core function ("run" method) */ -typedef void (*kai_run_matmul_func_t)( +typedef void (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_run_matmul_func_t)( size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); /** Micro-kernel interface */ struct kai_matmul_clamp_f32_qai8dxp_qsu4cxp_ukernel { - kai_get_m_step_func_t get_m_step; - kai_get_n_step_func_t get_n_step; - kai_get_mr_func_t get_mr; - kai_get_nr_func_t get_nr; - kai_get_nr_func_t get_kr; - kai_get_sr_func_t get_sr; - kai_get_lhs_packed_offset_func_t get_lhs_packed_offset; - kai_get_rhs_packed_offset_func_t get_rhs_packed_offset; - kai_get_dst_offset_func_t get_dst_offset; - kai_get_dst_size_func_t get_dst_size; - kai_run_matmul_func_t run_matmul; + kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_m_step_func_t get_m_step; + kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_n_step_func_t get_n_step; + kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_mr_func_t get_mr; + kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_nr_func_t get_nr; + kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_nr_func_t get_kr; + kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_sr_func_t get_sr; + kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f32_qai8dxp_qsu4cxp_run_matmul_func_t run_matmul; }; #ifdef __cplusplus -- GitLab From cdd8a68aa62705ed6ac7832588548fbfd09b2fc8 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 23 May 2024 12:48:27 +0100 Subject: [PATCH 09/14] Fix compilation issue - Fix target project in the CMakeLists.txt - Remove unused variables Signed-off-by: Gian Marco Iodice --- CMakeLists.txt | 4 ++-- src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c | 4 ++-- ..._matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c | 1 - ..._matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c | 1 - ..._matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c | 1 - ..._matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c | 1 - 6 files changed, 4 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index dab8f92d..447eb103 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,7 +58,7 @@ set(KLEIDIAI_WARNING_FLAGS add_library(kleidiai) -target_sources(Kleidiai PRIVATE +target_sources(kleidiai PRIVATE src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c @@ -69,7 +69,7 @@ target_sources(Kleidiai PRIVATE src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c) target_include_directories(kleidiai - PRIVATE . + PRIVATE src/ ) target_compile_options(kleidiai diff --git a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c index 4a5fa657..19b93c24 100644 --- a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c +++ b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c @@ -93,7 +93,7 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( max0 = vmaxvq_f32(vmax0); min0 = vminvq_f32(vmin0); #endif - for (; k_idx < k; ++k_idx) { + for (; k_idx < (int32_t)k; ++k_idx) { const float src0_0 = *(src_ptr + (size_t)k_idx); max0 = KAI_MAX(src0_0, max0); min0 = KAI_MIN(src0_0, min0); @@ -132,7 +132,7 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( // Quantize the channels k_idx = 0; - for (; k_idx < k_internal; k_idx += k_block_len) { + for (; k_idx < (int32_t)k_internal; k_idx += k_block_len) { for (size_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { // Clamp at the last valid k-index const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c index 2671a2f8..9fb104d8 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c @@ -16,7 +16,6 @@ static const size_t kai_mr = 4; static const size_t kai_nr = 4; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_k0 = 32; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c index 62d1611c..1e11887e 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c @@ -16,7 +16,6 @@ static const size_t kai_mr = 4; static const size_t kai_nr = 4; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_k0 = 32; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c index 0c4a9225..00d1ed81 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c @@ -16,7 +16,6 @@ static const size_t kai_mr = 4; static const size_t kai_nr = 8; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_k0 = 32; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c index 1bac2078..391efd96 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c @@ -16,7 +16,6 @@ static const size_t kai_mr = 4; static const size_t kai_nr = 8; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_k0 = 32; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); -- GitLab From e6681fdc6c8493a1314cbf524959116ea6333405 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 23 May 2024 13:48:19 +0100 Subject: [PATCH 10/14] Remove the rhs_offset = 0 case in the RHS packing Signed-off-by: Gian Marco Iodice --- src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c | 14 ++--- .../kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c | 54 ++++++------------- ...ai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c | 3 +- ...ai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c | 2 +- 4 files changed, 25 insertions(+), 48 deletions(-) diff --git a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c index 19b93c24..03bbecb0 100644 --- a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c +++ b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c @@ -64,8 +64,8 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( const float* src_ptr = lhs; const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr); - const size_t k_block_len = kr / sr; const size_t k_internal = kai_k_roundedup(k, kr, sr); + const int32_t k_block_len = (int32_t)(kr / sr); for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { float max0 = -FLT_MAX; @@ -103,13 +103,13 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( const float qmin = (float)INT8_MIN; const float qmax = (float)INT8_MAX; - const float rmin0 = KAI_MIN(0.0f, min0); - const float rmax0 = KAI_MAX(0.0f, max0); + const float rmin0 = KAI_MIN(0.0F, min0); + const float rmax0 = KAI_MAX(0.0F, max0); - const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); // Reciprocal to quantize - const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; const float descaled_min0 = rmin0 * scale0; const float descaled_max0 = rmax0 * scale0; @@ -124,7 +124,7 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( zero_point0 = KAI_MIN(zero_point0, qmax); // Round to nearest integer - const int32_t nudged_zero_point0 = lrintf(zero_point0); + const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); const size_t dst_x = ((row_idx + m_idx_start) % mr); @@ -140,7 +140,7 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( const float src0_0 = *(src_ptr + k_idx_start); // Scale the values - int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); + int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); v0_s32 = v0_s32 + nudged_zero_point0; v0_s32 = KAI_MAX(v0_s32, INT8_MIN); diff --git a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c index 421e4430..6062efe1 100644 --- a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c +++ b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c @@ -31,7 +31,7 @@ inline static size_t kai_rhs_packed_stride(size_t k, size_t kr, size_t nr, size_ } inline static int8_t kai_int4_sign_extend(int8_t x) { - return (x ^ 0x8) - 8; + return (x ^ 0x80) - (int8_t)8; } size_t kai_get_n_step_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t nr) { @@ -81,7 +81,7 @@ void kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( const size_t k_internal = kai_k_roundedup(k, kr, sr); for (size_t y = 0; y < n; y += nr) { - const uint8_t* src_row = (const uint8_t*)rhs + y * rhs_stride; + const uint8_t* src_row = rhs + y * rhs_stride; uint8_t* dst_row = (uint8_t*)rhs_packed + (y / nr) * rhs_packed_stride; int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); @@ -110,46 +110,22 @@ void kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( byte1 = src_row[src_addr_byte1]; } - if (rhs_zero_point == 0) { - int8_t src_x0_lo = (byte0 & 0x0F); - int8_t src_x1_lo = (byte0 >> 4); + const uint8_t src_x0_lo = (byte0 & 0x0F); + const uint8_t src_x1_lo = (byte0 >> 4); - int8_t src_x0_hi = (byte1 & 0x0F); - int8_t src_x1_hi = (byte1 >> 4); + const uint8_t src_x0_hi = (byte1 & 0x0F); + const uint8_t src_x1_hi = (byte1 >> 4); - const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); - const int8_t dst_qs1 = src_x1_lo | (src_x1_hi << 4); + sums[i] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; + sums[i] += (int32_t)src_x1_lo + (int32_t)src_x1_hi - 2 * (int32_t)rhs_zero_point; - src_x0_lo = kai_int4_sign_extend(src_x0_lo); - src_x1_lo = kai_int4_sign_extend(src_x1_lo); - src_x0_hi = kai_int4_sign_extend(src_x0_hi); - src_x1_hi = kai_int4_sign_extend(src_x1_hi); - sums[i] += src_x0_lo + src_x0_hi; - sums[i] += src_x1_lo + src_x1_hi; + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + const uint8_t dst_qs1 = src_x1_lo | (src_x1_hi << 4); - *(int8_t*)dst_row = dst_qs0; - dst_row += sizeof(int8_t); - *(int8_t*)dst_row = dst_qs1; - dst_row += sizeof(int8_t); - - } else { - const uint8_t src_x0_lo = (byte0 & 0x0F); - const uint8_t src_x1_lo = (byte0 >> 4); - - const uint8_t src_x0_hi = (byte1 & 0x0F); - const uint8_t src_x1_hi = (byte1 >> 4); - - sums[i] += src_x0_lo + src_x0_hi - 2 * rhs_zero_point; - sums[i] += src_x1_lo + src_x1_hi - 2 * rhs_zero_point; - - const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); - const uint8_t dst_qs1 = src_x1_lo | (src_x1_hi << 4); - - *dst_row = dst_qs0 ^ 0x88; - dst_row += sizeof(uint8_t); - *dst_row = dst_qs1 ^ 0x88; - dst_row += sizeof(uint8_t); - } + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + *dst_row = dst_qs1 ^ 0x88; + dst_row += sizeof(uint8_t); } } } @@ -163,7 +139,7 @@ void kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( // Adjust the scales for (size_t i = 0; i < nr; ++i) { - *((float*)(dst_row)) = scale[y + i] * 0.0625f; + *((float*)(dst_row)) = scale[y + i] * 0.0625F; dst_row += sizeof(float); } } diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c index a51b2dac..1bedff91 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c @@ -16,7 +16,6 @@ static const size_t kai_mr = 1; static const size_t kai_nr = 4; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_k0 = kai_kr * kai_sr; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); @@ -103,6 +102,8 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod( return; } + const size_t kai_k0 = kai_kr * kai_sr; + const size_t num_rows = m; const size_t num_cols = n; diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c index e11804a3..a66365f6 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c @@ -16,7 +16,6 @@ static const size_t kai_mr = 1; static const size_t kai_nr = 8; static const size_t kai_kr = 16; static const size_t kai_sr = 2; -static const size_t kai_k0 = kai_kr * kai_sr; static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); @@ -103,6 +102,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( return; } + const size_t kai_k0 = kai_kr * kai_sr; const size_t num_rows = m; const size_t num_cols = n; -- GitLab From 155c9063b5b57c396f60c884cfbc944c364465b2 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 23 May 2024 13:56:24 +0100 Subject: [PATCH 11/14] Add explicit cast to k_block_len Signed-off-by: Gian Marco Iodice --- src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c index 03bbecb0..d93469a0 100644 --- a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c +++ b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c @@ -133,7 +133,7 @@ void kai_run_lhs_quant_pack_qai8dxp_f32( // Quantize the channels k_idx = 0; for (; k_idx < (int32_t)k_internal; k_idx += k_block_len) { - for (size_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { + for (size_t k_block_idx = 0; k_block_idx < (size_t)k_block_len; ++k_block_idx) { // Clamp at the last valid k-index const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k); -- GitLab From efce9aeeedd368202715bb2f56cf54bd36e61d7a Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 23 May 2024 14:02:05 +0100 Subject: [PATCH 12/14] Remove unused kai_int4_sign_extend() function Signed-off-by: Gian Marco Iodice --- src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c index 6062efe1..8832445b 100644 --- a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c +++ b/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c @@ -30,10 +30,6 @@ inline static size_t kai_rhs_packed_stride(size_t k, size_t kr, size_t nr, size_ return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); } -inline static int8_t kai_int4_sign_extend(int8_t x) { - return (x ^ 0x80) - (int8_t)8; -} - size_t kai_get_n_step_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t nr) { return nr; } -- GitLab From 823dcc491e95f3765a84286c4cdad6c28242951d Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 23 May 2024 14:16:21 +0100 Subject: [PATCH 13/14] Rename qsu4cx to qsi4cx - Rename qsu4cx to qsi4cx as the integer 4-bit is signed Signed-off-by: Gian Marco Iodice --- CMakeLists.txt | 14 +- .../CMakeLists.txt | 38 ++++ .../matmul_clamp_f32_qai8dxp_qsi4cxp.cpp} | 168 +++++++++--------- .../CMakeLists.txt | 38 ---- ... => kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c} | 14 +- ... => kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h} | 14 +- ...i8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c} | 24 +-- ...i8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h} | 32 ++-- ...i8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c} | 24 +-- ...i8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h} | 32 ++-- ..._qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c} | 24 +-- ..._qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h} | 32 ++-- ..._qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c} | 24 +-- ..._qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h} | 32 ++-- ..._qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c} | 24 +-- ..._qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h} | 32 ++-- ..._qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c} | 24 +-- ..._qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h} | 32 ++-- ...tmul_clamp_f32_qai8dxp_qsi4cxp_interface.h | 52 ++++++ ...tmul_clamp_f32_qai8dxp_qsu4cxp_interface.h | 52 ------ 20 files changed, 363 insertions(+), 363 deletions(-) create mode 100644 examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt rename examples/{matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp => matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp} (73%) delete mode 100644 examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt rename src/matmul/{kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c => kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c} (92%) rename src/matmul/{kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h => kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h} (91%) rename src/matmul/{matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c => matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c} (91%) rename src/matmul/{matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h => matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h} (82%) rename src/matmul/{matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c => matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c} (93%) rename src/matmul/{matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h => matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h} (82%) rename src/matmul/{matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c => matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c} (92%) rename src/matmul/{matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h => matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h} (82%) rename src/matmul/{matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c => matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c} (95%) rename src/matmul/{matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h => matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h} (82%) rename src/matmul/{matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c => matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c} (94%) rename src/matmul/{matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h => matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h} (82%) rename src/matmul/{matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c => matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c} (97%) rename src/matmul/{matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h => matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h} (82%) create mode 100644 src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h delete mode 100644 src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 447eb103..40ab01af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,13 +60,13 @@ add_library(kleidiai) target_sources(kleidiai PRIVATE src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c - src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c - src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c - src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c - src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c - src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c - src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c - src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c) + src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c + src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c + src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c + src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c + src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c + src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c + src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c) target_include_directories(kleidiai PRIVATE src/ diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt new file mode 100644 index 00000000..ccc8b840 --- /dev/null +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt @@ -0,0 +1,38 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.16) + +# KleidiAI include directories +include_directories( + ../../src/ + ../../src/matmul/ + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") + +# Files requires to build the executable +add_executable(matmul_clamp_f32_qai8dxp_qsi4cxp + matmul_clamp_f32_qai8dxp_qsi4cxp.cpp + ../../src/kai_common.h + ../../src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h + ../../src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c + ../../src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h + ../../src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c) + diff --git a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp similarity index 73% rename from examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp rename to examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp index 80d529ea..0f404cf0 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/matmul_clamp_f32_qai8dxp_qsu4cxp.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp @@ -13,97 +13,97 @@ #include #include "kai_lhs_quant_pack_qai8dxp_f32.h" -#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h" -#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h" -#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h" -#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h" -#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h" -#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h" -#include "kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h" -#include "kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" +#include "kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" #define INT4_MIN (-8) #define INT4_MAX (7) // Micro-kernel interface struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { - kai_matmul_clamp_f32_qai8dxp_qsu4cxp_ukernel ukernel; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; std::string name = {}; }; kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod, - "matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod, - "matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm"}, - {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm, - "matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm"}, }; // Number of micro-kernel variants stored in the array @@ -384,7 +384,7 @@ int main(int argc, char** argv) { // Get the size in bytes for the packed matrices const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); - const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(n, k, nr, kr, sr); + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(n, k, nr, kr, sr); const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); // Allocate the matrices @@ -394,12 +394,12 @@ int main(int argc, char** argv) { // If the RHS matrix contains constant values, the packing can be performed // only once - struct kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0_params params; + struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; // RHS packing - kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( + kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( 1, n, k, nr, kr, sr, (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS NULL, // Bias diff --git a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt deleted file mode 100644 index 8de22326..00000000 --- a/examples/matmul_clamp_f32_qai8dxp_qsu4cxp/CMakeLists.txt +++ /dev/null @@ -1,38 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -# -# SPDX-License-Identifier: Apache-2.0 -# - -cmake_minimum_required(VERSION 3.16) - -# KleidiAI include directories -include_directories( - ../../src/ - ../../src/matmul/ - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") - -# Files requires to build the executable -add_executable(matmul_clamp_f32_qai8dxp_qsu4cxp - matmul_clamp_f32_qai8dxp_qsu4cxp.cpp - ../../src/kai_common.h - ../../src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h - ../../src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c - ../../src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h - ../../src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h - ../../src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c) - diff --git a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c b/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c similarity index 92% rename from src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c rename to src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c index 8832445b..321ae08c 100644 --- a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c +++ b/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h" +#include "kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" #include #include @@ -30,31 +30,31 @@ inline static size_t kai_rhs_packed_stride(size_t k, size_t kr, size_t nr, size_ return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); } -size_t kai_get_n_step_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t nr) { +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr) { return nr; } -size_t kai_get_rhs_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride) { +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride) { return n_idx * rhs_stride; } -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { KAI_ASSERT((n_idx % nr) == 0); return (n_idx / nr) * kai_rhs_packed_stride(k, kr, nr, sr); } -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { const size_t num_rows = kai_roundup(n, nr) / nr; return num_rows * kai_rhs_packed_stride(k, kr, nr, sr); } -void kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( +void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const int32_t* bias, const float* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0_params* params) { + const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params) { KAI_ASSERT((k % 2) == 0); KAI_ASSERT(num_groups == 1); KAI_ASSERT(bias == NULL); diff --git a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h b/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h similarity index 91% rename from src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h rename to src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h index bc607599..8340c2bc 100644 --- a/src/matmul/kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.h +++ b/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h @@ -12,7 +12,7 @@ extern "C" { #endif -struct kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0_params { +struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params { int8_t lhs_zero_point; uint8_t rhs_zero_point; }; @@ -26,7 +26,7 @@ struct kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0_params { * * @return the n step value */ -size_t kai_get_n_step_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t nr); +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr); /** * @brief Function to calculate the offset in bytes for the RHS matrix (not packed), which holds @@ -39,7 +39,7 @@ size_t kai_get_n_step_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t nr); * * return the offset in bytes to the RHS matrix (not packed) */ -size_t kai_get_rhs_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride); +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -53,7 +53,7 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n_idx, size_t r * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr); /** @@ -67,7 +67,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( * * return the size in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr); +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr); /** * @brief Micro-kernel to pack the RHS matrix. @@ -91,10 +91,10 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0(size_t n, size_t * @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. * @param[in] params Parameters for the micro-kernel. */ -void kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0( +void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const int32_t* bias, const float* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0_params* params); + const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params); #ifdef __cplusplus } diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c similarity index 91% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c index 1bedff91..6fee9b51 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" #include #include @@ -44,43 +44,43 @@ inline static size_t kai_rhs_packed_stride(size_t k) { return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); } -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { return kai_m_step; } -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void) { +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { return kai_n_step; } -size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void) { +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); } -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride) { KAI_ASSERT((m_idx % kai_m_step) == 0); KAI_ASSERT((n_idx % kai_n_step) == 0); @@ -88,11 +88,11 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dot return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod( +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_DOTPROD) diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h similarity index 82% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h index cd9e64c3..558f0f4c 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h @@ -22,7 +22,7 @@ extern "C" { * * @return the m step value */ -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void); +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); /** * @brief Function to get the n step value. @@ -31,7 +31,7 @@ size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod * * @return the n step */ -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void); +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with @@ -39,31 +39,31 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod * * @return the mr value */ -size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the nr value */ -size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the kr value */ -size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the sr value */ -size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -76,7 +76,7 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(voi * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -87,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_n * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -98,7 +98,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_n * * return the DST offset in bytes */ -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride); /** @@ -107,13 +107,13 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dot * @param[in] m Number of rows in the destination (DST) matrix * @param[in] n Number of columns in the destination (DST) matrix */ -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n); /** * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed - * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsu4cx) and packed. + * RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. * Output tile: (rows x cols) = 1 x 4 * Accumulation performed in a single for loop: 64 * Instruction used: dotprod @@ -126,14 +126,14 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotpr * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs * both the dynamic quantization to 8-bit and activation packing in a single step. * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref - * kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 + * kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 * @param[out] dst Result of the vector-by-matrix * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) * @param[in] scalar_min Min value used to clamp the final result. * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp4x8_1x4x32_neon_dotprod( +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c similarity index 93% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c index a66365f6..f243f3d7 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" #include #include @@ -44,43 +44,43 @@ inline static size_t kai_rhs_packed_stride(size_t k) { return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); } -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) { return kai_m_step; } -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void) { +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) { return kai_n_step; } -size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void) { +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); } -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride) { KAI_ASSERT((m_idx % kai_m_step) == 0); KAI_ASSERT((n_idx % kai_n_step) == 0); @@ -88,11 +88,11 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dot return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod( size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_DOTPROD) diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h similarity index 82% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h index 06b58e26..69bf5c08 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h @@ -22,7 +22,7 @@ extern "C" { * * @return the m step value */ -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void); +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void); /** * @brief Function to get the n step value. @@ -31,7 +31,7 @@ size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod * * @return the n step */ -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void); +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void); /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with @@ -39,31 +39,31 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod * * @return the mr value */ -size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the nr value */ -size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the kr value */ -size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void); /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the sr value */ -size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -76,7 +76,7 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(voi * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -87,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_n * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -98,7 +98,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_n * * return the DST offset in bytes */ -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod( size_t m_idx, size_t n_idx, size_t dst_stride); /** @@ -107,13 +107,13 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dot * @param[in] m Number of rows in the destination (DST) matrix * @param[in] n Number of columns in the destination (DST) matrix */ -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t m, size_t n); /** * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed - * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsu4cx) and packed. + * RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. * Output tile: (rows x cols) = 1 x 8 * Accumulation performed in a single for loop: 64 * Instruction used: dotprod @@ -126,14 +126,14 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotpr * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs * both the dynamic quantization to 8-bit and activation packing in a single step. * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref - * kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 + * kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 * @param[out] dst Result of the vector-by-matrix * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) * @param[in] scalar_min Min value used to clamp the final result. * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c similarity index 92% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c index 9fb104d8..0bdd4f9e 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" #include #include @@ -44,43 +44,43 @@ inline static size_t kai_rhs_packed_stride(size_t k) { return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); } -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) { return kai_m_step; } -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) { +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) { return kai_n_step; } -size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) { +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); } -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( size_t m_idx, size_t n_idx, size_t dst_stride) { KAI_ASSERT((m_idx % kai_m_step) == 0); KAI_ASSERT((n_idx % kai_n_step) == 0); @@ -88,11 +88,11 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8m return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm( +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h similarity index 82% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h index 0f339b33..782fa28e 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h @@ -22,7 +22,7 @@ extern "C" { * * @return the m step value */ -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void); /** * @brief Function to get the n step value. @@ -31,7 +31,7 @@ size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(vo * * @return the n step */ -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void); /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with @@ -39,31 +39,31 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(vo * * @return the mr value */ -size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the nr value */ -size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the kr value */ -size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void); /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the sr value */ -size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -76,7 +76,7 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(void); * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -87,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_n * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -98,7 +98,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_n * * return the DST offset in bytes */ -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( size_t m_idx, size_t n_idx, size_t dst_stride); /** @@ -107,13 +107,13 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8m * @param[in] m Number of rows in the destination (DST) matrix. * @param[in] n Number of columns in the destination (DST) matrix. */ -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t m, size_t n); /** * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed - * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsu4cx) and packed. + * RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. * Output tile: (rows x cols) = 4 x 4 * Accumulation performed in a single for loop: 32 * Instruction used: i8mm @@ -126,14 +126,14 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm( * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs * both the dynamic quantization to 8-bit and activation packing in a single step. * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref - * kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 + * kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 * @param[out] dst Result of the vector-by-matrix * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) * @param[in] scalar_min Min value used to clamp the final result. * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_4x4x32_neon_i8mm( +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c similarity index 95% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c index 1e11887e..2c348c37 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" #include #include @@ -44,43 +44,43 @@ inline static size_t kai_rhs_packed_stride(size_t k) { return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); } -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { return kai_m_step; } -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) { +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { return kai_n_step; } -size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) { +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); } -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( size_t m_idx, size_t n_idx, size_t dst_stride) { KAI_ASSERT((m_idx % kai_m_step) == 0); KAI_ASSERT((n_idx % kai_n_step) == 0); @@ -88,11 +88,11 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8m return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm( +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h similarity index 82% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h index e09d5abd..1f350fe9 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h @@ -22,7 +22,7 @@ extern "C" { * * @return the m step value */ -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); /** * @brief Function to get the n step value. @@ -31,7 +31,7 @@ size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(vo * * @return the n step */ -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with @@ -39,31 +39,31 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(vo * * @return the mr value */ -size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the nr value */ -size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the kr value */ -size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the sr value */ -size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -76,7 +76,7 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(void); * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -87,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_n * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -98,7 +98,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_n * * return the DST offset in bytes */ -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( size_t m_idx, size_t n_idx, size_t dst_stride); /** @@ -107,13 +107,13 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8m * @param[in] m Number of rows in the destination (DST) matrix. * @param[in] n Number of columns in the destination (DST) matrix. */ -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n); /** * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed - * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsu4cx) and packed. + * RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. * Output tile: (rows x cols) = 8 x 4 * Accumulation performed in a single for loop: 32 * Instruction used: i8mm @@ -126,14 +126,14 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm( * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs * both the dynamic quantization to 8-bit and activation packing in a single step. * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref - * kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 + * kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 * @param[out] dst Result of the vector-by-matrix * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) * @param[in] scalar_min Min value used to clamp the final result. * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp4x8_8x4x32_neon_i8mm( +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c similarity index 94% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c index 00d1ed81..2704c853 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" #include #include @@ -44,43 +44,43 @@ inline static size_t kai_rhs_packed_stride(size_t k) { return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); } -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) { return kai_m_step; } -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) { +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) { return kai_n_step; } -size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) { +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); } -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( size_t m_idx, size_t n_idx, size_t dst_stride) { KAI_ASSERT((m_idx % kai_m_step) == 0); KAI_ASSERT((n_idx % kai_n_step) == 0); @@ -88,11 +88,11 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8m return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm( +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h similarity index 82% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h index a043cf0e..6e6363c7 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h @@ -22,7 +22,7 @@ extern "C" { * * @return the m step value */ -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the n step value. @@ -31,7 +31,7 @@ size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(vo * * @return the n step */ -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with @@ -39,31 +39,31 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(vo * * @return the mr value */ -size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the nr value */ -size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the kr value */ -size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the sr value */ -size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -76,7 +76,7 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(void); * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -87,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_n * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -98,7 +98,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_n * * return the DST offset in bytes */ -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( size_t m_idx, size_t n_idx, size_t dst_stride); /** @@ -107,13 +107,13 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8m * @param[in] m Number of rows in the destination (DST) matrix. * @param[in] n Number of columns in the destination (DST) matrix. */ -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t m, size_t n); /** * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed - * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsu4cx) and packed. + * RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. * Output tile: (rows x cols) = 4 x 8 * Accumulation performed in a single for loop: 32 * Instruction used: i8mm @@ -126,14 +126,14 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm( * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs * both the dynamic quantization to 8-bit and activation packing in a single step. * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref - * kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 + * kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 * @param[out] dst Result of the vector-by-matrix * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) * @param[in] scalar_min Min value used to clamp the final result. * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_4x8x32_neon_i8mm( +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c similarity index 97% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c rename to src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c index 391efd96..6a40b814 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.c +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" #include #include @@ -44,43 +44,43 @@ inline static size_t kai_rhs_packed_stride(size_t k) { return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); } -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) { +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void) { return kai_m_step; } -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) { +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void) { return kai_n_step; } -size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) { +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void) { return kai_mr; } -size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) { +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void) { return kai_nr; } -size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) { +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void) { return kai_kr; } -size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void) { +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void) { return kai_sr; } -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t m_idx, size_t k) { +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_m_step) == 0); return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); } -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t n_idx, size_t k) { +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_n_step) == 0); return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); } -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( size_t m_idx, size_t n_idx, size_t dst_stride) { KAI_ASSERT((m_idx % kai_m_step) == 0); KAI_ASSERT((n_idx % kai_n_step) == 0); @@ -88,11 +88,11 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8m return (n_idx * sizeof(float)) + m_idx * dst_stride; } -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t m, size_t n) { +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(size_t m, size_t n) { return m * n * sizeof(float); } -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm( +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { #if defined(__ARM_FEATURE_MATMUL_INT8) diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h similarity index 82% rename from src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h rename to src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h index 12c5aca2..5d2e2d59 100644 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm.h +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h @@ -22,7 +22,7 @@ extern "C" { * * @return the m step value */ -size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void); /** * @brief Function to get the n step value. @@ -31,7 +31,7 @@ size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(vo * * @return the n step */ -size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void); /** * @brief Function to get the mr value, which must be used to pack the LHS matrix with @@ -39,31 +39,31 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(vo * * @return the mr value */ -size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void); /** * @brief Function to get the nr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the nr value */ -size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void); /** * @brief Function to get the kr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the kr value */ -size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void); /** * @brief Function to get the sr value, which must be used to pack the RHS matrix with - * the @ref kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 micro-kernel + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel * * @return the sr value */ -size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void); /** * @brief Function to calculate the offset in bytes for the packed LHS matrix, @@ -76,7 +76,7 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(void); * * return the offset in bytes to the packed LHS matrix */ -size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t m_idx, size_t k); +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(size_t m_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the packed RHS matrix, @@ -87,7 +87,7 @@ size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_n * * return the offset in bytes to the packed RHS matrix */ -size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t n_idx, size_t k); +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(size_t n_idx, size_t k); /** * @brief Function to calculate the offset in bytes for the DST matrix @@ -98,7 +98,7 @@ size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_n * * return the DST offset in bytes */ -size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm( +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( size_t m_idx, size_t n_idx, size_t dst_stride); /** @@ -107,13 +107,13 @@ size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8m * @param[in] m Number of rows in the destination (DST) matrix. * @param[in] n Number of columns in the destination (DST) matrix. */ -size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm(size_t m, size_t n); +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(size_t m, size_t n); /** * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. * * LHS matrix: Signed 8-bit quantized asymmetric per-row (qau8dx) and packed - * RHS matrix: Unsigned 4-bit quantized symmetric per-channel (qsi4cx) and packed. + * RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4cx) and packed. * Output tile: (rows x cols) = 8 x 8 * Accumulation performed in a single for loop: 32 * Instruction used: i8mm @@ -126,14 +126,14 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm( * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs * both the dynamic quantization to 8-bit and activation packing in a single step. * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref - * kai_run_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0 + * kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 * @param[out] dst Result of the vector-by-matrix * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) * @param[in] scalar_min Min value used to clamp the final result. * @param[in] scalar_max Max value used to clamp the final result. */ -void kai_run_matmul_clamp_f32_qai8dxp4x8_qsu4cxp8x8_8x8x32_neon_i8mm( +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h new file mode 100644 index 00000000..d6f2c5c2 --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h @@ -0,0 +1,52 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f32_qai8dxp_qsi4cxp + +/** Micro-kernel helper functions ("get" methods) */ +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_dst_size_func_t)(size_t m, size_t n); + +/** Micro-kernel core function ("run" method) */ +typedef void (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/** Micro-kernel interface */ +struct kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel { + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_m_step_func_t get_m_step; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_n_step_func_t get_n_step; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_mr_func_t get_mr; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_nr_func_t get_nr; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_nr_func_t get_kr; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_sr_func_t get_sr; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h deleted file mode 100644 index 0365a34f..00000000 --- a/src/matmul/matmul_clamp_f32_qai8dxp_qsu4cxp/kai_matmul_clamp_f32_qai8dxp_qsu4cxp_interface.h +++ /dev/null @@ -1,52 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#pragma once - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// All micro-kernels variants of the same type share the same interfaces -// In this case, the micro-kernel type is: matmul_clamp_f32_qai8dxp_qsu4cxp - -/** Micro-kernel helper functions ("get" methods) */ -typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_m_step_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_n_step_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_mr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_nr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_kr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_sr_func_t)(void); -typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); -typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); -typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_dst_offset_func_t)( - size_t m_idx, size_t n_idx, size_t dst_stride); -typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_dst_size_func_t)(size_t m, size_t n); - -/** Micro-kernel core function ("run" method) */ -typedef void (*kai_matmul_clamp_f32_qai8dxp_qsu4cxp_run_matmul_func_t)( - size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, - size_t dst_stride_col, float scalar_min, float scalar_max); - -/** Micro-kernel interface */ -struct kai_matmul_clamp_f32_qai8dxp_qsu4cxp_ukernel { - kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_m_step_func_t get_m_step; - kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_n_step_func_t get_n_step; - kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_mr_func_t get_mr; - kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_nr_func_t get_nr; - kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_nr_func_t get_kr; - kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_sr_func_t get_sr; - kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_lhs_packed_offset_func_t get_lhs_packed_offset; - kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; - kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_dst_offset_func_t get_dst_offset; - kai_matmul_clamp_f32_qai8dxp_qsu4cxp_get_dst_size_func_t get_dst_size; - kai_matmul_clamp_f32_qai8dxp_qsu4cxp_run_matmul_func_t run_matmul; -}; - -#ifdef __cplusplus -} -#endif -- GitLab From 095000a5ecdc906962ea45282ca123fa87d8f2b2 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 23 May 2024 14:39:19 +0100 Subject: [PATCH 14/14] Update the README.md file Signed-off-by: Gian Marco Iodice --- README.md | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index be2008d4..3dd80cc4 100644 --- a/README.md +++ b/README.md @@ -44,19 +44,21 @@ A micro-kernel exists for different Arm® architectures, technologies, and compu Some of the key features of KleidiAI are the following: - No dependencies on external libraries -<<<<<<< HEAD -- No internal memory allocation -- No internal threading mechanisms -- Stateless, stable, and consistent API -======= + - No dynamic memory allocation + - No memory management​ + - No scheduling + - Stateless, stable, and consistent API​ ->>>>>>> Refactor file/function names + - Performance-critical compute-bound and memory-bound micro-kernels + - Specialized micro-kernels utilizing different Arm® CPU architectural features (for example, FEAT_DotProd and FEAT_I8MM) + - Specialized micro-kernels for different fusion patterns + - Micro-kernel as a standalone library, consisting of only a .c and .h files > ℹ️ The micro-kernel API is designed to be as generic as possible for integration into third-party runtimes. @@ -98,7 +100,7 @@ Some of the data types currently supported with the KleidiAI library are the fol | Data type | Abbreviation | Notes | | ----------- | ----------- | ----------- | | Floating-point 32-bit | f32 | | -| Quantized (q) Symmetric (s) Unsigned (u) 4-bit (4) Per-Channel (cx) quantization parameters | qsu4cx | An fp32 multiplier shared among all values of the same channel. `x` denotes the entirety of the channel | +| Quantized (q) Symmetric (s) Signed (u) 4-bit (4) Per-Channel (cx) quantization parameters | qsi4cx | An fp32 multiplier shared among all values of the same channel. `x` denotes the entirety of the channel | | Quantized (q) Asymmetric (a) Signed (i) 8-bit (8) Per-Dimension (dx) (for example, Per-Row) quantization parameters | qai8dx | An fp32 multiplier and a int32 zero offset shared among all values of the same dimension. | > ℹ️ In some cases, we may append the letter `p` to the data type to specify that the tensor is expected to be packed. A packed tensor is a tensor that has been rearranged in our preferred data layout from the original data layout to improve the performance of the micro-kernel. In addition to the letter `p`, we may append other alphanumerical values to specify the attributes of the data packing (for example, the block packing size). @@ -115,17 +117,17 @@ Some of the data types currently supported with the KleidiAI library are the fol Matrix-multiplication with LHS packed and RHS packed matrices - matmul_clamp_f32_qai8dxp_qsu4cxp + matmul_clamp_f32_qai8dxp_qsi4cxp LHS: qai8dxp
- RHS: qsu4cxp
+ RHS: qsi4cxp
DST: f32
TensorFlow Lite
- The packing function for the RHS matrix is available in the `kai_rhs_pack_nxk_qsu4cxp_qsu4cxs1s0.c/.h` files.
+ The packing function for the RHS matrix is available in the `kai_rhs_pack_nxk_qsi4cxp_qsi4cxs1s0.c/.h` files.
Since the RHS matrix often contains constant values, we recommend packing the RHS matrix only once and freeing the content of the original RHS matrix.
-- GitLab