diff --git a/.editorconfig b/.editorconfig index 6b724e49e13c4aa8408f45f4b01e2dba9a0744fc..d5ad33f905dace2c9d0e2cb391060edf08d17eb0 100644 --- a/.editorconfig +++ b/.editorconfig @@ -20,6 +20,10 @@ trim_trailing_whitespace = true [*.{json,yml,yaml}] indent_size = 2 +# Override settings. +[*.{c,cpp,h,hpp}] +indent_size = unset + # Override settings. [LICENSES/*] indent_size = unset diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f3e82a89202a34b133b6cda5e7a000be923090a..40ab01af72f934801ef18f2619c59efd96d9279f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,6 +56,26 @@ set(KLEIDIAI_WARNING_FLAGS $<$:${KLEIDIAI_WARNING_FLAGS_CXX}> ) +add_library(kleidiai) + +target_sources(kleidiai PRIVATE + src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c + src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c + src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c + src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c + src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c + src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c + src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c + src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c) + +target_include_directories(kleidiai + PRIVATE src/ +) + +target_compile_options(kleidiai + PRIVATE ${KLEIDIAI_WARNING_FLAGS} +) + if(KLEIDIAI_BUILD_TESTS) enable_testing() include(GoogleTest) diff --git a/README.md b/README.md index cc139c11e836869b3c58ae34991fc20fffbdd630..3dd80cc4a996ae3f03b1787e1437bbbc12b65775 100644 --- a/README.md +++ b/README.md @@ -29,9 +29,9 @@ For example, consider the convolution 2d operator performed through the Winograd - Matrix multiplication - Winograd output transform -Each of the preceding operations is a micro-kernel. For an example, please refer to the [first micro kernel PR](https://gitlab.arm.com/kleidi/kleidiai/-/merge_requests/2) +Each of the preceding operations is a micro-kernel. -However, why are the preceding operations not called kernels or functions? +However, why the preceding operations are not called kernels or functions instead? Because the micro-kernels are designed to give the flexibility to process also a portion of the output tensor, which is the reason why we call it micro-kernel. @@ -44,16 +44,109 @@ A micro-kernel exists for different Arm® architectures, technologies, and compu Some of the key features of KleidiAI are the following: - No dependencies on external libraries -- No internal memory allocation -- No internal threading mechanisms -- Stateless, stable, and consistent API + +- No dynamic memory allocation + +- No memory management​ + +- No scheduling + +- Stateless, stable, and consistent API​ + - Performance-critical compute-bound and memory-bound micro-kernels -- Specialized micro-kernels for different Arm® CPU architectures and technologies + +- Specialized micro-kernels utilizing different Arm® CPU architectural features (for example, FEAT_DotProd and FEAT_I8MM) + - Specialized micro-kernels for different fusion patterns + - Micro-kernel as a standalone library, consisting of only a .c and .h files > ℹ️ The micro-kernel API is designed to be as generic as possible for integration into third-party runtimes. +

Current supported Arm® CPUs technologies and features

+ +Arm® Neon™ + +- FEAT_DotProd is optional in Armv8.2-A and mandatory in Armv8.4-A +- FEAT_I8MM is optional in Armv8.2-A and mandatory in Armv8.6-A + +

Filename convention

+ +The `src/` directory is the home for all micro-kernels. The micro-kernels are grouped in separate directories based on the performed operation. For example, all the matrix-multiplication micro-kernels are held in the `matmul/` operator directory. + +Inside the operator directory, you can find: + +- *The common micro-kernels*, which are helper micro-kernels necessary for the correct functioning of the main ones. For example, some of these may be required for packing the input tensors. +- *The micro-kernels* files, which are held in separate sub-directories. + +The name of the micro-kernel folder provides the description of the operation performed and the data type of the destination and source tensors. The general syntax for the micro-kernel folder is as follows: + +`____...` + +All .c and .h pair files in that folder are micro-kernel variants. The variants are differentiated by specifying the computational paramaters (for example, the block size), the Arm® technology (for example, Arm® Neon™), and Arm® architecture feature exploited (for example, FEAT_DotProd). The general syntax for the micro-kernel variant is as follows: + +`kai____.c/.h` + +> ℹ️ These files, only depend on the `kai_common.h` file. + +All functions defined in the .h header file of the micro-kernel variant has the following syntax: + +`kai__.c/.h` + +

Data types

+ +Some of the data types currently supported with the KleidiAI library are the following: + +| Data type | Abbreviation | Notes | +| ----------- | ----------- | ----------- | +| Floating-point 32-bit | f32 | | +| Quantized (q) Symmetric (s) Signed (u) 4-bit (4) Per-Channel (cx) quantization parameters | qsi4cx | An fp32 multiplier shared among all values of the same channel. `x` denotes the entirety of the channel | +| Quantized (q) Asymmetric (a) Signed (i) 8-bit (8) Per-Dimension (dx) (for example, Per-Row) quantization parameters | qai8dx | An fp32 multiplier and a int32 zero offset shared among all values of the same dimension. | + +> ℹ️ In some cases, we may append the letter `p` to the data type to specify that the tensor is expected to be packed. A packed tensor is a tensor that has been rearranged in our preferred data layout from the original data layout to improve the performance of the micro-kernel. In addition to the letter `p`, we may append other alphanumerical values to specify the attributes of the data packing (for example, the block packing size). + +

Supported micro-kernels

+ + + + + + + + + + + + + + + + + + + + + + + +
Micro-kernelAbbreviationData typeReference frameworkNotes
Matrix-multiplication with LHS packed and RHS packed matricesmatmul_clamp_f32_qai8dxp_qsi4cxp + LHS: qai8dxp
+ RHS: qsi4cxp
+ DST: f32
+
+ TensorFlow Lite
+
+ The packing function for the RHS matrix is available in the `kai_rhs_pack_nxk_qsi4cxp_qsi4cxs1s0.c/.h` files.
+ Since the RHS matrix often contains constant values, we recommend packing the RHS matrix only once and freeing the content of the original RHS matrix.
+
Dynamic quantization and LHS matrix packingkai_lhs_quant_pack_qai8dxp_f32 + SRC: f32
+ DST: qai8cx
+
+ TensorFlow Lite
+
+
+
+

Frequently Asked Questions (FAQ)

What is the difference between the Compute Library for the Arm® Architecture (ACL) and KleidiAI?

diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ccc8b84008a4a2862984e476944c0c9c2e03c84e --- /dev/null +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/CMakeLists.txt @@ -0,0 +1,38 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.16) + +# KleidiAI include directories +include_directories( + ../../src/ + ../../src/matmul/ + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") + +# Files requires to build the executable +add_executable(matmul_clamp_f32_qai8dxp_qsi4cxp + matmul_clamp_f32_qai8dxp_qsi4cxp.cpp + ../../src/kai_common.h + ../../src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h + ../../src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c + ../../src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h + ../../src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h + ../../src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c) + diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0f404cf09637a62c8012432ad666bfddd668cbfd --- /dev/null +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp @@ -0,0 +1,449 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +// Include micro-kernel variants +#include +#include +#include +#include +#include +#include + +#include "kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h" +#include "kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" + +#define INT4_MIN (-8) +#define INT4_MAX (7) + +// Micro-kernel interface +struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; + std::string name = {}; +}; + +kai_matmul_ukernel_f32_qa8dxp_qs4cxp ukernel_variants[] = { + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + "matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm"}, + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + "matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm"}, +}; + +// Number of micro-kernel variants stored in the array +const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); + +static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, size_t seed) { + std::srand(seed); + + // Fill the array with random values between -1 and 1 + for (int i = 0; i < num_rows * num_cols; i++) { + dst[i] = (float)((double)std::rand() / RAND_MAX) * 2 - 1; + } +} + +static void quant_qs4cx_f32(size_t n, size_t k, const float* rhs_f32, uint8_t* rhs_qs4cx, float* rhs_scales_f32) { + const size_t dst_stride = (k / 2) * sizeof(int8_t); + + for (size_t row_idx = 0; row_idx < n; ++row_idx) { + const float* src_ptr = rhs_f32 + row_idx * k; + + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + max0 = std::max(src0_0, max0); + min0 = std::min(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT4_MIN; + const float qmax = (float)INT4_MAX; + + const float rmin0 = std::min(0.0f, min0); + const float rmax0 = std::max(0.0f, max0); + + const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + + uint8_t* dst_ptr = (uint8_t*)rhs_qs4cx + row_idx * dst_stride; + + // Quantize the channels + for (size_t k_idx = 0; k_idx < k; k_idx += 2) { + const float src0_0 = src_ptr[k_idx + 0]; + const float src0_1 = src_ptr[k_idx + 1]; + + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); + int32_t v1_s32 = (int32_t)(round(src0_1 * scale0)); + + // Maximum/minimum int4 values + v0_s32 = std::max(v0_s32, INT4_MIN); + v0_s32 = std::min(v0_s32, INT4_MAX); + v1_s32 = std::max(v1_s32, INT4_MIN); + v1_s32 = std::min(v1_s32, INT4_MAX); + + int32_t v0_u8 = (uint8_t)(v0_s32 + 8); + int32_t v1_u8 = (uint8_t)(v1_s32 + 8); + + const uint8_t rhs_v0 = (v1_u8 << 4) | v0_u8; + + dst_ptr[0] = rhs_v0; + dst_ptr += sizeof(uint8_t); + } + + rhs_scales_f32[row_idx] = recip_scale0; + } +}; + +static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { + const size_t dst_stride = (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + const float* src_ptr = lhs_f32 + row_idx * k; + + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + max0 = std::max(src0_0, max0); + min0 = std::min(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = std::min(0.0f, min0); + const float rmax0 = std::max(0.0f, max0); + + const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; + + zero_point0 = std::max(zero_point0, qmin); + zero_point0 = std::min(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = lrintf(zero_point0); + + int8_t* dst_ptr = (int8_t*)lhs_qa8dx + row_idx * dst_stride; + + // LHS offset at the beginning of the row + *((float*)(dst_ptr)) = recip_scale0; + dst_ptr += sizeof(float); + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + dst_ptr += sizeof(int32_t); + + // Quantize the channels + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = std::max(v0_s32, INT8_MIN); + v0_s32 = std::min(v0_s32, INT8_MAX); + dst_ptr[0] = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + } +}; + +static void ref_matmul_f32_qa8dx_qs4cx( + size_t m, size_t n, size_t k, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4cx, const float* rhs_scales_f32, + float* dst_f32, float scalar_min, float scalar_max) { + const size_t lhs_stride = k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); + const size_t rhs_stride = (k / 2) * sizeof(uint8_t); + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; + for (size_t col_idx = 0; col_idx < n; ++col_idx) { + // Main f32 accumulator + int32_t iacc = 0; + + const int8_t* lhs_ptr = lhs_ptr_start; + const uint8_t* rhs_ptr = rhs_qs4cx + col_idx * rhs_stride; + + // Get the LHS quantization parameters stored at the + // beginning of each row + const float lhs_scale = *(const float*)lhs_ptr; + lhs_ptr += sizeof(float); + + const int32_t lhs_offset = *(const int32_t*)lhs_ptr; + lhs_ptr += sizeof(int32_t); + + for (size_t b = 0; b < k; b += 2) { + // Get the LHS values + const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; + const int32_t lhs_v1 = (int32_t)lhs_ptr[1]; + + // Get the RHS values + const uint8_t rhs_byte = rhs_ptr[0]; + + // Unpack the RHS values + const int32_t rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8); + const int32_t rhs_v1 = (((int32_t)(rhs_byte >> 4)) - 8); + + iacc += lhs_v0 * rhs_v0; + iacc += lhs_v1 * rhs_v1; + iacc += lhs_offset * rhs_v0; + iacc += lhs_offset * rhs_v1; + + lhs_ptr += 2; + rhs_ptr += 1; + } + + // Get the RHS scale + const float rhs_scale = rhs_scales_f32[col_idx]; + + float main_acc = iacc * rhs_scale; + + main_acc = main_acc * lhs_scale; + + // Clamp (min-max) operation + main_acc = std::max(main_acc, scalar_min); + main_acc = std::min(main_acc, scalar_max); + + dst_f32[0] = main_acc; + dst_f32 += 1; + } + } +}; + +static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, const float* ref, const float* act) { + bool is_valid = true; + + for (size_t i = 0; i < num_rows * num_cols; ++i) { + if (std::fabs(ref[i] - act[i]) > tolerance) { + const size_t x = i % num_cols; + const size_t y = i / num_cols; + printf("ERROR![%ld][%ld]: ref=%.5f vs. act=%.5f\n", y, x, ref[i], act[i]); + is_valid = false; + } + } + return is_valid; +} + +int main(int argc, char** argv) { + const size_t m = 13; + const size_t n = 17; + const size_t k = 18; + const size_t seed_lhs = 4568; + const size_t seed_rhs = seed_lhs + 4; + + const size_t lhs_native_size_f32 = m * k * sizeof(float); + const size_t rhs_native_size_f32 = n * k * sizeof(float); + const size_t rhs_native_size_qs4cx = n * (k / 2) * sizeof(uint8_t); + const size_t rhs_scales_size_f32 = n * sizeof(float); + + // Allocate the memory + uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; + uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; + uint8_t* rhs_native_mtx_qs4cx = new uint8_t[rhs_native_size_qs4cx]; + uint8_t* rhs_scales_f32 = new uint8_t[rhs_scales_size_f32]; + + fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); + fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); + + quant_qs4cx_f32(n, k, (const float*)rhs_native_mtx_f32, (uint8_t*)rhs_native_mtx_qs4cx, (float*)rhs_scales_f32); + + delete[] rhs_native_mtx_f32; + + //----------- REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + // Memory sizes for the reference implementation + // After dynamically quantized the LHS matrix, we have the scale and offset for each + // row. The scale (f32) and offset (int32) are stored at the beginning of each row + size_t lhs_ref_size_qa8dx = m * (k + sizeof(int32_t) + sizeof(float)); + size_t dst_ref_size_f32 = m * n * sizeof(float); + + uint8_t* lhs_ref_mtx_qa8dx = new uint8_t[lhs_ref_size_qa8dx]; + uint8_t* dst_ref_mtx_f32 = new uint8_t[dst_ref_size_f32]; + + ref_quant_qa8dx_f32(m, k, (const float*)lhs_native_mtx_f32, (int8_t*)lhs_ref_mtx_qa8dx); + + ref_matmul_f32_qa8dx_qs4cx( + m, n, k, (const int8_t*)lhs_ref_mtx_qa8dx, (const uint8_t*)rhs_native_mtx_qs4cx, (const float*)rhs_scales_f32, + (float*)dst_ref_mtx_f32, -FLT_MAX, FLT_MAX); + + // Remove the unnecessary buffer + delete[] lhs_ref_mtx_qa8dx; + + //----------- END REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + + //----------- MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + for (size_t idx_variant = 0; idx_variant < num_ukernel_variants; ++idx_variant) { + std::cout << "Testing " << ukernel_variants[idx_variant].name << std::endl; + + // Get the packing parameters + const size_t mr = ukernel_variants[idx_variant].ukernel.get_mr(); + const size_t nr = ukernel_variants[idx_variant].ukernel.get_nr(); + const size_t kr = ukernel_variants[idx_variant].ukernel.get_kr(); + const size_t sr = ukernel_variants[idx_variant].ukernel.get_sr(); + + // Get the size in bytes for the packed matrices + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(n, k, nr, kr, sr); + const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); + + // Allocate the matrices + uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; + uint8_t* rhs_packed_mtx_qs4cx = new uint8_t[rhs_packed_size]; + uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; + + // If the RHS matrix contains constant values, the packing can be performed + // only once + struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + + // RHS packing + kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + 1, n, k, nr, kr, sr, + (const uint8_t*)(rhs_native_mtx_qs4cx), // RHS + NULL, // Bias + (const float*)(rhs_scales_f32), // Scale + rhs_packed_mtx_qs4cx, // DST + 0, ¶ms); + + // LHS packing + kai_run_lhs_quant_pack_qai8dxp_f32( + m, k, mr, kr, sr, 0, (const float*)lhs_native_mtx_f32, k * sizeof(float), lhs_packed_mtx_qa8dx); + + // Matmul + { + const size_t dst_stride = n * sizeof(float); + const size_t lhs_offset = ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k); + const size_t rhs_offset = ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(0, k); + const size_t dst_offset = ukernel_variants[idx_variant].ukernel.get_dst_offset(0, 0, dst_stride); + + const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); + const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4cx + rhs_offset); + float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); + + ukernel_variants[idx_variant].ukernel.run_matmul( + m, n, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + } + + const bool is_valid = + is_output_correct(m, n, 0.0001f, (const float*)dst_ref_mtx_f32, (const float*)dst_act_mtx_f32); + + if (is_valid) { + printf("TEST[%ld] = PASSED\n", idx_variant); + } else { + printf("TEST[%ld] = FAILED\n", idx_variant); + } + delete[] lhs_packed_mtx_qa8dx; + delete[] rhs_packed_mtx_qs4cx; + delete[] dst_act_mtx_f32; + } + delete[] lhs_native_mtx_f32; + delete[] rhs_native_mtx_qs4cx; + delete[] rhs_scales_f32; + delete[] dst_ref_mtx_f32; +} + +//----------- END MICRO-KERNELS TESTS +//------------------------------------ +//------------------------------------ diff --git a/src/kai_common.h b/src/kai_common.h index b44dad5b9712fdb11f9ba06e876ac575b140a684..9bb2550af70a53a66df3b28692d8d674b10b8eba 100644 --- a/src/kai_common.h +++ b/src/kai_common.h @@ -3,12 +3,15 @@ // // SPDX-License-Identifier: Apache-2.0 // - #pragma once #include #include +#ifdef __cplusplus +extern "C" { +#endif + // NOLINTBEGIN(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) // // * cppcoreguidelines-avoid-do-while: do-while is necessary for macros. @@ -27,6 +30,7 @@ KAI_ERROR(msg); \ } \ } while (0) + // NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) #define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond) @@ -40,3 +44,15 @@ #define KAI_ASSUME_IF KAI_ASSERT_IF #define KAI_UNUSED(x) (void)(x) + +#define KAI_UNUSED(x) (void)(x) +#define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b)) +#define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b)) + +inline static size_t kai_roundup(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c new file mode 100644 index 0000000000000000000000000000000000000000..d93469a0e71ae8f8875fb0fdfe428f505c8990ce --- /dev/null +++ b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.c @@ -0,0 +1,176 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_lhs_quant_pack_qai8dxp_f32.h" + +#include +#include +#include +#include + +#include "kai_common.h" + +static const size_t kai_num_bytes_per_multiplier = sizeof(float); +static const size_t kai_num_bytes_per_offset = sizeof(int32_t); + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for memory alignment. + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) { + const size_t k_internal = kai_k_roundedup(k, kr, sr); + + KAI_ASSERT((k_internal % 2) == 0); + + return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); +} + +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f32(size_t mr) { + KAI_UNUSED(mr); + return 1; +} + +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { + // It always points to the beginning of the row + return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr); +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + const size_t num_rows = kai_roundup(m, mr) / mr; + + return num_rows * kai_lhs_packed_stride(k, mr, kr, sr); +} + +void kai_run_lhs_quant_pack_qai8dxp_f32( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* restrict lhs, + size_t lhs_stride, void* restrict lhs_packed) { + KAI_ASSERT((kr % sr) == 0); + + if (m == 0) { + return; + } + + const size_t num_rows = m; + + const float* src_ptr = lhs; + + const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const int32_t k_block_len = (int32_t)(kr / sr); + + for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); + + // Find min/max for each channel + int32_t k_idx = 0; +#if defined(__aarch64__) + for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { + const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx); + const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx); + + // Calculate the max + vmax0 = vmaxq_f32(src0_0, vmax0); + vmax0 = vmaxq_f32(vmax0, src0_1); + + // Calculate the min + vmin0 = vminq_f32(src0_0, vmin0); + vmin0 = vminq_f32(vmin0, src0_1); + } + // Get the max/min + max0 = vmaxvq_f32(vmax0); + min0 = vminvq_f32(vmin0); +#endif + for (; k_idx < (int32_t)k; ++k_idx) { + const float src0_0 = *(src_ptr + (size_t)k_idx); + max0 = KAI_MAX(src0_0, max0); + min0 = KAI_MIN(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = KAI_MIN(0.0F, min0); + const float rmax0 = KAI_MAX(0.0F, max0); + + const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; + + zero_point0 = KAI_MAX(zero_point0, qmin); + zero_point0 = KAI_MIN(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); + + const size_t dst_x = ((row_idx + m_idx_start) % mr); + + uint8_t* dst_ptr = (uint8_t*)lhs_packed + dst_x * k_block_len * sizeof(int8_t); + + // Quantize the channels + k_idx = 0; + for (; k_idx < (int32_t)k_internal; k_idx += k_block_len) { + for (size_t k_block_idx = 0; k_block_idx < (size_t)k_block_len; ++k_block_idx) { + // Clamp at the last valid k-index + const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k); + + const float src0_0 = *(src_ptr + k_idx_start); + + // Scale the values + int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = KAI_MAX(v0_s32, INT8_MIN); + v0_s32 = KAI_MIN(v0_s32, INT8_MAX); + *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + + dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); + + dst_ptr += dst_x * kai_num_bytes_per_offset; + + // LHS offset at the beginning of the row + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + + // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier + KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); + + dst_ptr += mr * kai_num_bytes_per_offset; + + // Store the scale quantization params + *((float*)(dst_ptr)) = recip_scale0; + + src_ptr += (lhs_stride / sizeof(float)); + + // Move to the next row if we have interleaved all Mr rows + if ((((row_idx + 1) + m_idx_start) % mr) == 0) { + lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); + } + } +} diff --git a/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h new file mode 100644 index 0000000000000000000000000000000000000000..85715dc661f90c5e454f4c9d39d0345b2eb14fa5 --- /dev/null +++ b/src/matmul/kai_lhs_quant_pack_qai8dxp_f32.h @@ -0,0 +1,87 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @param[in] mr The number of M rows to interleave on the same output row. + * + * @return the m step value + */ +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f32(size_t mr); + +/** + * @brief Function to calculate the offset in bytes for the LHS matrix (not packed) + * + * This function should be called before passing the pointer to the LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) + * + * return the offset in bytes to the LHS matrix + */ +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_stride); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). + * @param[in] mr The number of M rows to interleave on the same output row. + * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); + +/** + * @brief Function to return the memory required for storing the quantized and packed LHS matrix + * + * @param[in] m Total number of rows in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). + * @param[in] mr The number of M rows to interleave on the same output row. + * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. + * + * return the size in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr, size_t sr); + +/** + * @brief Micro-kernel to quantize and pack the LHS matrix + * + * @param[in] m The number of output rows written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. + * @param[in] mr The number of M rows to interleave on the same output row. + * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. + * However, kr must be multiple of sr. + * @param[in] m_idx_start The starting M index. + * @param[in] lhs LHS of the vector-by-matrix. + * @param[in] lhs_stride Stride in bytes between two rows of LHS. + * @param[out] lhs_packed The quantized and packed LHS matrix. + */ +void kai_run_lhs_quant_pack_qai8dxp_f32( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, size_t lhs_stride, + void* lhs_packed); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c b/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c new file mode 100644 index 0000000000000000000000000000000000000000..321ae08c58d5cdd97ea95afaf49df739d886ed07 --- /dev/null +++ b/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c @@ -0,0 +1,142 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" + +#include +#include +#include +#include + +#include "kai_common.h" + +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t kr, size_t nr, size_t sr) { + const size_t k_internal = kai_k_roundedup(k, kr, sr); + + KAI_ASSERT((k_internal % 2) == 0); + + return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr) { + return nr; +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { + KAI_ASSERT((n_idx % nr) == 0); + + return (n_idx / nr) * kai_rhs_packed_stride(k, kr, nr, sr); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { + const size_t num_rows = kai_roundup(n, nr) / nr; + + return num_rows * kai_rhs_packed_stride(k, kr, nr, sr); +} + +void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const int32_t* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params) { + KAI_ASSERT((k % 2) == 0); + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(bias == NULL); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + // Note: The input matrix (rhs) is expected with: + // "k" columns and "n" rows (NxK) + + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_stride = k / 2; + const size_t rhs_packed_stride = kai_rhs_packed_stride(k, kr, nr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + + for (size_t y = 0; y < n; y += nr) { + const uint8_t* src_row = rhs + y * rhs_stride; + uint8_t* dst_row = (uint8_t*)rhs_packed + (y / nr) * rhs_packed_stride; + + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); + + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + + for (size_t x = 0; x < k_internal; x += (kr * sr)) { + for (size_t s = 0; s < sr; ++s) { + for (size_t i = 0; i < nr; ++i) { + for (size_t kr_idx = 0; kr_idx < kr / sr; kr_idx += 2) { + const size_t k_idx_start0 = (x / 2) + kr_idx / 2 + s * (kr / sr) / 2; + const size_t k_idx_start1 = k_idx_start0 + (kr / 2); + + const size_t src_addr_byte0 = i * rhs_stride + k_idx_start0; + const size_t src_addr_byte1 = i * rhs_stride + k_idx_start1; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k_idx_start0 < (k / 2)) { + byte0 = src_row[src_addr_byte0]; + } + + if (k_idx_start1 < (k / 2)) { + byte1 = src_row[src_addr_byte1]; + } + + const uint8_t src_x0_lo = (byte0 & 0x0F); + const uint8_t src_x1_lo = (byte0 >> 4); + + const uint8_t src_x0_hi = (byte1 & 0x0F); + const uint8_t src_x1_hi = (byte1 >> 4); + + sums[i] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; + sums[i] += (int32_t)src_x1_lo + (int32_t)src_x1_hi - 2 * (int32_t)rhs_zero_point; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + const uint8_t dst_qs1 = src_x1_lo | (src_x1_hi << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + *dst_row = dst_qs1 ^ 0x88; + dst_row += sizeof(uint8_t); + } + } + } + } + + // Adjust the reduction sums + for (size_t i = 0; i < nr; ++i) { + *((int32_t*)(dst_row)) = sums[i] * 16; + dst_row += sizeof(int32_t); + } + + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + *((float*)(dst_row)) = scale[y + i] * 0.0625F; + dst_row += sizeof(float); + } + } +} diff --git a/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h b/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h new file mode 100644 index 0000000000000000000000000000000000000000..8340c2bcc8f0cae464426a07e9a27a69d752e23f --- /dev/null +++ b/src/matmul/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h @@ -0,0 +1,101 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params { + int8_t lhs_zero_point; + uint8_t rhs_zero_point; +}; + +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @param[in] nr The number of columns written by the matmul micro-kernel + * + * @return the n step value + */ +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr); + +/** + * @brief Function to calculate the offset in bytes for the RHS matrix (not packed), which holds + * the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. + * Two int4 values are stored in one byte. The lower order part of the byte (low) holds + * the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). + * + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. + * @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) + * + * return the offset in bytes to the RHS matrix (not packed) + */ +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. + * @param[in] k The common dimension between the LHS and RHS matrix (K) + * @param[in] nr The number of columns written by the matmul micro-kernel + * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr); + +/** + * @brief Function to return the memory required for storing the packed RHS matrix + * + * @param[in] n The number of rows in the RHS matrix (not packed) + * @param[in] k The number of columns in the RHS matrix (not packed). + * @param[in] nr The number of columns written by the matmul micro-kernel + * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. + * + * return the size in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr); + +/** + * @brief Micro-kernel to pack the RHS matrix. + * + * @note The int4 values are stored in a N x K matrix, where N is number of rows and K is the number of columns. + * Two int4 values are stored in one byte. The lower order part of the byte (low) holds + * the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). + * + * @param[in] num_groups The number of groups. It must be 1. + * @param[in] n The number of columns of the output matrix (N). + * @param[in] k The common dimension between the LHS and RHS matrix (K). It must be an even value. + * @param[in] nr The number of N columns to interleave on the same output output row. + * @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. + * @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. + * However, kr must be multiple of sr. + * @param[in] rhs The RHS matrix containing the 4-bit values. + * Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). + * @param[in] bias The biases. + * @param[in] scale The scale for each output channel. + * @param[out] rhs_packed The packed RHS matrix. + * @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. + * @param[in] params Parameters for the micro-kernel. + */ +void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const int32_t* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c new file mode 100644 index 0000000000000000000000000000000000000000..6fee9b51c7c5071524b5a02f33defb555919f7b5 --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c @@ -0,0 +1,230 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" + +#include +#include + +#include "kai_common.h" + +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { +#if defined(__ARM_FEATURE_DOTPROD) + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t kai_k0 = kai_kr * kai_sr; + + const size_t num_rows = m; + const size_t num_cols = n; + + const size_t lhs_packed_stride = kai_lhs_packed_stride(k); + const size_t k_internal = kai_k_roundedup(k); + + const int8x16_t nibble_mask = vdupq_n_s8(0xF0); + + const uint8_t* lhs_ptr_start = lhs_packed; + + for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { + const uint8_t* rhs_ptr = rhs_packed; + for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { + const uint8_t* lhs_ptr = lhs_ptr_start; + + // Main f32 accumulator + int32x4_t iacc0011 = vdupq_n_s32(0); + int32x4_t iacc2233 = vdupq_n_s32(0); + + for (size_t b = 0; b < k_internal; b += kai_k0) { + // Set up RHS + const int8x16_t rhs_raw_vec_0 = vld1q_s8((const int8_t*)(rhs_ptr + 0)); + const int8x16_t rhs_raw_vec_1 = vld1q_s8((const int8_t*)(rhs_ptr + 16)); + const int8x16_t rhs_raw_vec_2 = vld1q_s8((const int8_t*)(rhs_ptr + 32)); + const int8x16_t rhs_raw_vec_3 = vld1q_s8((const int8_t*)(rhs_ptr + 48)); + + // Low nibble + const int8x16_t rhs_vec_0_0 = vshlq_n_s8(rhs_raw_vec_0, 4); + const int8x16_t rhs_vec_1_0 = vshlq_n_s8(rhs_raw_vec_1, 4); + const int8x16_t rhs_vec_2_0 = vshlq_n_s8(rhs_raw_vec_2, 4); + const int8x16_t rhs_vec_3_0 = vshlq_n_s8(rhs_raw_vec_3, 4); + + // High nibble + const int8x16_t rhs_vec_0_1 = vandq_s8(rhs_raw_vec_0, nibble_mask); + const int8x16_t rhs_vec_1_1 = vandq_s8(rhs_raw_vec_1, nibble_mask); + const int8x16_t rhs_vec_2_1 = vandq_s8(rhs_raw_vec_2, nibble_mask); + const int8x16_t rhs_vec_3_1 = vandq_s8(rhs_raw_vec_3, nibble_mask); + + const int8x16_t lhs_vec_0 = vld1q_s8((const int8_t*)(lhs_ptr + 0)); + const int8x16_t lhs_vec_1 = vld1q_s8((const int8_t*)(lhs_ptr + 16)); + + lhs_ptr += 32; + rhs_ptr += 64; + + int8x16_t t; + + t = vcombine_s8(vget_low_s8(lhs_vec_0), vget_low_s8(lhs_vec_0)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_0, t); + t = vcombine_s8(vget_high_s8(lhs_vec_0), vget_high_s8(lhs_vec_0)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_2_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_3_0, t); + t = vcombine_s8(vget_low_s8(lhs_vec_1), vget_low_s8(lhs_vec_1)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_1, t); + t = vcombine_s8(vget_high_s8(lhs_vec_1), vget_high_s8(lhs_vec_1)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_2_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_3_1, t); + } + + int32x4_t iacc = vpaddq_s32(iacc0011, iacc2233); + + // LHS offset + const int32x4_t lhs_offset = vld1q_dup_s32((const int32_t*)lhs_ptr); + lhs_ptr += sizeof(int32_t); + + // LHS scale + const float32x4_t lhs_scale = vld1q_dup_f32((const float*)lhs_ptr); + lhs_ptr += sizeof(float); + + // RHS sum values + const int32x4_t sum_n_s32 = vld1q_s32((const int32_t*)(rhs_ptr)); + rhs_ptr += sizeof(int32x4_t); + + // RHS scale + const float32x4_t rhs_scale = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + + // Add the reduction sum + iacc = vmlaq_s32(iacc, sum_n_s32, lhs_offset); + + float32x4_t main_acc = vmulq_f32(vcvtq_f32_s32(iacc), rhs_scale); + + main_acc = vmulq_f32(main_acc, lhs_scale); + + // clamp (min-max) operation + const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); + const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); + + main_acc = vmaxq_f32(main_acc, vmin_f32); + main_acc = vminq_f32(main_acc, vmax_f32); + + if (col_idx + kai_nr <= n) { + vst1q_f32((float*)((uint8_t*)dst + col_idx * sizeof(float) + row_idx * dst_stride_row), main_acc); + } else { + size_t leftover = n % kai_nr; + *(float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc, 0); + if (leftover > 1) { + *(float*)((uint8_t*)dst + (col_idx + 1) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc, 1); + } + if (leftover > 2) { + *(float*)((uint8_t*)dst + (col_idx + 2) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc, 2); + } + } + } + lhs_ptr_start += lhs_packed_stride; + } +#else + KAI_ASSERT(false); + KAI_UNUSED(m); + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(lhs_packed); + KAI_UNUSED(rhs_packed); + KAI_UNUSED(dst); + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(scalar_min); + KAI_UNUSED(scalar_max); +#endif +} diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h new file mode 100644 index 0000000000000000000000000000000000000000..558f0f4c2b8674c594412167803fe834adfe7b3f --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h @@ -0,0 +1,142 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifndef __cplusplus +#include +#endif +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @return the m step value + */ +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @return the n step + */ +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/** + * @brief Function to get the mr value, which must be used to pack the LHS matrix with + * the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel + * + * @return the mr value + */ +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the nr value + */ +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/** + * @brief Function to get the kr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the kr value + */ +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/** + * @brief Function to get the sr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the sr value + */ +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). + * @param[in] k The common dimension between the LHS and RHS matrix (K). + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the DST matrix + * + * @param[in] m_idx Row index in the DST matrix. + * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * + * return the DST offset in bytes + */ +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/** + * @brief Function to query the size in bytes for the destination matrix. + * + * @param[in] m Number of rows in the destination (DST) matrix + * @param[in] n Number of columns in the destination (DST) matrix + */ +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n); + +/** + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. + * + * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed + * RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. + * Output tile: (rows x cols) = 1 x 4 + * Accumulation performed in a single for loop: 64 + * Instruction used: dotprod + * + * @param[in] m The number of output rows written. + * @param[in] n The number of output columns written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. + * @param[in] lhs_packed The LHS matrix packed. + * When the activation are dynamically quantized, you can obtain this matrix + * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs + * both the dynamic quantization to 8-bit and activation packing in a single step. + * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 + * @param[out] dst Result of the vector-by-matrix + * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. + * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) + * @param[in] scalar_min Min value used to clamp the final result. + * @param[in] scalar_max Max value used to clamp the final result. + */ +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c new file mode 100644 index 0000000000000000000000000000000000000000..f243f3d7244524d4cd843fbf57d5e9631ce6f27c --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c @@ -0,0 +1,282 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h" + +#include +#include + +#include "kai_common.h" + +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 8; +static const size_t kai_mr = 1; +static const size_t kai_nr = 8; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { +#if defined(__ARM_FEATURE_DOTPROD) + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t kai_k0 = kai_kr * kai_sr; + const size_t num_rows = m; + const size_t num_cols = n; + + const size_t lhs_packed_stride = kai_lhs_packed_stride(k); + const size_t k_internal = kai_k_roundedup(k); + + const int8x16_t nibble_mask = vdupq_n_s8(0xF0); + + const uint8_t* lhs_ptr_start = lhs_packed; + + for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { + const uint8_t* rhs_ptr = rhs_packed; + for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { + const uint8_t* lhs_ptr = lhs_ptr_start; + + // Main f32 accumulator + int32x4_t iacc0011 = vdupq_n_s32(0); + int32x4_t iacc2233 = vdupq_n_s32(0); + int32x4_t iacc4455 = vdupq_n_s32(0); + int32x4_t iacc6677 = vdupq_n_s32(0); + + for (size_t b = 0; b < k_internal; b += kai_k0) { + // Set up RHS + const int8x16_t rhs_raw_vec_0 = vld1q_s8((const int8_t*)(rhs_ptr + 0)); + const int8x16_t rhs_raw_vec_1 = vld1q_s8((const int8_t*)(rhs_ptr + 16)); + const int8x16_t rhs_raw_vec_2 = vld1q_s8((const int8_t*)(rhs_ptr + 32)); + const int8x16_t rhs_raw_vec_3 = vld1q_s8((const int8_t*)(rhs_ptr + 48)); + const int8x16_t rhs_raw_vec_4 = vld1q_s8((const int8_t*)(rhs_ptr + 64)); + const int8x16_t rhs_raw_vec_5 = vld1q_s8((const int8_t*)(rhs_ptr + 80)); + const int8x16_t rhs_raw_vec_6 = vld1q_s8((const int8_t*)(rhs_ptr + 96)); + const int8x16_t rhs_raw_vec_7 = vld1q_s8((const int8_t*)(rhs_ptr + 112)); + + // Low nibble + const int8x16_t rhs_vec_0_0 = vshlq_n_s8(rhs_raw_vec_0, 4); + const int8x16_t rhs_vec_1_0 = vshlq_n_s8(rhs_raw_vec_1, 4); + const int8x16_t rhs_vec_2_0 = vshlq_n_s8(rhs_raw_vec_2, 4); + const int8x16_t rhs_vec_3_0 = vshlq_n_s8(rhs_raw_vec_3, 4); + const int8x16_t rhs_vec_4_0 = vshlq_n_s8(rhs_raw_vec_4, 4); + const int8x16_t rhs_vec_5_0 = vshlq_n_s8(rhs_raw_vec_5, 4); + const int8x16_t rhs_vec_6_0 = vshlq_n_s8(rhs_raw_vec_6, 4); + const int8x16_t rhs_vec_7_0 = vshlq_n_s8(rhs_raw_vec_7, 4); + + // High nibble + const int8x16_t rhs_vec_0_1 = vandq_s8(rhs_raw_vec_0, nibble_mask); + const int8x16_t rhs_vec_1_1 = vandq_s8(rhs_raw_vec_1, nibble_mask); + const int8x16_t rhs_vec_2_1 = vandq_s8(rhs_raw_vec_2, nibble_mask); + const int8x16_t rhs_vec_3_1 = vandq_s8(rhs_raw_vec_3, nibble_mask); + const int8x16_t rhs_vec_4_1 = vandq_s8(rhs_raw_vec_4, nibble_mask); + const int8x16_t rhs_vec_5_1 = vandq_s8(rhs_raw_vec_5, nibble_mask); + const int8x16_t rhs_vec_6_1 = vandq_s8(rhs_raw_vec_6, nibble_mask); + const int8x16_t rhs_vec_7_1 = vandq_s8(rhs_raw_vec_7, nibble_mask); + + const int8x16_t lhs_vec_0 = vld1q_s8((const int8_t*)(lhs_ptr + 0)); + const int8x16_t lhs_vec_1 = vld1q_s8((const int8_t*)(lhs_ptr + 16)); + + lhs_ptr += 32; + rhs_ptr += 128; + + int8x16_t t; + + t = vcombine_s8(vget_low_s8(lhs_vec_0), vget_low_s8(lhs_vec_0)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_0, t); + iacc4455 = vdotq_s32(iacc4455, rhs_vec_2_0, t); + iacc6677 = vdotq_s32(iacc6677, rhs_vec_3_0, t); + t = vcombine_s8(vget_high_s8(lhs_vec_0), vget_high_s8(lhs_vec_0)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_4_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_5_0, t); + iacc4455 = vdotq_s32(iacc4455, rhs_vec_6_0, t); + iacc6677 = vdotq_s32(iacc6677, rhs_vec_7_0, t); + + t = vcombine_s8(vget_low_s8(lhs_vec_1), vget_low_s8(lhs_vec_1)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_1, t); + iacc4455 = vdotq_s32(iacc4455, rhs_vec_2_1, t); + iacc6677 = vdotq_s32(iacc6677, rhs_vec_3_1, t); + t = vcombine_s8(vget_high_s8(lhs_vec_1), vget_high_s8(lhs_vec_1)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_4_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_5_1, t); + iacc4455 = vdotq_s32(iacc4455, rhs_vec_6_1, t); + iacc6677 = vdotq_s32(iacc6677, rhs_vec_7_1, t); + } + + int32x4_t iacc0 = vpaddq_s32(iacc0011, iacc2233); + int32x4_t iacc1 = vpaddq_s32(iacc4455, iacc6677); + + // LHS offset + const int32x4_t lhs_offset = vld1q_dup_s32((const int32_t*)lhs_ptr); + lhs_ptr += sizeof(int32_t); + + // LHS scale + const float32x4_t lhs_scale = vld1q_dup_f32((const float*)lhs_ptr); + lhs_ptr += sizeof(float); + + // RHS sum values + const int32x4_t sum_n_s32_0 = vld1q_s32((const int32_t*)(rhs_ptr)); + rhs_ptr += sizeof(int32x4_t); + const int32x4_t sum_n_s32_1 = vld1q_s32((const int32_t*)(rhs_ptr)); + rhs_ptr += sizeof(int32x4_t); + + // RHS scale + const float32x4_t rhs_scale0 = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + const float32x4_t rhs_scale1 = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + + // Add the reduction sum + iacc0 = vmlaq_s32(iacc0, sum_n_s32_0, lhs_offset); + iacc1 = vmlaq_s32(iacc1, sum_n_s32_1, lhs_offset); + + float32x4_t main_acc0 = vmulq_f32(vcvtq_f32_s32(iacc0), rhs_scale0); + float32x4_t main_acc1 = vmulq_f32(vcvtq_f32_s32(iacc1), rhs_scale1); + + main_acc0 = vmulq_f32(main_acc0, lhs_scale); + main_acc1 = vmulq_f32(main_acc1, lhs_scale); + + // clamp (min-max) operation + const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); + const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); + + main_acc0 = vmaxq_f32(main_acc0, vmin_f32); + main_acc0 = vminq_f32(main_acc0, vmax_f32); + + main_acc1 = vmaxq_f32(main_acc1, vmin_f32); + main_acc1 = vminq_f32(main_acc1, vmax_f32); + + if (col_idx + kai_nr <= n) { + vst1q_f32( + (float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + row_idx * dst_stride_row), main_acc0); + vst1q_f32( + (float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + row_idx * dst_stride_row), main_acc1); + } else { + size_t leftover = n % kai_nr; + *(float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc0, 0); + if (leftover > 1) { + *(float*)((uint8_t*)dst + (col_idx + 1) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc0, 1); + } + if (leftover > 2) { + *(float*)((uint8_t*)dst + (col_idx + 2) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc0, 2); + } + if (leftover > 3) { + *(float*)((uint8_t*)dst + (col_idx + 3) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc0, 3); + } + if (leftover > 4) { + *(float*)((uint8_t*)dst + (col_idx + 4) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc1, 0); + } + if (leftover > 5) { + *(float*)((uint8_t*)dst + (col_idx + 5) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc1, 1); + } + if (leftover > 6) { + *(float*)((uint8_t*)dst + (col_idx + 6) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc1, 2); + } + } + } + lhs_ptr_start += lhs_packed_stride; + } +#else + KAI_ASSERT(false); + KAI_UNUSED(m); + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(lhs_packed); + KAI_UNUSED(rhs_packed); + KAI_UNUSED(dst); + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(scalar_min); + KAI_UNUSED(scalar_max); +#endif +} diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h new file mode 100644 index 0000000000000000000000000000000000000000..69bf5c08f35b5adf531b83e5ac262910c83159cd --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h @@ -0,0 +1,142 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifndef __cplusplus +#include +#endif +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @return the m step value + */ +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void); + +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @return the n step + */ +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void); + +/** + * @brief Function to get the mr value, which must be used to pack the LHS matrix with + * the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel + * + * @return the mr value + */ +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the nr value + */ +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void); + +/** + * @brief Function to get the kr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the kr value + */ +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void); + +/** + * @brief Function to get the sr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the sr value + */ +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(void); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t m_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). + * @param[in] k The common dimension between the LHS and RHS matrix (K). + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t n_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the DST matrix + * + * @param[in] m_idx Row index in the DST matrix. + * @param[in] n_idx Column index in the DST matrix. It must be multiple of 8. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * + * return the DST offset in bytes + */ +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/** + * @brief Function to query the size in bytes for the destination matrix. + * + * @param[in] m Number of rows in the destination (DST) matrix + * @param[in] n Number of columns in the destination (DST) matrix + */ +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod(size_t m, size_t n); + +/** + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. + * + * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed + * RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. + * Output tile: (rows x cols) = 1 x 8 + * Accumulation performed in a single for loop: 64 + * Instruction used: dotprod + * + * @param[in] m The number of output rows written. + * @param[in] n The number of output columns written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. + * @param[in] lhs_packed The LHS matrix packed. + * When the activation are dynamically quantized, you can obtain this matrix + * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs + * both the dynamic quantization to 8-bit and activation packing in a single step. + * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 + * @param[out] dst Result of the vector-by-matrix + * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. + * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) + * @param[in] scalar_min Min value used to clamp the final result. + * @param[in] scalar_max Max value used to clamp the final result. + */ +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c new file mode 100644 index 0000000000000000000000000000000000000000..0bdd4f9e4501f0dc5ddca884edd75f0bdf77e87c --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c @@ -0,0 +1,279 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h" + +#include +#include + +#include "kai_common.h" + +static const size_t kai_m_step = 4; +static const size_t kai_n_step = 4; +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max) { +#if defined(__ARM_FEATURE_MATMUL_INT8) + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t k_internal = kai_k_roundedup(k); + + size_t num_blocks = k_internal / 32; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x28, #0x80\n" + "mov x20, #0x20\n" + "movi v4.16b, #0xf0\n" + "mov x27, %x[m]\n" + "madd x28, %x[num_blocks], x28, x20\n" + "cbz x27, 8f\n" + "1:" // Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "2:" // Column loop + "movi v3.4s, #0x0\n" + "movi v2.4s, #0x0\n" + "mov x21, %x[lhs_packed]\n" + "mov x20, %x[num_blocks]\n" + "movi v1.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "3:" // Block loop + "ldr q31, [x26, #0x0]\n" + "ldr q30, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q29, [x21, #0x0]\n" + "ldr q28, [x21, #0x10]\n" + "ldr q27, [x26, #0x20]\n" + "ldr q26, [x26, #0x30]\n" + "add x26, x26, #0x40\n" + "ldr q25, [x21, #0x20]\n" + "ldr q24, [x21, #0x30]\n" + "shl v23.16b, v31.16b, #0x4\n" + "shl v22.16b, v30.16b, #0x4\n" + "ldr q21, [x21, #0x40]\n" + "ldr q20, [x21, #0x50]\n" + "and v31.16b, v31.16b, v4.16b\n" + "and v30.16b, v30.16b, v4.16b\n" + "ldr q19, [x21, #0x60]\n" + "ldr q18, [x21, #0x70]\n" + "shl v17.16b, v27.16b, #0x4\n" + "shl v16.16b, v26.16b, #0x4\n" + ".inst 0x4e97a7a3 // smmla v3.4s, v29.16b, v23.16b\n" + ".inst 0x4e96a7a2 // smmla v2.4s, v29.16b, v22.16b\n" + "and v27.16b, v27.16b, v4.16b\n" + "add x21, x21, #0x80\n" + ".inst 0x4e97a781 // smmla v1.4s, v28.16b, v23.16b\n" + ".inst 0x4e96a780 // smmla v0.4s, v28.16b, v22.16b\n" + "and v26.16b, v26.16b, v4.16b\n" + ".inst 0x4e91a723 // smmla v3.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a722 // smmla v2.4s, v25.16b, v16.16b\n" + ".inst 0x4e91a701 // smmla v1.4s, v24.16b, v17.16b\n" + ".inst 0x4e90a700 // smmla v0.4s, v24.16b, v16.16b\n" + ".inst 0x4e9fa6a3 // smmla v3.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a2 // smmla v2.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa681 // smmla v1.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea680 // smmla v0.4s, v20.16b, v30.16b\n" + ".inst 0x4e9ba663 // smmla v3.4s, v19.16b, v27.16b\n" + ".inst 0x4e9aa662 // smmla v2.4s, v19.16b, v26.16b\n" + ".inst 0x4e9ba641 // smmla v1.4s, v18.16b, v27.16b\n" + ".inst 0x4e9aa640 // smmla v0.4s, v18.16b, v26.16b\n" + "bgt 3b\n" + "ldr q18, [x26, #0x0]\n" + "ldr q17, [x21, #0x0]\n" + "uzp1 v26.2d, v3.2d, v2.2d\n" + "uzp2 v25.2d, v3.2d, v2.2d\n" + "ldr q24, [x26, #0x10]\n" + "ldr q16, [x21, #0x10]\n" + "uzp1 v23.2d, v1.2d, v0.2d\n" + "uzp2 v22.2d, v1.2d, v0.2d\n" + "ld1r { v21.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x4\n" + "ld1r { v20.4s }, [x20]\n" + "mla v26.4s, v18.4s, v17.s[0]\n" + "mla v25.4s, v18.4s, v17.s[1]\n" + "add x26, x26, #0x20\n" + "mla v23.4s, v18.4s, v17.s[2]\n" + "mla v22.4s, v18.4s, v17.s[3]\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v26.4s, v26.4s\n" + "scvtf v25.4s, v25.4s\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v22.4s, v22.4s\n" + "fmul v26.4s, v26.4s, v19.4s\n" + "fmul v25.4s, v25.4s, v18.4s\n" + "fmul v23.4s, v23.4s, v17.4s\n" + "fmul v22.4s, v22.4s, v16.4s\n" + "fmax v26.4s, v26.4s, v21.4s\n" + "fmax v25.4s, v25.4s, v21.4s\n" + "fmax v23.4s, v23.4s, v21.4s\n" + "fmax v22.4s, v22.4s, v21.4s\n" + "fmin v26.4s, v26.4s, v20.4s\n" + "fmin v25.4s, v25.4s, v20.4s\n" + "fmin v23.4s, v23.4s, v20.4s\n" + "fmin v22.4s, v22.4s, v20.4s\n" + "bge 6f\n" + "mov x23, %x[dst]\n" + "cmp x27, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GE\n" + "cmp x27, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GE\n" + "cmp x27, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GE\n" + "tbz x25, #1, 4f\n" + "str d22, [x20], #0x8\n" + "str d23, [x21], #0x8\n" + "str d25, [x22], #0x8\n" + "str d26, [x23], #0x8\n" + "tbz x25, #0, 5f\n" + "st1 { v22.s }[2], [x20]\n" + "st1 { v23.s }[2], [x21]\n" + "st1 { v25.s }[2], [x22]\n" + "st1 { v26.s }[2], [x23]\n" + "b 5f\n" + "4:" // Output block 0: partial_1_0 + "str s22, [x20, #0x0]\n" + "str s23, [x21, #0x0]\n" + "str s25, [x22, #0x0]\n" + "str s26, [x23, #0x0]\n" + "5:" // Output block 0: Done + "b 7f\n" + "6:" // Full output + "mov x20, %x[dst]\n" + "cmp x27, #0x1\n" + "str q26, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 7f\n" + "cmp x27, #0x2\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 7f\n" + "cmp x27, #0x3\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 7f\n" + "str q22, [x20, #0x0]\n" + "7:" // Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "subs x27, x27, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x28\n" + "mov %x[dst], x24\n" + "bgt 1b\n" + "8:" // Row loop skip + : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) + : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), + [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", + "x28"); +#else + KAI_ASSERT(false); + KAI_UNUSED(m); + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(lhs_packed); + KAI_UNUSED(rhs_packed); + KAI_UNUSED(dst); + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(scalar_min); + KAI_UNUSED(scalar_max); +#endif +} diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h new file mode 100644 index 0000000000000000000000000000000000000000..782fa28ed3207c8318fa057b57cafc55fa969b96 --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h @@ -0,0 +1,142 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifndef __cplusplus +#include +#endif +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @return the m step value + */ +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void); + +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @return the n step + */ +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void); + +/** + * @brief Function to get the mr value, which must be used to pack the LHS matrix with + * the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel + * + * @return the mr value + */ +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the nr value + */ +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void); + +/** + * @brief Function to get the kr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the kr value + */ +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void); + +/** + * @brief Function to get the sr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the sr value + */ +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(void); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). + * @param[in] k Total number of columns in the LHS matrix (not packed). + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t m_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. + * @param[in] k The common dimension between the LHS and RHS matrix (K). + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t n_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the DST matrix + * + * @param[in] m_idx Row index in the DST matrix. It must be a multiple of 4. + * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * + * return the DST offset in bytes + */ +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/** + * @brief Function to query the size in bytes for the destination matrix. + * + * @param[in] m Number of rows in the destination (DST) matrix. + * @param[in] n Number of columns in the destination (DST) matrix. + */ +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm(size_t m, size_t n); + +/** + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. + * + * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed + * RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. + * Output tile: (rows x cols) = 4 x 4 + * Accumulation performed in a single for loop: 32 + * Instruction used: i8mm + * + * @param[in] m The number of output rows written. + * @param[in] n The number of output columns written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. + * @param[in] lhs_packed The LHS matrix packed. + * When the activation are dynamically quantized, you can obtain this matrix + * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs + * both the dynamic quantization to 8-bit and activation packing in a single step. + * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 + * @param[out] dst Result of the vector-by-matrix + * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. + * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) + * @param[in] scalar_min Min value used to clamp the final result. + * @param[in] scalar_max Max value used to clamp the final result. + */ +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c new file mode 100644 index 0000000000000000000000000000000000000000..2c348c3741e74b7c1aa3f22593058ce3b4a1a053 --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c @@ -0,0 +1,497 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" + +#include +#include + +#include "kai_common.h" + +static const size_t kai_m_step = 8; +static const size_t kai_n_step = 4; +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max) { +#if defined(__ARM_FEATURE_MATMUL_INT8) + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t k_internal = kai_k_roundedup(k); + + size_t num_blocks = k_internal / 32; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x12, %x[m]\n" + "mov x11, #0x80\n" + "movi v11.16b, #0xf0\n" + "mov x20, #0x20\n" + "cmp x12, #0x8\n" + "madd x11, %x[num_blocks], x11, x20\n" + "blt 8f\n" + "1:" // Row loop + "mov x10, %x[rhs_packed]\n" + "mov x9, %x[n]\n" + "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" + "2:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v10.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "mov x21, %x[num_blocks]\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "movi v4.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "add x20, x22, x11\n" + "3:" // Block loop + "ldr q2, [x10, #0x0]\n" + "ldr q1, [x10, #0x10]\n" + "subs x21, x21, #0x1\n" + "ldr q20, [x22, #0x0]\n" + "ldr q19, [x22, #0x10]\n" + "ldr q18, [x20, #0x0]\n" + "ldr q0, [x20, #0x10]\n" + "ldr q31, [x10, #0x20]\n" + "ldr q30, [x10, #0x30]\n" + "shl v17.16b, v2.16b, #0x4\n" + "shl v16.16b, v1.16b, #0x4\n" + "ldr q29, [x22, #0x20]\n" + "ldr q28, [x22, #0x30]\n" + "and v2.16b, v2.16b, v11.16b\n" + "and v1.16b, v1.16b, v11.16b\n" + "ldr q27, [x20, #0x20]\n" + "ldr q26, [x20, #0x30]\n" + "add x10, x10, #0x40\n" + "ldr q25, [x22, #0x40]\n" + "ldr q24, [x22, #0x50]\n" + ".inst 0x4e91a68a // smmla v10.4s, v20.16b, v17.16b\n" + ".inst 0x4e90a689 // smmla v9.4s, v20.16b, v16.16b\n" + "ldr q23, [x20, #0x40]\n" + "ldr q22, [x20, #0x50]\n" + ".inst 0x4e91a668 // smmla v8.4s, v19.16b, v17.16b\n" + ".inst 0x4e90a667 // smmla v7.4s, v19.16b, v16.16b\n" + "ldr q21, [x22, #0x60]\n" + "ldr q20, [x22, #0x70]\n" + ".inst 0x4e91a646 // smmla v6.4s, v18.16b, v17.16b\n" + ".inst 0x4e90a645 // smmla v5.4s, v18.16b, v16.16b\n" + "ldr q19, [x20, #0x60]\n" + "ldr q18, [x20, #0x70]\n" + ".inst 0x4e91a404 // smmla v4.4s, v0.16b, v17.16b\n" + ".inst 0x4e90a403 // smmla v3.4s, v0.16b, v16.16b\n" + "shl v17.16b, v31.16b, #0x4\n" + "shl v16.16b, v30.16b, #0x4\n" + "add x22, x22, #0x80\n" + "add x20, x20, #0x80\n" + "and v31.16b, v31.16b, v11.16b\n" + "and v30.16b, v30.16b, v11.16b\n" + ".inst 0x4e91a7aa // smmla v10.4s, v29.16b, v17.16b\n" + ".inst 0x4e90a7a9 // smmla v9.4s, v29.16b, v16.16b\n" + ".inst 0x4e91a788 // smmla v8.4s, v28.16b, v17.16b\n" + ".inst 0x4e90a787 // smmla v7.4s, v28.16b, v16.16b\n" + ".inst 0x4e91a766 // smmla v6.4s, v27.16b, v17.16b\n" + ".inst 0x4e90a765 // smmla v5.4s, v27.16b, v16.16b\n" + ".inst 0x4e91a744 // smmla v4.4s, v26.16b, v17.16b\n" + ".inst 0x4e90a743 // smmla v3.4s, v26.16b, v16.16b\n" + ".inst 0x4e82a72a // smmla v10.4s, v25.16b, v2.16b\n" + ".inst 0x4e81a729 // smmla v9.4s, v25.16b, v1.16b\n" + ".inst 0x4e82a708 // smmla v8.4s, v24.16b, v2.16b\n" + ".inst 0x4e81a707 // smmla v7.4s, v24.16b, v1.16b\n" + ".inst 0x4e82a6e6 // smmla v6.4s, v23.16b, v2.16b\n" + ".inst 0x4e81a6e5 // smmla v5.4s, v23.16b, v1.16b\n" + ".inst 0x4e82a6c4 // smmla v4.4s, v22.16b, v2.16b\n" + ".inst 0x4e81a6c3 // smmla v3.4s, v22.16b, v1.16b\n" + ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" + ".inst 0x4e9fa666 // smmla v6.4s, v19.16b, v31.16b\n" + ".inst 0x4e9ea665 // smmla v5.4s, v19.16b, v30.16b\n" + ".inst 0x4e9fa644 // smmla v4.4s, v18.16b, v31.16b\n" + ".inst 0x4e9ea643 // smmla v3.4s, v18.16b, v30.16b\n" + "bgt 3b\n" + "ldr q20, [x10, #0x0]\n" + "ldr q19, [x22, #0x0]\n" + "uzp1 v2.2d, v10.2d, v9.2d\n" + "uzp2 v1.2d, v10.2d, v9.2d\n" + "ldr q18, [x20, #0x0]\n" + "ldr q0, [x10, #0x10]\n" + "uzp1 v31.2d, v8.2d, v7.2d\n" + "uzp2 v30.2d, v8.2d, v7.2d\n" + "ldr q17, [x22, #0x10]\n" + "ldr q16, [x20, #0x10]\n" + "uzp1 v29.2d, v6.2d, v5.2d\n" + "uzp2 v28.2d, v6.2d, v5.2d\n" + "ld1r { v27.4s }, [%x[clamp_vals]]\n" + "uzp1 v26.2d, v4.2d, v3.2d\n" + "uzp2 v25.2d, v4.2d, v3.2d\n" + "mla v2.4s, v20.4s, v19.s[0]\n" + "mla v1.4s, v20.4s, v19.s[1]\n" + "mla v31.4s, v20.4s, v19.s[2]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x9, #0x4\n" + "ld1r { v24.4s }, [x20]\n" + "mla v30.4s, v20.4s, v19.s[3]\n" + "mla v29.4s, v20.4s, v18.s[0]\n" + "fmul v23.4s, v0.4s, v17.s[0]\n" + "mla v28.4s, v20.4s, v18.s[1]\n" + "mla v26.4s, v20.4s, v18.s[2]\n" + "fmul v22.4s, v0.4s, v17.s[1]\n" + "add x10, x10, #0x20\n" + "mla v25.4s, v20.4s, v18.s[3]\n" + "scvtf v2.4s, v2.4s\n" + "scvtf v1.4s, v1.4s\n" + "scvtf v31.4s, v31.4s\n" + "fmul v21.4s, v0.4s, v17.s[2]\n" + "scvtf v30.4s, v30.4s\n" + "fmul v20.4s, v0.4s, v17.s[3]\n" + "scvtf v29.4s, v29.4s\n" + "fmul v19.4s, v0.4s, v16.s[0]\n" + "scvtf v28.4s, v28.4s\n" + "fmul v18.4s, v0.4s, v16.s[1]\n" + "scvtf v26.4s, v26.4s\n" + "fmul v17.4s, v0.4s, v16.s[2]\n" + "scvtf v25.4s, v25.4s\n" + "fmul v16.4s, v0.4s, v16.s[3]\n" + "fmul v2.4s, v2.4s, v23.4s\n" + "fmul v1.4s, v1.4s, v22.4s\n" + "fmul v31.4s, v31.4s, v21.4s\n" + "fmul v30.4s, v30.4s, v20.4s\n" + "fmul v29.4s, v29.4s, v19.4s\n" + "fmul v28.4s, v28.4s, v18.4s\n" + "fmul v26.4s, v26.4s, v17.4s\n" + "fmul v25.4s, v25.4s, v16.4s\n" + "fmax v2.4s, v2.4s, v27.4s\n" + "fmax v1.4s, v1.4s, v27.4s\n" + "fmax v31.4s, v31.4s, v27.4s\n" + "fmax v30.4s, v30.4s, v27.4s\n" + "fmax v29.4s, v29.4s, v27.4s\n" + "fmax v28.4s, v28.4s, v27.4s\n" + "fmax v26.4s, v26.4s, v27.4s\n" + "fmax v25.4s, v25.4s, v27.4s\n" + "fmin v2.4s, v2.4s, v24.4s\n" + "fmin v1.4s, v1.4s, v24.4s\n" + "fmin v31.4s, v31.4s, v24.4s\n" + "fmin v30.4s, v30.4s, v24.4s\n" + "fmin v29.4s, v29.4s, v24.4s\n" + "fmin v28.4s, v28.4s, v24.4s\n" + "fmin v26.4s, v26.4s, v24.4s\n" + "fmin v25.4s, v25.4s, v24.4s\n" + "bge 6f\n" + "mov x27, %x[dst]\n" + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x27, %x[dst_stride_row], LSL #1\n" + "add x21, x27, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "tbz x9, #1, 4f\n" + "str d25, [x23], #0x8\n" + "str d26, [x25], #0x8\n" + "str d28, [x24], #0x8\n" + "str d29, [x26], #0x8\n" + "str d30, [x20], #0x8\n" + "str d31, [x22], #0x8\n" + "str d1, [x21], #0x8\n" + "str d2, [x27], #0x8\n" + "tbz x9, #0, 5f\n" + "st1 { v25.s }[2], [x23]\n" + "st1 { v26.s }[2], [x25]\n" + "st1 { v28.s }[2], [x24]\n" + "st1 { v29.s }[2], [x26]\n" + "st1 { v30.s }[2], [x20]\n" + "st1 { v31.s }[2], [x22]\n" + "st1 { v1.s }[2], [x21]\n" + "st1 { v2.s }[2], [x27]\n" + "b 5f\n" + "4:" // Output block 0: partial_1_0 + "str s25, [x23, #0x0]\n" + "str s26, [x25, #0x0]\n" + "str s28, [x24, #0x0]\n" + "str s29, [x26, #0x0]\n" + "str s30, [x20, #0x0]\n" + "str s31, [x22, #0x0]\n" + "str s1, [x21, #0x0]\n" + "str s2, [x27, #0x0]\n" + "5:" // Output block 0: Done + "b 7f\n" + "6:" // Full output + "mov x20, %x[dst]\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q1, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q31, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q29, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q26, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q25, [x20, #0x0]\n" + "7:" // Output stage exit + "subs x9, x9, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "mov x20, #0x2\n" + "sub x12, x12, #0x8\n" + "cmp x12, #0x8\n" + "mov %x[dst], x28\n" + "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" + "bge 1b\n" + "8:" // Row loop skip + "cbz x12, 16f\n" + "9:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "10:" // Row tail: Column loop + "movi v10.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "mov x22, %x[lhs_packed]\n" + "mov x20, %x[num_blocks]\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "11:" // Row tail: Block loop + "ldr q31, [x26, #0x0]\n" + "ldr q30, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q29, [x22, #0x0]\n" + "ldr q28, [x22, #0x10]\n" + "ldr q27, [x26, #0x20]\n" + "ldr q26, [x26, #0x30]\n" + "add x26, x26, #0x40\n" + "ldr q25, [x22, #0x20]\n" + "ldr q24, [x22, #0x30]\n" + "shl v23.16b, v31.16b, #0x4\n" + "shl v22.16b, v30.16b, #0x4\n" + "ldr q21, [x22, #0x40]\n" + "ldr q20, [x22, #0x50]\n" + "and v31.16b, v31.16b, v11.16b\n" + "and v30.16b, v30.16b, v11.16b\n" + "ldr q19, [x22, #0x60]\n" + "ldr q18, [x22, #0x70]\n" + "shl v17.16b, v27.16b, #0x4\n" + "shl v16.16b, v26.16b, #0x4\n" + ".inst 0x4e97a7aa // smmla v10.4s, v29.16b, v23.16b\n" + ".inst 0x4e96a7a9 // smmla v9.4s, v29.16b, v22.16b\n" + "and v27.16b, v27.16b, v11.16b\n" + "add x22, x22, #0x80\n" + ".inst 0x4e97a788 // smmla v8.4s, v28.16b, v23.16b\n" + ".inst 0x4e96a787 // smmla v7.4s, v28.16b, v22.16b\n" + "and v26.16b, v26.16b, v11.16b\n" + ".inst 0x4e91a72a // smmla v10.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a729 // smmla v9.4s, v25.16b, v16.16b\n" + ".inst 0x4e91a708 // smmla v8.4s, v24.16b, v17.16b\n" + ".inst 0x4e90a707 // smmla v7.4s, v24.16b, v16.16b\n" + ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" + ".inst 0x4e9ba66a // smmla v10.4s, v19.16b, v27.16b\n" + ".inst 0x4e9aa669 // smmla v9.4s, v19.16b, v26.16b\n" + ".inst 0x4e9ba648 // smmla v8.4s, v18.16b, v27.16b\n" + ".inst 0x4e9aa647 // smmla v7.4s, v18.16b, v26.16b\n" + "bgt 11b\n" + "ldr q18, [x26, #0x0]\n" + "ldr q17, [x22, #0x0]\n" + "uzp1 v26.2d, v10.2d, v9.2d\n" + "uzp2 v25.2d, v10.2d, v9.2d\n" + "ldr q24, [x26, #0x10]\n" + "ldr q16, [x22, #0x10]\n" + "uzp1 v23.2d, v8.2d, v7.2d\n" + "uzp2 v22.2d, v8.2d, v7.2d\n" + "ld1r { v21.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x4\n" + "ld1r { v20.4s }, [x20]\n" + "mla v26.4s, v18.4s, v17.s[0]\n" + "mla v25.4s, v18.4s, v17.s[1]\n" + "add x26, x26, #0x20\n" + "mla v23.4s, v18.4s, v17.s[2]\n" + "mla v22.4s, v18.4s, v17.s[3]\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v26.4s, v26.4s\n" + "scvtf v25.4s, v25.4s\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v22.4s, v22.4s\n" + "fmul v26.4s, v26.4s, v19.4s\n" + "fmul v25.4s, v25.4s, v18.4s\n" + "fmul v23.4s, v23.4s, v17.4s\n" + "fmul v22.4s, v22.4s, v16.4s\n" + "fmax v26.4s, v26.4s, v21.4s\n" + "fmax v25.4s, v25.4s, v21.4s\n" + "fmax v23.4s, v23.4s, v21.4s\n" + "fmax v22.4s, v22.4s, v21.4s\n" + "fmin v26.4s, v26.4s, v20.4s\n" + "fmin v25.4s, v25.4s, v20.4s\n" + "fmin v23.4s, v23.4s, v20.4s\n" + "fmin v22.4s, v22.4s, v20.4s\n" + "bge 14f\n" + "mov x23, %x[dst]\n" + "cmp x12, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GE\n" + "cmp x12, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GE\n" + "cmp x12, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GE\n" + "tbz x25, #1, 12f\n" + "str d22, [x20], #0x8\n" + "str d23, [x21], #0x8\n" + "str d25, [x22], #0x8\n" + "str d26, [x23], #0x8\n" + "tbz x25, #0, 13f\n" + "st1 { v22.s }[2], [x20]\n" + "st1 { v23.s }[2], [x21]\n" + "st1 { v25.s }[2], [x22]\n" + "st1 { v26.s }[2], [x23]\n" + "b 13f\n" + "12:" // Row tail: Output block 0: partial_1_0 + "str s22, [x20, #0x0]\n" + "str s23, [x21, #0x0]\n" + "str s25, [x22, #0x0]\n" + "str s26, [x23, #0x0]\n" + "13:" // Row tail: Output block 0: Done + "b 15f\n" + "14:" // Row tail: Full output + "mov x20, %x[dst]\n" + "cmp x12, #0x1\n" + "str q26, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 15f\n" + "cmp x12, #0x2\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 15f\n" + "cmp x12, #0x3\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 15f\n" + "str q22, [x20, #0x0]\n" + "15:" // Row tail: Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 10b\n" + "subs x12, x12, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x11\n" + "mov %x[dst], x24\n" + "bgt 9b\n" + "16:" // Row tail: Row loop skip + : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) + : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), + [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", + "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +#else + KAI_ASSERT(false); + KAI_UNUSED(m); + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(lhs_packed); + KAI_UNUSED(rhs_packed); + KAI_UNUSED(dst); + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(scalar_min); + KAI_UNUSED(scalar_max); +#endif +} diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h new file mode 100644 index 0000000000000000000000000000000000000000..1f350fe97136d4ac949b2761fc0355c0b83ae537 --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h @@ -0,0 +1,142 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifndef __cplusplus +#include +#endif +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @return the m step value + */ +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @return the n step + */ +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/** + * @brief Function to get the mr value, which must be used to pack the LHS matrix with + * the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel + * + * @return the mr value + */ +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the nr value + */ +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/** + * @brief Function to get the kr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the kr value + */ +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/** + * @brief Function to get the sr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the sr value + */ +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qai8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 8 + * @param[in] k Total number of columns in the LHS matrix (not packed). + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. + * @param[in] k The common dimension between the LHS and RHS matrix (K). + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t n_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the DST matrix + * + * @param[in] m_idx Row index in the DST matrix. It must be a multiple of 8. + * @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * + * return the DST offset in bytes + */ +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/** + * @brief Function to query the size in bytes for the destination matrix. + * + * @param[in] m Number of rows in the destination (DST) matrix. + * @param[in] n Number of columns in the destination (DST) matrix. + */ +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n); + +/** + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. + * + * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed + * RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. + * Output tile: (rows x cols) = 8 x 4 + * Accumulation performed in a single for loop: 32 + * Instruction used: i8mm + * + * @param[in] m The number of output rows written. + * @param[in] n The number of output columns written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. + * @param[in] lhs_packed The LHS matrix packed. + * When the activation are dynamically quantized, you can obtain this matrix + * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs + * both the dynamic quantization to 8-bit and activation packing in a single step. + * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 + * @param[out] dst Result of the vector-by-matrix + * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. + * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) + * @param[in] scalar_min Min value used to clamp the final result. + * @param[in] scalar_max Max value used to clamp the final result. + */ +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c new file mode 100644 index 0000000000000000000000000000000000000000..2704c853ab3c39510cab5849c952d1e8ba266882 --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c @@ -0,0 +1,369 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h" + +#include +#include + +#include "kai_common.h" + +static const size_t kai_m_step = 4; +static const size_t kai_n_step = 8; +static const size_t kai_mr = 4; +static const size_t kai_nr = 8; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { +#if defined(__ARM_FEATURE_MATMUL_INT8) + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t k_internal = kai_k_roundedup(k); + + size_t num_blocks = k_internal / 32; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x28, #0x80\n" + "mov x20, #0x20\n" + "movi v12.16b, #0xf0\n" + "mov x27, %x[m]\n" + "madd x28, %x[num_blocks], x28, x20\n" + "cbz x27, 10f\n" + "1:" // Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "2:" // Column loop + "movi v11.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "mov x21, %x[lhs_packed]\n" + "mov x20, %x[num_blocks]\n" + "movi v9.4s, #0x0\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "movi v4.4s, #0x0\n" + "3:" // Block loop + "ldr q3, [x26, #0x0]\n" + "ldr q2, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q1, [x26, #0x20]\n" + "ldr q0, [x26, #0x30]\n" + "ldr q31, [x21, #0x0]\n" + "ldr q30, [x21, #0x10]\n" + "ldr q29, [x26, #0x40]\n" + "ldr q28, [x26, #0x50]\n" + "shl v19.16b, v3.16b, #0x4\n" + "shl v18.16b, v2.16b, #0x4\n" + "ldr q27, [x26, #0x60]\n" + "ldr q26, [x26, #0x70]\n" + "shl v17.16b, v1.16b, #0x4\n" + "shl v16.16b, v0.16b, #0x4\n" + "ldr q25, [x21, #0x20]\n" + "ldr q24, [x21, #0x30]\n" + "and v3.16b, v3.16b, v12.16b\n" + "and v2.16b, v2.16b, v12.16b\n" + "ldr q23, [x21, #0x40]\n" + "ldr q22, [x21, #0x50]\n" + ".inst 0x4e93a7eb // smmla v11.4s, v31.16b, v19.16b\n" + ".inst 0x4e92a7e9 // smmla v9.4s, v31.16b, v18.16b\n" + "ldr q21, [x21, #0x60]\n" + "ldr q20, [x21, #0x70]\n" + ".inst 0x4e91a7ea // smmla v10.4s, v31.16b, v17.16b\n" + ".inst 0x4e90a7e8 // smmla v8.4s, v31.16b, v16.16b\n" + ".inst 0x4e93a7c7 // smmla v7.4s, v30.16b, v19.16b\n" + ".inst 0x4e92a7c5 // smmla v5.4s, v30.16b, v18.16b\n" + "shl v19.16b, v29.16b, #0x4\n" + "add x21, x21, #0x80\n" + ".inst 0x4e91a7c6 // smmla v6.4s, v30.16b, v17.16b\n" + ".inst 0x4e90a7c4 // smmla v4.4s, v30.16b, v16.16b\n" + "shl v18.16b, v28.16b, #0x4\n" + "add x26, x26, #0x80\n" + "shl v17.16b, v27.16b, #0x4\n" + "shl v16.16b, v26.16b, #0x4\n" + ".inst 0x4e93a72b // smmla v11.4s, v25.16b, v19.16b\n" + "and v1.16b, v1.16b, v12.16b\n" + "and v0.16b, v0.16b, v12.16b\n" + ".inst 0x4e92a729 // smmla v9.4s, v25.16b, v18.16b\n" + ".inst 0x4e93a707 // smmla v7.4s, v24.16b, v19.16b\n" + ".inst 0x4e92a705 // smmla v5.4s, v24.16b, v18.16b\n" + "and v29.16b, v29.16b, v12.16b\n" + ".inst 0x4e91a72a // smmla v10.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a728 // smmla v8.4s, v25.16b, v16.16b\n" + "and v28.16b, v28.16b, v12.16b\n" + ".inst 0x4e91a706 // smmla v6.4s, v24.16b, v17.16b\n" + ".inst 0x4e90a704 // smmla v4.4s, v24.16b, v16.16b\n" + "and v27.16b, v27.16b, v12.16b\n" + ".inst 0x4e83a6eb // smmla v11.4s, v23.16b, v3.16b\n" + ".inst 0x4e82a6e9 // smmla v9.4s, v23.16b, v2.16b\n" + "and v26.16b, v26.16b, v12.16b\n" + ".inst 0x4e83a6c7 // smmla v7.4s, v22.16b, v3.16b\n" + ".inst 0x4e82a6c5 // smmla v5.4s, v22.16b, v2.16b\n" + ".inst 0x4e81a6ea // smmla v10.4s, v23.16b, v1.16b\n" + ".inst 0x4e80a6e8 // smmla v8.4s, v23.16b, v0.16b\n" + ".inst 0x4e81a6c6 // smmla v6.4s, v22.16b, v1.16b\n" + ".inst 0x4e80a6c4 // smmla v4.4s, v22.16b, v0.16b\n" + ".inst 0x4e9da6ab // smmla v11.4s, v21.16b, v29.16b\n" + ".inst 0x4e9ca6a9 // smmla v9.4s, v21.16b, v28.16b\n" + ".inst 0x4e9da687 // smmla v7.4s, v20.16b, v29.16b\n" + ".inst 0x4e9ca685 // smmla v5.4s, v20.16b, v28.16b\n" + ".inst 0x4e9ba6aa // smmla v10.4s, v21.16b, v27.16b\n" + ".inst 0x4e9aa6a8 // smmla v8.4s, v21.16b, v26.16b\n" + ".inst 0x4e9ba686 // smmla v6.4s, v20.16b, v27.16b\n" + ".inst 0x4e9aa684 // smmla v4.4s, v20.16b, v26.16b\n" + "bgt 3b\n" + "ldr q20, [x26, #0x0]\n" + "ldr q19, [x26, #0x10]\n" + "uzp1 v2.2d, v11.2d, v9.2d\n" + "uzp1 v1.2d, v10.2d, v8.2d\n" + "ldr q18, [x21, #0x0]\n" + "ldr q17, [x26, #0x20]\n" + "uzp2 v0.2d, v11.2d, v9.2d\n" + "uzp2 v31.2d, v10.2d, v8.2d\n" + "ldr q30, [x26, #0x30]\n" + "ldr q16, [x21, #0x10]\n" + "uzp1 v29.2d, v7.2d, v5.2d\n" + "uzp1 v28.2d, v6.2d, v4.2d\n" + "ld1r { v27.4s }, [%x[clamp_vals]]\n" + "uzp2 v26.2d, v7.2d, v5.2d\n" + "uzp2 v25.2d, v6.2d, v4.2d\n" + "add x20, %x[clamp_vals], #0x4\n" + "ld1r { v24.4s }, [x20]\n" + "mla v2.4s, v20.4s, v18.s[0]\n" + "mla v1.4s, v19.4s, v18.s[0]\n" + "cmp x25, #0x8\n" + "mla v0.4s, v20.4s, v18.s[1]\n" + "mla v31.4s, v19.4s, v18.s[1]\n" + "fmul v23.4s, v17.4s, v16.s[0]\n" + "add x26, x26, #0x40\n" + "mla v29.4s, v20.4s, v18.s[2]\n" + "mla v28.4s, v19.4s, v18.s[2]\n" + "fmul v22.4s, v30.4s, v16.s[0]\n" + "mla v26.4s, v20.4s, v18.s[3]\n" + "mla v25.4s, v19.4s, v18.s[3]\n" + "fmul v21.4s, v17.4s, v16.s[1]\n" + "scvtf v2.4s, v2.4s\n" + "scvtf v1.4s, v1.4s\n" + "scvtf v0.4s, v0.4s\n" + "scvtf v31.4s, v31.4s\n" + "fmul v20.4s, v30.4s, v16.s[1]\n" + "scvtf v29.4s, v29.4s\n" + "fmul v19.4s, v17.4s, v16.s[2]\n" + "scvtf v28.4s, v28.4s\n" + "fmul v18.4s, v30.4s, v16.s[2]\n" + "scvtf v26.4s, v26.4s\n" + "fmul v17.4s, v17.4s, v16.s[3]\n" + "scvtf v25.4s, v25.4s\n" + "fmul v16.4s, v30.4s, v16.s[3]\n" + "fmul v2.4s, v2.4s, v23.4s\n" + "fmul v1.4s, v1.4s, v22.4s\n" + "fmul v0.4s, v0.4s, v21.4s\n" + "fmul v31.4s, v31.4s, v20.4s\n" + "fmul v29.4s, v29.4s, v19.4s\n" + "fmul v28.4s, v28.4s, v18.4s\n" + "fmul v26.4s, v26.4s, v17.4s\n" + "fmul v25.4s, v25.4s, v16.4s\n" + "fmax v2.4s, v2.4s, v27.4s\n" + "fmax v1.4s, v1.4s, v27.4s\n" + "fmax v0.4s, v0.4s, v27.4s\n" + "fmax v31.4s, v31.4s, v27.4s\n" + "fmax v29.4s, v29.4s, v27.4s\n" + "fmax v28.4s, v28.4s, v27.4s\n" + "fmax v26.4s, v26.4s, v27.4s\n" + "fmax v25.4s, v25.4s, v27.4s\n" + "fmin v2.4s, v2.4s, v24.4s\n" + "fmin v1.4s, v1.4s, v24.4s\n" + "fmin v0.4s, v0.4s, v24.4s\n" + "fmin v31.4s, v31.4s, v24.4s\n" + "fmin v29.4s, v29.4s, v24.4s\n" + "fmin v28.4s, v28.4s, v24.4s\n" + "fmin v26.4s, v26.4s, v24.4s\n" + "fmin v25.4s, v25.4s, v24.4s\n" + "bge 8f\n" + "mov x23, %x[dst]\n" + "cmp x27, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GE\n" + "cmp x27, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GE\n" + "cmp x27, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GE\n" + "tbz x25, #2, 5f\n" + "st1 { v26.4s }, [x20], #0x10\n" + "st1 { v29.4s }, [x21], #0x10\n" + "st1 { v0.4s }, [x22], #0x10\n" + "st1 { v2.4s }, [x23], #0x10\n" + "tbz x25, #1, 4f\n" + "str d25, [x20], #0x8\n" + "str d28, [x21], #0x8\n" + "str d31, [x22], #0x8\n" + "str d1, [x23], #0x8\n" + "tbz x25, #0, 7f\n" + "st1 { v25.s }[2], [x20]\n" + "st1 { v28.s }[2], [x21]\n" + "st1 { v31.s }[2], [x22]\n" + "st1 { v1.s }[2], [x23]\n" + "b 7f\n" + "4:" // Output block 0: partial_1_4 + "tbz x25, #0, 7f\n" + "str s25, [x20, #0x0]\n" + "str s28, [x21, #0x0]\n" + "str s31, [x22, #0x0]\n" + "str s1, [x23, #0x0]\n" + "b 7f\n" + "5:" // Output block 0: partial_2_0 + "tbz x25, #1, 6f\n" + "str d26, [x20], #0x8\n" + "str d29, [x21], #0x8\n" + "str d0, [x22], #0x8\n" + "str d2, [x23], #0x8\n" + "tbz x25, #0, 7f\n" + "st1 { v26.s }[2], [x20]\n" + "st1 { v29.s }[2], [x21]\n" + "st1 { v0.s }[2], [x22]\n" + "st1 { v2.s }[2], [x23]\n" + "b 7f\n" + "6:" // Output block 0: partial_1_0 + "str s26, [x20, #0x0]\n" + "str s29, [x21, #0x0]\n" + "str s0, [x22, #0x0]\n" + "str s2, [x23, #0x0]\n" + "7:" // Output block 0: Done + "b 9f\n" + "8:" // Full output + "mov x20, %x[dst]\n" + "cmp x27, #0x1\n" + "str q2, [x20, #0x0]\n" + "str q1, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 9f\n" + "cmp x27, #0x2\n" + "str q0, [x20, #0x0]\n" + "str q31, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 9f\n" + "cmp x27, #0x3\n" + "str q29, [x20, #0x0]\n" + "str q28, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 9f\n" + "str q26, [x20, #0x0]\n" + "str q25, [x20, #0x10]\n" + "9:" // Output stage exit + "subs x25, x25, #0x8\n" + "add %x[dst], %x[dst], #0x20\n" + "bgt 2b\n" + "subs x27, x27, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x28\n" + "mov %x[dst], x24\n" + "bgt 1b\n" + "10:" // Row loop skip + : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) + : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), + [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", + "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +#else + KAI_ASSERT(false); + KAI_UNUSED(m); + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(lhs_packed); + KAI_UNUSED(rhs_packed); + KAI_UNUSED(dst); + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(scalar_min); + KAI_UNUSED(scalar_max); +#endif +} diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h new file mode 100644 index 0000000000000000000000000000000000000000..6e6363c7afb25c1b4c60b753f65ad2388cfd87d0 --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h @@ -0,0 +1,142 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifndef __cplusplus +#include +#endif +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @return the m step value + */ +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void); + +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @return the n step + */ +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void); + +/** + * @brief Function to get the mr value, which must be used to pack the LHS matrix with + * the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel + * + * @return the mr value + */ +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the nr value + */ +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void); + +/** + * @brief Function to get the kr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the kr value + */ +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void); + +/** + * @brief Function to get the sr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the sr value + */ +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(void); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 4. + * @param[in] k Total number of columns in the LHS matrix (not packed). + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t m_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 8. + * @param[in] k The common dimension between the LHS and RHS matrix (K). + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t n_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the DST matrix + * + * @param[in] m_idx Row index in the DST matrix. It must be a multiple of 4. + * @param[in] n_idx Column index in the DST matrix. It must be a multiple of 8. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * + * return the DST offset in bytes + */ +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/** + * @brief Function to query the size in bytes for the destination matrix. + * + * @param[in] m Number of rows in the destination (DST) matrix. + * @param[in] n Number of columns in the destination (DST) matrix. + */ +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm(size_t m, size_t n); + +/** + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. + * + * LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed + * RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. + * Output tile: (rows x cols) = 4 x 8 + * Accumulation performed in a single for loop: 32 + * Instruction used: i8mm + * + * @param[in] m The number of output rows written. + * @param[in] n The number of output columns written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. + * @param[in] lhs_packed The LHS matrix packed. + * When the activation are dynamically quantized, you can obtain this matrix + * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs + * both the dynamic quantization to 8-bit and activation packing in a single step. + * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 + * @param[out] dst Result of the vector-by-matrix + * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. + * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) + * @param[in] scalar_min Min value used to clamp the final result. + * @param[in] scalar_max Max value used to clamp the final result. + */ +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c new file mode 100644 index 0000000000000000000000000000000000000000..6a40b814348e4f472ffcbb8f8b9663529e53d4dc --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c @@ -0,0 +1,745 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h" + +#include +#include + +#include "kai_common.h" + +static const size_t kai_m_step = 8; +static const size_t kai_n_step = 8; +static const size_t kai_mr = 4; +static const size_t kai_nr = 8; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max) { +#if defined(__ARM_FEATURE_MATMUL_INT8) + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t k_internal = kai_k_roundedup(k); + + size_t num_blocks = k_internal / 32; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x12, %x[m]\n" + "mov x11, #0x80\n" + "movi v3.16b, #0xf0\n" + "mov x20, #0x20\n" + "cmp x12, #0x8\n" + "madd x11, %x[num_blocks], x11, x20\n" + "blt 10f\n" + "1:" // Row loop + "mov x10, %x[rhs_packed]\n" + "mov x9, %x[n]\n" + "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" + "2:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v13.4s, #0x0\n" + "movi v14.4s, #0x0\n" + "mov x21, %x[num_blocks]\n" + "movi v25.4s, #0x0\n" + "movi v16.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "add x20, x22, x11\n" + "movi v24.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v15.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v11.4s, #0x0\n" + "movi v31.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "3:" // Block loop + "ldr q21, [x10, #0x0]\n" + "ldr q20, [x10, #0x10]\n" + "subs x21, x21, #0x1\n" + "ldr q2, [x10, #0x20]\n" + "ldr q23, [x10, #0x30]\n" + "ldr q8, [x22, #0x0]\n" + "ldr q1, [x22, #0x10]\n" + "ldr q12, [x20, #0x0]\n" + "ldr q6, [x20, #0x10]\n" + "shl v17.16b, v21.16b, #0x4\n" + "shl v22.16b, v20.16b, #0x4\n" + "ldr q9, [x10, #0x40]\n" + "ldr q18, [x10, #0x50]\n" + "shl v4.16b, v2.16b, #0x4\n" + "shl v5.16b, v23.16b, #0x4\n" + "ldr q27, [x10, #0x60]\n" + "and v21.16b, v21.16b, v3.16b\n" + "and v20.16b, v20.16b, v3.16b\n" + ".inst 0x4e91a50d // smmla v13.4s, v8.16b, v17.16b\n" + ".inst 0x4e96a519 // smmla v25.4s, v8.16b, v22.16b\n" + ".inst 0x4e91a43a // smmla v26.4s, v1.16b, v17.16b\n" + "and v2.16b, v2.16b, v3.16b\n" + ".inst 0x4e84a50e // smmla v14.4s, v8.16b, v4.16b\n" + ".inst 0x4e85a510 // smmla v16.4s, v8.16b, v5.16b\n" + "ldr q8, [x10, #0x70]\n" + "and v23.16b, v23.16b, v3.16b\n" + ".inst 0x4e96a42a // smmla v10.4s, v1.16b, v22.16b\n" + ".inst 0x4e84a43e // smmla v30.4s, v1.16b, v4.16b\n" + "add x10, x10, #0x80\n" + ".inst 0x4e85a433 // smmla v19.4s, v1.16b, v5.16b\n" + "ldr q1, [x22, #0x20]\n" + ".inst 0x4e91a598 // smmla v24.4s, v12.16b, v17.16b\n" + ".inst 0x4e96a59c // smmla v28.4s, v12.16b, v22.16b\n" + ".inst 0x4e84a580 // smmla v0.4s, v12.16b, v4.16b\n" + ".inst 0x4e85a58f // smmla v15.4s, v12.16b, v5.16b\n" + "ldr q12, [x22, #0x30]\n" + ".inst 0x4e91a4dd // smmla v29.4s, v6.16b, v17.16b\n" + "ldr q17, [x20, #0x20]\n" + ".inst 0x4e96a4df // smmla v31.4s, v6.16b, v22.16b\n" + "ldr q22, [x20, #0x30]\n" + ".inst 0x4e84a4cb // smmla v11.4s, v6.16b, v4.16b\n" + "ldr q4, [x22, #0x40]\n" + ".inst 0x4e85a4c7 // smmla v7.4s, v6.16b, v5.16b\n" + "ldr q5, [x22, #0x50]\n" + "shl v6.16b, v9.16b, #0x4\n" + "and v9.16b, v9.16b, v3.16b\n" + ".inst 0x4e86a42d // smmla v13.4s, v1.16b, v6.16b\n" + ".inst 0x4e86a59a // smmla v26.4s, v12.16b, v6.16b\n" + ".inst 0x4e86a638 // smmla v24.4s, v17.16b, v6.16b\n" + ".inst 0x4e86a6dd // smmla v29.4s, v22.16b, v6.16b\n" + "shl v6.16b, v18.16b, #0x4\n" + "and v18.16b, v18.16b, v3.16b\n" + ".inst 0x4e86a439 // smmla v25.4s, v1.16b, v6.16b\n" + ".inst 0x4e86a58a // smmla v10.4s, v12.16b, v6.16b\n" + ".inst 0x4e86a63c // smmla v28.4s, v17.16b, v6.16b\n" + ".inst 0x4e86a6df // smmla v31.4s, v22.16b, v6.16b\n" + "shl v6.16b, v27.16b, #0x4\n" + ".inst 0x4e95a48d // smmla v13.4s, v4.16b, v21.16b\n" + ".inst 0x4e95a4ba // smmla v26.4s, v5.16b, v21.16b\n" + "and v27.16b, v27.16b, v3.16b\n" + ".inst 0x4e86a42e // smmla v14.4s, v1.16b, v6.16b\n" + ".inst 0x4e86a59e // smmla v30.4s, v12.16b, v6.16b\n" + ".inst 0x4e86a620 // smmla v0.4s, v17.16b, v6.16b\n" + ".inst 0x4e86a6cb // smmla v11.4s, v22.16b, v6.16b\n" + "shl v6.16b, v8.16b, #0x4\n" + ".inst 0x4e94a499 // smmla v25.4s, v4.16b, v20.16b\n" + ".inst 0x4e94a4aa // smmla v10.4s, v5.16b, v20.16b\n" + "and v8.16b, v8.16b, v3.16b\n" + ".inst 0x4e86a430 // smmla v16.4s, v1.16b, v6.16b\n" + "ldr q1, [x20, #0x40]\n" + ".inst 0x4e86a593 // smmla v19.4s, v12.16b, v6.16b\n" + "ldr q12, [x20, #0x50]\n" + ".inst 0x4e86a62f // smmla v15.4s, v17.16b, v6.16b\n" + "ldr q17, [x22, #0x60]\n" + ".inst 0x4e86a6c7 // smmla v7.4s, v22.16b, v6.16b\n" + "ldr q22, [x22, #0x70]\n" + "ldr q6, [x20, #0x60]\n" + ".inst 0x4e82a48e // smmla v14.4s, v4.16b, v2.16b\n" + ".inst 0x4e82a4be // smmla v30.4s, v5.16b, v2.16b\n" + "add x22, x22, #0x80\n" + ".inst 0x4e95a438 // smmla v24.4s, v1.16b, v21.16b\n" + ".inst 0x4e94a43c // smmla v28.4s, v1.16b, v20.16b\n" + ".inst 0x4e97a490 // smmla v16.4s, v4.16b, v23.16b\n" + "ldr q4, [x20, #0x70]\n" + ".inst 0x4e97a4b3 // smmla v19.4s, v5.16b, v23.16b\n" + "add x20, x20, #0x80\n" + ".inst 0x4e82a420 // smmla v0.4s, v1.16b, v2.16b\n" + ".inst 0x4e97a42f // smmla v15.4s, v1.16b, v23.16b\n" + ".inst 0x4e95a59d // smmla v29.4s, v12.16b, v21.16b\n" + ".inst 0x4e94a59f // smmla v31.4s, v12.16b, v20.16b\n" + ".inst 0x4e82a58b // smmla v11.4s, v12.16b, v2.16b\n" + ".inst 0x4e97a587 // smmla v7.4s, v12.16b, v23.16b\n" + ".inst 0x4e89a62d // smmla v13.4s, v17.16b, v9.16b\n" + ".inst 0x4e92a639 // smmla v25.4s, v17.16b, v18.16b\n" + ".inst 0x4e9ba62e // smmla v14.4s, v17.16b, v27.16b\n" + ".inst 0x4e88a630 // smmla v16.4s, v17.16b, v8.16b\n" + ".inst 0x4e89a6da // smmla v26.4s, v22.16b, v9.16b\n" + ".inst 0x4e92a6ca // smmla v10.4s, v22.16b, v18.16b\n" + ".inst 0x4e9ba6de // smmla v30.4s, v22.16b, v27.16b\n" + ".inst 0x4e88a6d3 // smmla v19.4s, v22.16b, v8.16b\n" + ".inst 0x4e89a4d8 // smmla v24.4s, v6.16b, v9.16b\n" + ".inst 0x4e92a4dc // smmla v28.4s, v6.16b, v18.16b\n" + ".inst 0x4e9ba4c0 // smmla v0.4s, v6.16b, v27.16b\n" + ".inst 0x4e88a4cf // smmla v15.4s, v6.16b, v8.16b\n" + ".inst 0x4e89a49d // smmla v29.4s, v4.16b, v9.16b\n" + ".inst 0x4e92a49f // smmla v31.4s, v4.16b, v18.16b\n" + ".inst 0x4e9ba48b // smmla v11.4s, v4.16b, v27.16b\n" + ".inst 0x4e88a487 // smmla v7.4s, v4.16b, v8.16b\n" + "bgt 3b\n" + "ldr q18, [x10, #0x0]\n" + "ldr q2, [x10, #0x10]\n" + "uzp1 v4.2d, v13.2d, v25.2d\n" + "uzp1 v5.2d, v14.2d, v16.2d\n" + "ldr q22, [x22, #0x0]\n" + "ldr q27, [x20, #0x0]\n" + "uzp2 v1.2d, v13.2d, v25.2d\n" + "uzp2 v20.2d, v14.2d, v16.2d\n" + "ldr q17, [x10, #0x20]\n" + "ldr q6, [x10, #0x30]\n" + "uzp1 v9.2d, v26.2d, v10.2d\n" + "uzp1 v13.2d, v30.2d, v19.2d\n" + "ldr q23, [x22, #0x10]\n" + "ldr q12, [x20, #0x10]\n" + "uzp2 v21.2d, v26.2d, v10.2d\n" + "uzp2 v25.2d, v30.2d, v19.2d\n" + "ld1r { v8.4s }, [%x[clamp_vals]]\n" + "uzp1 v16.2d, v24.2d, v28.2d\n" + "uzp1 v10.2d, v0.2d, v15.2d\n" + "mla v4.4s, v18.4s, v22.s[0]\n" + "uzp2 v30.2d, v24.2d, v28.2d\n" + "uzp2 v28.2d, v0.2d, v15.2d\n" + "mla v5.4s, v2.4s, v22.s[0]\n" + "add x20, %x[clamp_vals], #0x4\n" + "ld1r { v24.4s }, [x20]\n" + "uzp1 v14.2d, v29.2d, v31.2d\n" + "uzp1 v26.2d, v11.2d, v7.2d\n" + "mla v1.4s, v18.4s, v22.s[1]\n" + "uzp2 v0.2d, v29.2d, v31.2d\n" + "uzp2 v11.2d, v11.2d, v7.2d\n" + "mla v20.4s, v2.4s, v22.s[1]\n" + "cmp x9, #0x8\n" + "mla v9.4s, v18.4s, v22.s[2]\n" + "mla v13.4s, v2.4s, v22.s[2]\n" + "scvtf v4.4s, v4.4s\n" + "add x10, x10, #0x40\n" + "mla v21.4s, v18.4s, v22.s[3]\n" + "mla v25.4s, v2.4s, v22.s[3]\n" + "fmul v19.4s, v17.4s, v23.s[0]\n" + "mla v16.4s, v18.4s, v27.s[0]\n" + "mla v10.4s, v2.4s, v27.s[0]\n" + "scvtf v5.4s, v5.4s\n" + "mla v30.4s, v18.4s, v27.s[1]\n" + "mla v28.4s, v2.4s, v27.s[1]\n" + "fmul v15.4s, v6.4s, v23.s[0]\n" + "mla v14.4s, v18.4s, v27.s[2]\n" + "mla v26.4s, v2.4s, v27.s[2]\n" + "scvtf v1.4s, v1.4s\n" + "mla v0.4s, v18.4s, v27.s[3]\n" + "mla v11.4s, v2.4s, v27.s[3]\n" + "fmul v22.4s, v17.4s, v23.s[1]\n" + "scvtf v20.4s, v20.4s\n" + "fmul v29.4s, v6.4s, v23.s[1]\n" + "scvtf v9.4s, v9.4s\n" + "fmul v2.4s, v17.4s, v23.s[2]\n" + "scvtf v13.4s, v13.4s\n" + "fmul v18.4s, v6.4s, v23.s[2]\n" + "scvtf v21.4s, v21.4s\n" + "fmul v31.4s, v17.4s, v23.s[3]\n" + "scvtf v25.4s, v25.4s\n" + "fmul v7.4s, v6.4s, v23.s[3]\n" + "scvtf v16.4s, v16.4s\n" + "fmul v27.4s, v17.4s, v12.s[0]\n" + "scvtf v10.4s, v10.4s\n" + "fmul v23.4s, v6.4s, v12.s[0]\n" + "scvtf v30.4s, v30.4s\n" + "scvtf v28.4s, v28.4s\n" + "scvtf v14.4s, v14.4s\n" + "scvtf v26.4s, v26.4s\n" + "scvtf v0.4s, v0.4s\n" + "scvtf v11.4s, v11.4s\n" + "fmul v4.4s, v4.4s, v19.4s\n" + "fmul v19.4s, v17.4s, v12.s[1]\n" + "fmul v5.4s, v5.4s, v15.4s\n" + "fmul v15.4s, v6.4s, v12.s[1]\n" + "fmul v1.4s, v1.4s, v22.4s\n" + "fmul v22.4s, v17.4s, v12.s[2]\n" + "fmul v17.4s, v17.4s, v12.s[3]\n" + "fmul v20.4s, v20.4s, v29.4s\n" + "fmul v29.4s, v6.4s, v12.s[2]\n" + "fmul v12.4s, v6.4s, v12.s[3]\n" + "fmul v9.4s, v9.4s, v2.4s\n" + "fmul v13.4s, v13.4s, v18.4s\n" + "fmul v21.4s, v21.4s, v31.4s\n" + "fmul v25.4s, v25.4s, v7.4s\n" + "fmul v16.4s, v16.4s, v27.4s\n" + "fmul v10.4s, v10.4s, v23.4s\n" + "fmul v30.4s, v30.4s, v19.4s\n" + "fmul v28.4s, v28.4s, v15.4s\n" + "fmul v14.4s, v14.4s, v22.4s\n" + "fmul v26.4s, v26.4s, v29.4s\n" + "fmul v0.4s, v0.4s, v17.4s\n" + "fmul v11.4s, v11.4s, v12.4s\n" + "fmax v4.4s, v4.4s, v8.4s\n" + "fmax v5.4s, v5.4s, v8.4s\n" + "fmax v1.4s, v1.4s, v8.4s\n" + "fmax v20.4s, v20.4s, v8.4s\n" + "fmax v9.4s, v9.4s, v8.4s\n" + "fmax v13.4s, v13.4s, v8.4s\n" + "fmax v21.4s, v21.4s, v8.4s\n" + "fmax v25.4s, v25.4s, v8.4s\n" + "fmax v16.4s, v16.4s, v8.4s\n" + "fmax v10.4s, v10.4s, v8.4s\n" + "fmax v30.4s, v30.4s, v8.4s\n" + "fmax v28.4s, v28.4s, v8.4s\n" + "fmax v14.4s, v14.4s, v8.4s\n" + "fmax v26.4s, v26.4s, v8.4s\n" + "fmax v0.4s, v0.4s, v8.4s\n" + "fmax v11.4s, v11.4s, v8.4s\n" + "fmin v4.4s, v4.4s, v24.4s\n" + "fmin v5.4s, v5.4s, v24.4s\n" + "fmin v1.4s, v1.4s, v24.4s\n" + "fmin v20.4s, v20.4s, v24.4s\n" + "fmin v9.4s, v9.4s, v24.4s\n" + "fmin v13.4s, v13.4s, v24.4s\n" + "fmin v21.4s, v21.4s, v24.4s\n" + "fmin v25.4s, v25.4s, v24.4s\n" + "fmin v16.4s, v16.4s, v24.4s\n" + "fmin v10.4s, v10.4s, v24.4s\n" + "fmin v30.4s, v30.4s, v24.4s\n" + "fmin v28.4s, v28.4s, v24.4s\n" + "fmin v14.4s, v14.4s, v24.4s\n" + "fmin v26.4s, v26.4s, v24.4s\n" + "fmin v0.4s, v0.4s, v24.4s\n" + "fmin v11.4s, v11.4s, v24.4s\n" + "bge 8f\n" + "mov x27, %x[dst]\n" + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x27, %x[dst_stride_row], LSL #1\n" + "add x21, x27, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "tbz x9, #2, 5f\n" + "st1 { v0.4s }, [x23], #0x10\n" + "st1 { v14.4s }, [x25], #0x10\n" + "st1 { v30.4s }, [x24], #0x10\n" + "st1 { v16.4s }, [x26], #0x10\n" + "st1 { v21.4s }, [x20], #0x10\n" + "st1 { v9.4s }, [x22], #0x10\n" + "st1 { v1.4s }, [x21], #0x10\n" + "st1 { v4.4s }, [x27], #0x10\n" + "tbz x9, #1, 4f\n" + "str d11, [x23], #0x8\n" + "str d26, [x25], #0x8\n" + "str d28, [x24], #0x8\n" + "str d10, [x26], #0x8\n" + "str d25, [x20], #0x8\n" + "str d13, [x22], #0x8\n" + "str d20, [x21], #0x8\n" + "str d5, [x27], #0x8\n" + "tbz x9, #0, 7f\n" + "st1 { v11.s }[2], [x23]\n" + "st1 { v26.s }[2], [x25]\n" + "st1 { v28.s }[2], [x24]\n" + "st1 { v10.s }[2], [x26]\n" + "st1 { v25.s }[2], [x20]\n" + "st1 { v13.s }[2], [x22]\n" + "st1 { v20.s }[2], [x21]\n" + "st1 { v5.s }[2], [x27]\n" + "b 7f\n" + "4:" // Output block 0: partial_1_4 + "tbz x9, #0, 7f\n" + "str s11, [x23, #0x0]\n" + "str s26, [x25, #0x0]\n" + "str s28, [x24, #0x0]\n" + "str s10, [x26, #0x0]\n" + "str s25, [x20, #0x0]\n" + "str s13, [x22, #0x0]\n" + "str s20, [x21, #0x0]\n" + "str s5, [x27, #0x0]\n" + "b 7f\n" + "5:" // Output block 0: partial_2_0 + "tbz x9, #1, 6f\n" + "str d0, [x23], #0x8\n" + "str d14, [x25], #0x8\n" + "str d30, [x24], #0x8\n" + "str d16, [x26], #0x8\n" + "str d21, [x20], #0x8\n" + "str d9, [x22], #0x8\n" + "str d1, [x21], #0x8\n" + "str d4, [x27], #0x8\n" + "tbz x9, #0, 7f\n" + "st1 { v0.s }[2], [x23]\n" + "st1 { v14.s }[2], [x25]\n" + "st1 { v30.s }[2], [x24]\n" + "st1 { v16.s }[2], [x26]\n" + "st1 { v21.s }[2], [x20]\n" + "st1 { v9.s }[2], [x22]\n" + "st1 { v1.s }[2], [x21]\n" + "st1 { v4.s }[2], [x27]\n" + "b 7f\n" + "6:" // Output block 0: partial_1_0 + "str s0, [x23, #0x0]\n" + "str s14, [x25, #0x0]\n" + "str s30, [x24, #0x0]\n" + "str s16, [x26, #0x0]\n" + "str s21, [x20, #0x0]\n" + "str s9, [x22, #0x0]\n" + "str s1, [x21, #0x0]\n" + "str s4, [x27, #0x0]\n" + "7:" // Output block 0: Done + "b 9f\n" + "8:" // Full output + "mov x20, %x[dst]\n" + "str q4, [x20, #0x0]\n" + "str q5, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q1, [x20, #0x0]\n" + "str q20, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q9, [x20, #0x0]\n" + "str q13, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q21, [x20, #0x0]\n" + "str q25, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q16, [x20, #0x0]\n" + "str q10, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q30, [x20, #0x0]\n" + "str q28, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q14, [x20, #0x0]\n" + "str q26, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q0, [x20, #0x0]\n" + "str q11, [x20, #0x10]\n" + "9:" // Output stage exit + "subs x9, x9, #0x8\n" + "add %x[dst], %x[dst], #0x20\n" + "bgt 2b\n" + "mov x20, #0x2\n" + "sub x12, x12, #0x8\n" + "cmp x12, #0x8\n" + "mov %x[dst], x28\n" + "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" + "bge 1b\n" + "10:" // Row loop skip + "cbz x12, 20f\n" + "11:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "12:" // Row tail: Column loop + "movi v13.4s, #0x0\n" + "movi v14.4s, #0x0\n" + "mov x22, %x[lhs_packed]\n" + "mov x20, %x[num_blocks]\n" + "movi v25.4s, #0x0\n" + "movi v16.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "13:" // Row tail: Block loop + "ldr q4, [x26, #0x0]\n" + "ldr q8, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q2, [x26, #0x20]\n" + "ldr q11, [x26, #0x30]\n" + "ldr q18, [x22, #0x0]\n" + "ldr q15, [x22, #0x10]\n" + "ldr q12, [x26, #0x40]\n" + "ldr q6, [x26, #0x50]\n" + "shl v9.16b, v4.16b, #0x4\n" + "shl v22.16b, v8.16b, #0x4\n" + "ldr q28, [x26, #0x60]\n" + "ldr q27, [x26, #0x70]\n" + "shl v17.16b, v2.16b, #0x4\n" + "shl v23.16b, v11.16b, #0x4\n" + "ldr q31, [x22, #0x20]\n" + "ldr q7, [x22, #0x30]\n" + "and v4.16b, v4.16b, v3.16b\n" + "and v8.16b, v8.16b, v3.16b\n" + "ldr q24, [x22, #0x40]\n" + "ldr q1, [x22, #0x50]\n" + ".inst 0x4e89a64d // smmla v13.4s, v18.16b, v9.16b\n" + ".inst 0x4e96a659 // smmla v25.4s, v18.16b, v22.16b\n" + "ldr q21, [x22, #0x60]\n" + "ldr q20, [x22, #0x70]\n" + ".inst 0x4e91a64e // smmla v14.4s, v18.16b, v17.16b\n" + ".inst 0x4e97a650 // smmla v16.4s, v18.16b, v23.16b\n" + ".inst 0x4e89a5fa // smmla v26.4s, v15.16b, v9.16b\n" + ".inst 0x4e96a5ea // smmla v10.4s, v15.16b, v22.16b\n" + "shl v22.16b, v12.16b, #0x4\n" + "add x22, x22, #0x80\n" + ".inst 0x4e91a5fe // smmla v30.4s, v15.16b, v17.16b\n" + ".inst 0x4e97a5f3 // smmla v19.4s, v15.16b, v23.16b\n" + "shl v17.16b, v6.16b, #0x4\n" + "add x26, x26, #0x80\n" + "shl v23.16b, v28.16b, #0x4\n" + "shl v5.16b, v27.16b, #0x4\n" + ".inst 0x4e96a7ed // smmla v13.4s, v31.16b, v22.16b\n" + "and v2.16b, v2.16b, v3.16b\n" + "and v11.16b, v11.16b, v3.16b\n" + ".inst 0x4e91a7f9 // smmla v25.4s, v31.16b, v17.16b\n" + ".inst 0x4e96a4fa // smmla v26.4s, v7.16b, v22.16b\n" + ".inst 0x4e91a4ea // smmla v10.4s, v7.16b, v17.16b\n" + "and v12.16b, v12.16b, v3.16b\n" + ".inst 0x4e97a7ee // smmla v14.4s, v31.16b, v23.16b\n" + ".inst 0x4e85a7f0 // smmla v16.4s, v31.16b, v5.16b\n" + "and v6.16b, v6.16b, v3.16b\n" + ".inst 0x4e97a4fe // smmla v30.4s, v7.16b, v23.16b\n" + ".inst 0x4e85a4f3 // smmla v19.4s, v7.16b, v5.16b\n" + "and v28.16b, v28.16b, v3.16b\n" + ".inst 0x4e84a70d // smmla v13.4s, v24.16b, v4.16b\n" + ".inst 0x4e88a719 // smmla v25.4s, v24.16b, v8.16b\n" + "and v27.16b, v27.16b, v3.16b\n" + ".inst 0x4e84a43a // smmla v26.4s, v1.16b, v4.16b\n" + ".inst 0x4e88a42a // smmla v10.4s, v1.16b, v8.16b\n" + ".inst 0x4e82a70e // smmla v14.4s, v24.16b, v2.16b\n" + ".inst 0x4e8ba710 // smmla v16.4s, v24.16b, v11.16b\n" + ".inst 0x4e82a43e // smmla v30.4s, v1.16b, v2.16b\n" + ".inst 0x4e8ba433 // smmla v19.4s, v1.16b, v11.16b\n" + ".inst 0x4e8ca6ad // smmla v13.4s, v21.16b, v12.16b\n" + ".inst 0x4e86a6b9 // smmla v25.4s, v21.16b, v6.16b\n" + ".inst 0x4e8ca69a // smmla v26.4s, v20.16b, v12.16b\n" + ".inst 0x4e86a68a // smmla v10.4s, v20.16b, v6.16b\n" + ".inst 0x4e9ca6ae // smmla v14.4s, v21.16b, v28.16b\n" + ".inst 0x4e9ba6b0 // smmla v16.4s, v21.16b, v27.16b\n" + ".inst 0x4e9ca69e // smmla v30.4s, v20.16b, v28.16b\n" + ".inst 0x4e9ba693 // smmla v19.4s, v20.16b, v27.16b\n" + "bgt 13b\n" + "ldr q5, [x26, #0x0]\n" + "ldr q20, [x26, #0x10]\n" + "uzp1 v2.2d, v13.2d, v25.2d\n" + "uzp1 v21.2d, v14.2d, v16.2d\n" + "ldr q6, [x22, #0x0]\n" + "ldr q1, [x26, #0x20]\n" + "uzp2 v4.2d, v13.2d, v25.2d\n" + "uzp2 v28.2d, v14.2d, v16.2d\n" + "ldr q7, [x26, #0x30]\n" + "ldr q17, [x22, #0x10]\n" + "uzp1 v29.2d, v26.2d, v10.2d\n" + "uzp1 v15.2d, v30.2d, v19.2d\n" + "ld1r { v27.4s }, [%x[clamp_vals]]\n" + "uzp2 v26.2d, v26.2d, v10.2d\n" + "uzp2 v25.2d, v30.2d, v19.2d\n" + "add x20, %x[clamp_vals], #0x4\n" + "ld1r { v19.4s }, [x20]\n" + "mla v2.4s, v5.4s, v6.s[0]\n" + "mla v21.4s, v20.4s, v6.s[0]\n" + "cmp x25, #0x8\n" + "mla v4.4s, v5.4s, v6.s[1]\n" + "mla v28.4s, v20.4s, v6.s[1]\n" + "fmul v23.4s, v1.4s, v17.s[0]\n" + "add x26, x26, #0x40\n" + "mla v29.4s, v5.4s, v6.s[2]\n" + "mla v15.4s, v20.4s, v6.s[2]\n" + "fmul v31.4s, v7.4s, v17.s[0]\n" + "mla v26.4s, v5.4s, v6.s[3]\n" + "mla v25.4s, v20.4s, v6.s[3]\n" + "fmul v22.4s, v1.4s, v17.s[1]\n" + "scvtf v2.4s, v2.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v4.4s, v4.4s\n" + "scvtf v28.4s, v28.4s\n" + "fmul v20.4s, v7.4s, v17.s[1]\n" + "scvtf v29.4s, v29.4s\n" + "fmul v24.4s, v1.4s, v17.s[2]\n" + "scvtf v15.4s, v15.4s\n" + "fmul v10.4s, v7.4s, v17.s[2]\n" + "scvtf v26.4s, v26.4s\n" + "fmul v0.4s, v1.4s, v17.s[3]\n" + "scvtf v25.4s, v25.4s\n" + "fmul v8.4s, v7.4s, v17.s[3]\n" + "fmul v2.4s, v2.4s, v23.4s\n" + "fmul v21.4s, v21.4s, v31.4s\n" + "fmul v4.4s, v4.4s, v22.4s\n" + "fmul v28.4s, v28.4s, v20.4s\n" + "fmul v29.4s, v29.4s, v24.4s\n" + "fmul v15.4s, v15.4s, v10.4s\n" + "fmul v26.4s, v26.4s, v0.4s\n" + "fmul v25.4s, v25.4s, v8.4s\n" + "fmax v2.4s, v2.4s, v27.4s\n" + "fmax v21.4s, v21.4s, v27.4s\n" + "fmax v4.4s, v4.4s, v27.4s\n" + "fmax v28.4s, v28.4s, v27.4s\n" + "fmax v29.4s, v29.4s, v27.4s\n" + "fmax v15.4s, v15.4s, v27.4s\n" + "fmax v26.4s, v26.4s, v27.4s\n" + "fmax v25.4s, v25.4s, v27.4s\n" + "fmin v2.4s, v2.4s, v19.4s\n" + "fmin v21.4s, v21.4s, v19.4s\n" + "fmin v4.4s, v4.4s, v19.4s\n" + "fmin v28.4s, v28.4s, v19.4s\n" + "fmin v29.4s, v29.4s, v19.4s\n" + "fmin v15.4s, v15.4s, v19.4s\n" + "fmin v26.4s, v26.4s, v19.4s\n" + "fmin v25.4s, v25.4s, v19.4s\n" + "bge 18f\n" + "mov x23, %x[dst]\n" + "cmp x12, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GE\n" + "cmp x12, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GE\n" + "cmp x12, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GE\n" + "tbz x25, #2, 15f\n" + "st1 { v26.4s }, [x20], #0x10\n" + "st1 { v29.4s }, [x21], #0x10\n" + "st1 { v4.4s }, [x22], #0x10\n" + "st1 { v2.4s }, [x23], #0x10\n" + "tbz x25, #1, 14f\n" + "str d25, [x20], #0x8\n" + "str d15, [x21], #0x8\n" + "str d28, [x22], #0x8\n" + "str d21, [x23], #0x8\n" + "tbz x25, #0, 17f\n" + "st1 { v25.s }[2], [x20]\n" + "st1 { v15.s }[2], [x21]\n" + "st1 { v28.s }[2], [x22]\n" + "st1 { v21.s }[2], [x23]\n" + "b 17f\n" + "14:" // Row tail: Output block 0: partial_1_4 + "tbz x25, #0, 17f\n" + "str s25, [x20, #0x0]\n" + "str s15, [x21, #0x0]\n" + "str s28, [x22, #0x0]\n" + "str s21, [x23, #0x0]\n" + "b 17f\n" + "15:" // Row tail: Output block 0: partial_2_0 + "tbz x25, #1, 16f\n" + "str d26, [x20], #0x8\n" + "str d29, [x21], #0x8\n" + "str d4, [x22], #0x8\n" + "str d2, [x23], #0x8\n" + "tbz x25, #0, 17f\n" + "st1 { v26.s }[2], [x20]\n" + "st1 { v29.s }[2], [x21]\n" + "st1 { v4.s }[2], [x22]\n" + "st1 { v2.s }[2], [x23]\n" + "b 17f\n" + "16:" // Row tail: Output block 0: partial_1_0 + "str s26, [x20, #0x0]\n" + "str s29, [x21, #0x0]\n" + "str s4, [x22, #0x0]\n" + "str s2, [x23, #0x0]\n" + "17:" // Row tail: Output block 0: Done + "b 19f\n" + "18:" // Row tail: Full output + "mov x20, %x[dst]\n" + "cmp x12, #0x1\n" + "str q2, [x20, #0x0]\n" + "str q21, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 19f\n" + "cmp x12, #0x2\n" + "str q4, [x20, #0x0]\n" + "str q28, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 19f\n" + "cmp x12, #0x3\n" + "str q29, [x20, #0x0]\n" + "str q15, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 19f\n" + "str q26, [x20, #0x0]\n" + "str q25, [x20, #0x10]\n" + "19:" // Row tail: Output stage exit + "subs x25, x25, #0x8\n" + "add %x[dst], %x[dst], #0x20\n" + "bgt 12b\n" + "subs x12, x12, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x11\n" + "mov %x[dst], x24\n" + "bgt 11b\n" + "20:" // Row tail: Row loop skip + : [lhs_packed] "+&r"(lhs_packed), [dst] "+&r"(dst) + : [rhs_packed] "r"(rhs_packed), [clamp_vals] "r"(clamp_vals), [m] "r"(m), [num_blocks] "r"(num_blocks), + [dst_stride_row] "r"(dst_stride_row), [n] "r"(n) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +#else + KAI_ASSERT(false); + KAI_UNUSED(m); + KAI_UNUSED(n); + KAI_UNUSED(k); + KAI_UNUSED(lhs_packed); + KAI_UNUSED(rhs_packed); + KAI_UNUSED(dst); + KAI_UNUSED(dst_stride_row); + KAI_UNUSED(dst_stride_col); + KAI_UNUSED(scalar_min); + KAI_UNUSED(scalar_max); +#endif +} diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h new file mode 100644 index 0000000000000000000000000000000000000000..5d2e2d59c2f970a74074fe4ac1178587edb22f71 --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h @@ -0,0 +1,142 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifndef __cplusplus +#include +#endif +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Function to get the m step value. + * The micro-kernel can process any M values. However, the starting M index to + * be processed must be a multiple of m step. + * + * @return the m step value + */ +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void); + +/** + * @brief Function to get the n step value. + * The micro-kernel can process any N values. However, the starting N index to + * be processed must be a multiple of n step. + * + * @return the n step + */ +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void); + +/** + * @brief Function to get the mr value, which must be used to pack the LHS matrix with + * the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel + * + * @return the mr value + */ +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void); + +/** + * @brief Function to get the nr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the nr value + */ +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void); + +/** + * @brief Function to get the kr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the kr value + */ +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void); + +/** + * @brief Function to get the sr value, which must be used to pack the RHS matrix with + * the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel + * + * @return the sr value + */ +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(void); + +/** + * @brief Function to calculate the offset in bytes for the packed LHS matrix, + * which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. + * + * This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. + * + * @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 8 + * @param[in] k Total number of columns in the LHS matrix (not packed). + * + * return the offset in bytes to the packed LHS matrix + */ +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(size_t m_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the packed RHS matrix, + * which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. + * + * @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 8. + * @param[in] k The common dimension between the LHS and RHS matrix (K). + * + * return the offset in bytes to the packed RHS matrix + */ +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(size_t n_idx, size_t k); + +/** + * @brief Function to calculate the offset in bytes for the DST matrix + * + * @param[in] m_idx Row index in the DST matrix. It must be a multiple of 8. + * @param[in] n_idx Column index in the DST matrix. It must be multiple of 8. + * @param[in] dst_stride The number of bytes in in each row of the DST matrix + * + * return the DST offset in bytes + */ +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/** + * @brief Function to query the size in bytes for the destination matrix. + * + * @param[in] m Number of rows in the destination (DST) matrix. + * @param[in] n Number of columns in the destination (DST) matrix. + */ +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm(size_t m, size_t n); + +/** + * @brief Function to run the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. + * + * LHS matrix: Signed 8-bit quantized asymmetric per-row (qau8dx) and packed + * RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4cx) and packed. + * Output tile: (rows x cols) = 8 x 8 + * Accumulation performed in a single for loop: 32 + * Instruction used: i8mm + * + * @param[in] m The number of output rows written. + * @param[in] n The number of output columns written. + * @param[in] k The number of channels. The common dimension of LHS & RHS. + * @param[in] lhs_packed The LHS matrix packed. + * When the activation are dynamically quantized, you can obtain this matrix + * by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs + * both the dynamic quantization to 8-bit and activation packing in a single step. + * @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref + * kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 + * @param[out] dst Result of the vector-by-matrix + * @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. + * @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) + * @param[in] scalar_min Min value used to clamp the final result. + * @param[in] scalar_max Max value used to clamp the final result. + */ +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..d6f2c5c27176ba55503ed3a3ead0a5a2d2791839 --- /dev/null +++ b/src/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h @@ -0,0 +1,52 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f32_qai8dxp_qsi4cxp + +/** Micro-kernel helper functions ("get" methods) */ +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_dst_size_func_t)(size_t m, size_t n); + +/** Micro-kernel core function ("run" method) */ +typedef void (*kai_matmul_clamp_f32_qai8dxp_qsi4cxp_run_matmul_func_t)( + size_t m, size_t n, size_t k, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/** Micro-kernel interface */ +struct kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel { + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_m_step_func_t get_m_step; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_n_step_func_t get_n_step; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_mr_func_t get_mr; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_nr_func_t get_nr; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_nr_func_t get_kr; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_sr_func_t get_sr; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f32_qai8dxp_qsi4cxp_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif