diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 431edc88ba5750d1e76ca797d04846b01c60ab15..4e84a67d765f7ab4b3167825251ec092236472d0 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -122,6 +122,7 @@ build-examples: - matmul_clamp_f16_f16_f16p - matmul_clamp_f32_qai8dxp_qsi4cxp - matmul_clamp_f32_qsi8d32p_qsi4c32p + - matmul_clamp_f32_qai8dxp_qsi4c32p script: - mkdir -p build/$EXAMPLE - cmake -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release -S examples/$EXAMPLE -B build/$EXAMPLE @@ -143,6 +144,7 @@ test-examples: - matmul_clamp_f16_f16_f16p - matmul_clamp_f32_qai8dxp_qsi4cxp - matmul_clamp_f32_qsi8d32p_qsi4c32p + - matmul_clamp_f32_qai8dxp_qsi4c32p script: - build/${EXAMPLE}/${EXAMPLE} | tee -a ${EXAMPLE}.log artifacts: diff --git a/CHANGELOG.md b/CHANGELOG.md index af6db5b78ee84f4647809e0292908ffccb2cd3bc..5428c5421865b159eb9127f033ebdfb9ce4a9143 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## v0.3.0 - Upcoming Release - Advanced SIMD FP32 GEMM and GEMV micro kernels +- Micro-kernels to compute the matrix multiplication of dynamically quantized asymmetric signed 8-bit integer with per-row quantization (QAI8DX) LHS and quantized symmetric 4-bit signed integer with per-block quantization (QSI4C32) RHS. The destination matrix data type is single-precision floating-point (F32). The micro-kernels have been optimized using the ArmĀ® CPU feature FEAT_I8MM for the matrix-by-matrix cases and the FEAT_DotProd for the vector-by-matrix cases. +- RHS matrix packing micro-kernels to pack the RHS matrix holding the QSI4C32 values. +- Unit test and example for integer micro-kernels. ## v0.2.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index c0d42de619b5c80eb0abb7c20e9390e1b8bba93a..3e1986655afe15e2c5251f11e672c6c57cd19425 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,9 +76,11 @@ endif() set(KLEIDIAI_FILES_SCALAR kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c + kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c + kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c ) set(KLEIDIAI_FILES_NEON_FP16 @@ -97,6 +99,8 @@ set(KLEIDIAI_FILES_NEON_DOTPROD kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c ) set(KLEIDIAI_FILES_NEON_I8MM @@ -105,6 +109,8 @@ set(KLEIDIAI_FILES_NEON_I8MM kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c + kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c ) set(KLEIDIAI_FILES_SME @@ -194,6 +200,7 @@ if(KLEIDIAI_BUILD_TESTS) test/tests/matmul_test.cpp test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp + test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp ) target_link_libraries(kleidiai_test diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..47527352650601b1bdd5f9de44e938f29d4879c6 --- /dev/null +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt @@ -0,0 +1,33 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.16) + +set(CMAKE_CXX_STANDARD 17) +set(KLEIDIAI_PATH ../../) +set(MATMUL_PACK_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/pack/) +set(MATMUL_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/) + +# KleidiAI include directories +include_directories( + ${KLEIDIAI_PATH} + ${MATMUL_PACK_PATH} + ${MATMUL_PATH}) + +# Files requires to build the executable +add_executable(matmul_clamp_f32_qai8dxp_qsi4c32p + matmul_clamp_f32_qai8dxp_qsi4c32p.cpp + ${KLEIDIAI_PATH}/kai/kai_common.h + ${MATMUL_PACK_PATH}/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c + ${MATMUL_PACK_PATH}/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c + ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_qai8dxp_f32.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c + ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c) + +target_compile_options(matmul_clamp_f32_qai8dxp_qsi4c32p + PRIVATE -march=armv8.2-a+dotprod+i8mm) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp new file mode 100644 index 0000000000000000000000000000000000000000..87bdad8f8ffc74e7929a0180f8c4dafb70419766 --- /dev/null +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -0,0 +1,727 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_DOTPROD) && !defined(__ARM_FEATURE_MATMUL_INT8) +#error "Dotprod and I8mm extensions required to compile this example" +#else +#include +#include +#include +#include +#include +#include +#include +#include + +// Include micro-kernel variants +#include "kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" +#include "kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" +#include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" + +#define INT4_MIN (-8) +#define INT4_MAX (7) + +enum class rhs_format { + nxk, + kxn, +}; + +// Micro-kernel interface +struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel; + std::string name = {}; +}; + +kai_matmul_ukernel_f32_qa8dxp_qs4c32p ukernel_variants[] = { + {{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod}, + "matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod"}, + {{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}, + "matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod"}, + {{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm}, + "matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm"}, + {{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm}, + "matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm"}, +}; + +// Number of micro-kernel variants stored in the array +const size_t num_ukernel_variants = sizeof(ukernel_variants) / sizeof(ukernel_variants[0]); + +static size_t roundup(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + +static inline size_t get_num_blocks_per_row(size_t k, size_t bl) { + return roundup(k, bl) / bl; +} + +static inline size_t get_rhs_native_stride(size_t x) { + return roundup(x, 2) / 2; +} + +static inline size_t get_rhs_scale_stride(size_t k, size_t bl) { + const size_t num_blocks_per_row = get_num_blocks_per_row(k, bl); + return num_blocks_per_row * sizeof(uint16_t); +} + +static void fill_uniform_random(size_t num_rows, size_t num_cols, float* dst, size_t seed) { + std::srand(seed); + + // Fill the array with random values between -1 and 1 + for (size_t i = 0; i < num_rows * num_cols; i++) { + dst[i] = (float)((double)std::rand() / RAND_MAX) * 2 - 1; + } +} + +static void quant_nxk_qs4c32_f32( + size_t n, size_t k, size_t bl, const float* rhs_f32, uint8_t* rhs_qs4c32, uint16_t* rhs_scales_bf16) { + const size_t num_blocks_row = get_num_blocks_per_row(k, bl); + const size_t rhs_qs4c32_stride = get_rhs_native_stride(k); + + // Make sure the output is filled with zeros + std::memset(rhs_qs4c32, 0, n * rhs_qs4c32_stride); + + for (size_t row_idx = 0; row_idx < n; ++row_idx) { + const float* src_ptr = rhs_f32 + row_idx * k; + + for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { + float amax = 0.0f; + float max = 0.0f; + + for (size_t b = 0; b < bl; ++b) { + const size_t k_idx = block_idx * bl + b; + + if (k_idx >= k) { + break; + } + + const float src0_0 = src_ptr[k_idx]; + const float asrc0_0 = fabsf(src0_0); + + if (amax < asrc0_0) { + amax = asrc0_0; + max = src0_0; + } + } + + const float scale = max / -8.0; + const float recip_scale = scale ? 1.0f / scale : 0.0f; + + // Store the scale in the dedicated buffer + *rhs_scales_bf16 = kai_cast_bf16_f32(scale); + + rhs_scales_bf16 += 1; + + for (size_t i = 0; i < bl; ++i) { + const size_t k_idx = block_idx * bl + i; + + if (k_idx >= k) { + break; + } + + const float src0_0 = src_ptr[k_idx]; + + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * recip_scale)); + + // Maximum/minimum int4 values + v0_s32 = std::max(v0_s32, INT4_MIN); + v0_s32 = std::min(v0_s32, INT4_MAX); + + const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8); + + const size_t dst_addr = (k_idx / 2) + row_idx * rhs_qs4c32_stride; + uint8_t rhs_v0 = rhs_qs4c32[dst_addr]; + + if ((k_idx % 2) == 0) { + rhs_v0 = v0_u8; + } else { + rhs_v0 |= (v0_u8 << 4); + } + + rhs_qs4c32[dst_addr] = rhs_v0; + } + } + } +} + +static void quant_kxn_qs4c32_f32( + size_t n, size_t k, size_t bl, const float* rhs_f32, uint8_t* rhs_qs4c32, uint16_t* rhs_scales_bf16) { + const size_t num_blocks_row = get_num_blocks_per_row(k, bl); + const size_t rhs_qs4c32_stride = get_rhs_native_stride(n); + + // Make sure the output is filled with zeros + std::memset(rhs_qs4c32, 0, k * rhs_qs4c32_stride); + + for (size_t row_idx = 0; row_idx < n; ++row_idx) { + const float* src_ptr = rhs_f32 + row_idx * k; + + for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { + float amax = 0.0f; + float max = 0.0f; + + for (size_t b = 0; b < bl; ++b) { + const size_t k_idx = block_idx * bl + b; + + if (k_idx >= k) { + break; + } + + const float src0_0 = src_ptr[k_idx]; + const float asrc0_0 = fabsf(src0_0); + + if (amax < asrc0_0) { + amax = asrc0_0; + max = src0_0; + } + } + + const float scale = max / -8.0; + const float recip_scale = scale ? 1.0f / scale : 0.0f; + + // Store the scale in the dedicated buffer + *rhs_scales_bf16 = kai_cast_bf16_f32(scale); + + rhs_scales_bf16 += 1; + + for (size_t i = 0; i < bl; ++i) { + const size_t k_idx = block_idx * bl + i; + + if (k_idx >= k) { + break; + } + + const float src0_0 = src_ptr[k_idx]; + + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * recip_scale)); + + // Maximum/minimum int4 values + v0_s32 = std::max(v0_s32, INT4_MIN); + v0_s32 = std::min(v0_s32, INT4_MAX); + + const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8); + + const size_t dst_addr = (row_idx / 2) + k_idx * rhs_qs4c32_stride; + uint8_t rhs_v0 = rhs_qs4c32[dst_addr]; + + if ((row_idx % 2) == 0) { + rhs_v0 = v0_u8; + } else { + rhs_v0 |= (v0_u8 << 4); + } + + rhs_qs4c32[dst_addr] = rhs_v0; + } + } + } +} + +static void quant_qs4cx_f32( + size_t n, size_t k, size_t bl, rhs_format format, const float* rhs_f32, uint8_t* rhs_qs4c32, + uint16_t* rhs_scales_bf16) { + if (rhs_format::nxk == format) { + quant_nxk_qs4c32_f32(n, k, bl, rhs_f32, rhs_qs4c32, rhs_scales_bf16); + } else { + quant_kxn_qs4c32_f32(n, k, bl, rhs_f32, rhs_qs4c32, rhs_scales_bf16); + } +}; + +static void ref_quant_qa8dx_f32(size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { + const size_t dst_stride = (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + const float* src_ptr = lhs_f32 + row_idx * k; + + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + max0 = std::max(src0_0, max0); + min0 = std::min(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = std::min(0.0f, min0); + const float rmax0 = std::max(0.0f, max0); + + const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; + + zero_point0 = std::max(zero_point0, qmin); + zero_point0 = std::min(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = lrintf(zero_point0); + + int8_t* dst_ptr = (int8_t*)lhs_qa8dx + row_idx * dst_stride; + + // LHS offset at the beginning of the row + *((float*)(dst_ptr)) = recip_scale0; + dst_ptr += sizeof(float); + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + dst_ptr += sizeof(int32_t); + + // Quantize the channels + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + // Scale the values + int32_t v0_s32 = (int32_t)(round(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = std::max(v0_s32, INT8_MIN); + v0_s32 = std::min(v0_s32, INT8_MAX); + dst_ptr[0] = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + } +} + +static void ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4c32( + size_t m, size_t n, size_t k, size_t bl, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4c32, + const uint16_t* scale_bf16, float* dst_f32, float scalar_min, float scalar_max) { + const size_t num_blocks_row = get_num_blocks_per_row(k, bl); + + const size_t lhs_stride = k + sizeof(float) + sizeof(int32_t); + const size_t rhs_stride = get_rhs_native_stride(k); + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; + + for (size_t col_idx = 0; col_idx < n; ++col_idx) { + // Main f32 accumulator + float main_acc = 0.0f; + + const int8_t* lhs_ptr = lhs_ptr_start; + const uint8_t* rhs_ptr = rhs_qs4c32 + col_idx * rhs_stride; + + // Get the LHS quantization parameters stored at the + // beginning of each row + const float lhs_scale = *(const float*)lhs_ptr; + lhs_ptr += sizeof(float); + + const int32_t lhs_offset = *(const int32_t*)lhs_ptr; + lhs_ptr += sizeof(int32_t); + + for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { + const uint16_t rhs_scale_bf16 = scale_bf16[block_idx + col_idx * num_blocks_row]; + const float rhs_scale = kai_cast_f32_bf16(rhs_scale_bf16); + + int32_t iacc = 0; + + for (size_t i = 0; i < bl; ++i) { + const size_t k_idx = block_idx * bl + i; + + if (k_idx >= k) { + break; + } + + // Get the LHS values + const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; + + // Get the RHS values + const uint8_t rhs_byte = rhs_ptr[0]; + + // Unpack the RHS values + int32_t rhs_v0 = 0; + if ((k_idx % 2) == 0) { + rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8); + } else { + rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8); + } + + iacc += lhs_v0 * rhs_v0; + iacc += lhs_offset * rhs_v0; + + lhs_ptr += 1; + + // Increment only when k_idx is not a multiple of 2 + rhs_ptr += k_idx % 2; + } + + main_acc += iacc * rhs_scale; + } + + main_acc = main_acc * lhs_scale; + + // Clamp (min-max) operation + main_acc = std::max(main_acc, scalar_min); + main_acc = std::min(main_acc, scalar_max); + + dst_f32[0] = main_acc; + dst_f32 += 1; + } + } +}; + +static void ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4c32( + size_t m, size_t n, size_t k, size_t bl, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4c32, + const uint16_t* scale_bf16, float* dst_f32, float scalar_min, float scalar_max) { + const size_t num_blocks_row = get_num_blocks_per_row(k, bl); + + const size_t lhs_stride = k + sizeof(float) + sizeof(int32_t); + const size_t rhs_stride = get_rhs_native_stride(n); + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; + + for (size_t col_idx = 0; col_idx < n; ++col_idx) { + // Main f32 accumulator + float main_acc = 0.0f; + + const int8_t* lhs_ptr = lhs_ptr_start; + const uint8_t* rhs_ptr = rhs_qs4c32 + (col_idx / 2); + + // Get the LHS quantization parameters stored at the + // beginning of each row + const float lhs_scale = *(const float*)lhs_ptr; + lhs_ptr += sizeof(float); + + const int32_t lhs_offset = *(const int32_t*)lhs_ptr; + lhs_ptr += sizeof(int32_t); + + for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { + const uint16_t rhs_scale_bf16 = scale_bf16[block_idx + col_idx * num_blocks_row]; + const float rhs_scale = kai_cast_f32_bf16(rhs_scale_bf16); + + int32_t iacc = 0; + + for (size_t i = 0; i < bl; ++i) { + const size_t k_idx = block_idx * bl + i; + + if (k_idx >= k) { + break; + } + + // Get the LHS values + const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; + + // Get the RHS values + const uint8_t rhs_byte = rhs_ptr[0]; + + // Unpack the RHS values + int32_t rhs_v0 = 0; + if ((col_idx % 2) == 0) { + rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8); + } else { + rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8); + } + + iacc += lhs_v0 * rhs_v0; + iacc += lhs_offset * rhs_v0; + + lhs_ptr += 1; + rhs_ptr += rhs_stride; + } + + main_acc += iacc * rhs_scale; + } + + main_acc = main_acc * lhs_scale; + + // Clamp (min-max) operation + main_acc = std::max(main_acc, scalar_min); + main_acc = std::min(main_acc, scalar_max); + + dst_f32[0] = main_acc; + dst_f32 += 1; + } + } +}; + +static void ref_matmul_f32_qa8dx_qs4c32( + size_t m, size_t n, size_t k, size_t bl, rhs_format format, const int8_t* lhs_qa8dx, const uint8_t* rhs_qs4c32, + const uint16_t* rhs_scales_bf16, float* dst_f32, float scalar_min, float scalar_max) { + if (rhs_format::nxk == format) { + ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4c32( + m, n, k, bl, lhs_qa8dx, rhs_qs4c32, rhs_scales_bf16, dst_f32, scalar_min, scalar_max); + } else { + ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4c32( + m, n, k, bl, lhs_qa8dx, rhs_qs4c32, rhs_scales_bf16, dst_f32, scalar_min, scalar_max); + } +}; + +static bool is_output_correct(size_t num_rows, size_t num_cols, float tolerance, const float* ref, const float* act) { + bool is_valid = true; + + for (size_t i = 0; i < num_rows * num_cols; ++i) { + if (std::fabs(ref[i] - act[i]) > tolerance) { + const size_t x = i % num_cols; + const size_t y = i / num_cols; + printf("ERROR![%ld][%ld]: ref=%.5f vs. act=%.5f\n", y, x, ref[i], act[i]); + is_valid = false; + } + } + return is_valid; +} + +int main() { + const size_t m = 37; + const size_t n = 75; + const size_t k = 256; + const size_t bl = 64; + + const size_t seed_lhs = 4568; + const size_t seed_rhs = seed_lhs + 4; + + std::cout << "------------" << std::endl; + + // Iterate over the RHS format (NxK or KxN) + for (const rhs_format& format : {rhs_format::nxk, rhs_format::kxn}) { + std::cout << "Testing RHS format = " << (format == rhs_format::nxk ? "N x K" : "K x N") << std::endl; + + const size_t lhs_native_size_f32 = m * k * sizeof(float); + const size_t rhs_native_size_f32 = n * k * sizeof(float); + const size_t rhs_native_size_qs4c32 = + format == rhs_format::nxk ? n * get_rhs_native_stride(k) : k * get_rhs_native_stride(n); + const size_t rhs_scales_size_bf16 = n * get_rhs_scale_stride(k, bl); + + // Allocate the memory + uint8_t* lhs_native_mtx_f32 = new uint8_t[lhs_native_size_f32]; + uint8_t* rhs_native_mtx_f32 = new uint8_t[rhs_native_size_f32]; + uint8_t* rhs_native_mtx_qs4c32 = new uint8_t[rhs_native_size_qs4c32]; + uint8_t* rhs_scales_mtx_bf16 = new uint8_t[rhs_scales_size_bf16]; + + fill_uniform_random(m, k, (float*)lhs_native_mtx_f32, seed_lhs); + fill_uniform_random(n, k, (float*)rhs_native_mtx_f32, seed_rhs); + + quant_qs4cx_f32( + n, k, bl, // Dimensions + format, // Format (NxK or KxN) + (const float*)rhs_native_mtx_f32, // RHS (F32) + rhs_native_mtx_qs4c32, // RHS (QS4C32) + (uint16_t*)rhs_scales_mtx_bf16); // Scales (Bf16) + + delete[] rhs_native_mtx_f32; + + //----------- REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + // Memory sizes for the reference implementation + // After dynamically quantized the LHS matrix, we have the scale and offset for each + // row. The scale (f32) and offset (int32) are stored at the beginning of each row + const size_t lhs_ref_size_qa8dx = m * (k + sizeof(int32_t) + sizeof(float)); + const size_t dst_ref_size_f32 = m * n * sizeof(float); + + uint8_t* lhs_ref_mtx_qa8dx = new uint8_t[lhs_ref_size_qa8dx]; + uint8_t* dst_ref_mtx_f32 = new uint8_t[dst_ref_size_f32]; + + ref_quant_qa8dx_f32(m, k, (const float*)lhs_native_mtx_f32, (int8_t*)lhs_ref_mtx_qa8dx); + + ref_matmul_f32_qa8dx_qs4c32( + m, n, k, // Dimensions + bl, // Block length + format, // Format (NxK or KxN) + (const int8_t*)lhs_ref_mtx_qa8dx, // LHS + (const uint8_t*)rhs_native_mtx_qs4c32, // RHS + (const uint16_t*)rhs_scales_mtx_bf16, // Scale + (float*)dst_ref_mtx_f32, // DST + -FLT_MAX, FLT_MAX); // Min and max for the clamp operation + + // Remove the unnecessary buffer + delete[] lhs_ref_mtx_qa8dx; + + //----------- END REFERENCE IMPLEMENTATION + //------------------------------------ + //------------------------------------ + + //----------- MICRO-KERNELS TESTS + //------------------------------------ + //------------------------------------ + for (size_t idx_variant = 0; idx_variant < num_ukernel_variants; ++idx_variant) { + // Get the packing parameters + const size_t mr = ukernel_variants[idx_variant].ukernel.get_mr(); + const size_t nr = ukernel_variants[idx_variant].ukernel.get_nr(); + const size_t kr = ukernel_variants[idx_variant].ukernel.get_kr(); + const size_t sr = ukernel_variants[idx_variant].ukernel.get_sr(); + + // Get the size in bytes for the packed matrices + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr); + size_t rhs_packed_size = 0; + + if (format == rhs_format::nxk) { + rhs_packed_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, kai_dt_bf16); + + } else { + rhs_packed_size = + kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, kai_dt_bf16); + } + + const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); + + // Allocate the matrices + uint8_t* lhs_packed_mtx_qa8dx = new uint8_t[lhs_packed_size]; + uint8_t* rhs_packed_mtx_qs4c32 = new uint8_t[rhs_packed_size]; + uint8_t* dst_act_mtx_f32 = new uint8_t[dst_size]; + + memset(dst_act_mtx_f32, 0, dst_size); + + // If the RHS matrix contains constant values, the packing can be performed + // only once + if (format == rhs_format::nxk) { + struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = kai_dt_bf16; + + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + 1, n, k, // Dimensions + nr, kr, sr, // Packing arguments + bl, // Block length + (const uint8_t*)(rhs_native_mtx_qs4c32), // RHS + get_rhs_native_stride(k), // RHS stride + NULL, // Bias + rhs_scales_mtx_bf16, // Scale + get_rhs_scale_stride(k, bl), // Scale stride + rhs_packed_mtx_qs4c32, // RHS packed + 0, ¶ms); + + } else { + struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = kai_dt_bf16; + + kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + 1, n, k, // Dimensions + nr, kr, sr, // Packing arguments + bl, // Block length + (const uint8_t*)(rhs_native_mtx_qs4c32), // RHS + get_rhs_native_stride(n), // RHS stride + NULL, // Bias + rhs_scales_mtx_bf16, // Scale + get_rhs_scale_stride(k, bl), // Scale stride + rhs_packed_mtx_qs4c32, // RHS packed + 0, ¶ms); + } + + const auto time_s = std::chrono::high_resolution_clock::now(); + + // LHS packing + kai_run_lhs_quant_pack_qai8dxp_f32( + m, k, // Dimensions + mr, kr, sr, 0, // Packing arguments + (const float*)lhs_native_mtx_f32, // LHS + k * sizeof(float), // LHS stride + lhs_packed_mtx_qa8dx); // LHS packed + + // Matmul + { + const size_t dst_stride = n * sizeof(float); + const size_t lhs_offset = ukernel_variants[idx_variant].ukernel.get_lhs_packed_offset(0, k); + const size_t rhs_offset = ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(0, k, bl); + const size_t dst_offset = ukernel_variants[idx_variant].ukernel.get_dst_offset(0, 0, dst_stride); + + const void* lhs_ptr = (const void*)((const char*)lhs_packed_mtx_qa8dx + lhs_offset); + const void* rhs_ptr = (const void*)((const char*)rhs_packed_mtx_qs4c32 + rhs_offset); + float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + dst_offset); + + ukernel_variants[idx_variant].ukernel.run_matmul( + m, n, k, // Dimensions + bl, // Block length + lhs_ptr, // LHS packed + rhs_ptr, // RHS packed + dst_ptr, // DST + dst_stride, // DST stride (row) + sizeof(float), // DST stride (col) + -FLT_MAX, FLT_MAX // Min and max for the clamp operation + ); + } + + const auto time_e = std::chrono::high_resolution_clock::now(); + + const auto elap = std::chrono::duration_cast(time_e - time_s); + + const bool is_valid = + is_output_correct(m, n, 0.0001f, (const float*)dst_ref_mtx_f32, (const float*)dst_act_mtx_f32); + + std::cout << "TEST[" << idx_variant << "]: Dynamic quantization + matmul" << std::endl; + std::cout << "- ukernel: " << ukernel_variants[idx_variant].name << std::endl; + if (is_valid) { + std::cout << "- Status: PASSED" << std::endl; + std::cout << "- Performance: " << elap.count() << " us" << std::endl; + } else { + std::cout << "Status: FAILED" << std::endl; + } + std::cout << "------------" << std::endl; + delete[] lhs_packed_mtx_qa8dx; + delete[] rhs_packed_mtx_qs4c32; + delete[] dst_act_mtx_f32; + } + delete[] lhs_native_mtx_f32; + delete[] rhs_native_mtx_qs4c32; + delete[] rhs_scales_mtx_bf16; + delete[] dst_ref_mtx_f32; + } +} + +//----------- END MICRO-KERNELS TESTS +//------------------------------------ +//------------------------------------ + +#endif // Architectural feature check diff --git a/kai/kai_common.h b/kai/kai_common.h index 60df8162a86261222210b402c9c206871ad0140a..47831cd6ba81d4c003e6624e6bd1b9904b47e1e7 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -50,6 +50,30 @@ extern "C" { #define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b)) #define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b)) +/// KleidiAI data types +/// Format: (reserved)|(num-bytes)|(type)|(variant-type) +enum kai_datatype { + kai_dt_unknown = 0x0000, + kai_dt_f32 = 0x0411, + kai_dt_f16 = 0x0212, + kai_dt_bf16 = 0x0213, + kai_dt_int32 = 0x0421, + kai_dt_int16 = 0x0222, + kai_dt_int8 = 0x0124, + kai_dt_uint32 = 0x0431, + kai_dt_uint16 = 0x0232, + kai_dt_uint8 = 0x0134, + kai_dt_bool = 0x0441 +}; + +/// Gets number of bytes for a given data type +/// @param[in] dt KleidiAI data type +/// +/// @return the numbers of bytes for the data type +inline static size_t kai_get_datatype_size_in_bytes(enum kai_datatype dt) { + return (size_t)(dt >> 8); +} + /// Converts a scalar f16 value to f32 /// @param[in] f16 The f16 value /// @@ -62,6 +86,27 @@ inline static float kai_cast_f32_f16(uint16_t f16) { #endif } +/// Converts a scalar bf16 value to f32 +/// @param[in] bf16 The f16 value +/// +/// @return the f32 value +inline static float kai_cast_f32_bf16(uint16_t bf16) { + const uint32_t i32 = (bf16 << 16); + float f32; + memcpy(&f32, &i32, sizeof(i32)); + return f32; +} + +/// Converts a f32 value to bf16 +/// @param[in] f32 The f32 value +/// +/// @return the bf16 value +inline static uint16_t kai_cast_bf16_f32(float f32) { + const uint32_t* i32 = (uint32_t*)(&f32); + uint16_t bf16 = (*i32 >> 16); + return bf16; +} + /// Converts a scalar f32 value to f16 /// @param[in] f32 The f32 value /// diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel index 442451752d6c1a51f97a8a9dce989bf31bfb6cc4..7dbec37504b97a42251383a0e509a4b5ed0eeabf 100644 --- a/kai/ukernels/matmul/BUILD.bazel +++ b/kai/ukernels/matmul/BUILD.bazel @@ -208,6 +208,65 @@ kai_c_library( cpu_uarch = kai_cpu_neon(), ) +cc_library( + name = "clamp_f32_qai8dxp_qsi4c32p_interface", + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h"], +) + +kai_c_library( + name = "clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"], + cpu_uarch = kai_cpu_dotprod(), + deps = [ + ":clamp_f32_qai8dxp_qsi4c32p_interface", + ], +) + +kai_c_library( + name = "clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h"], + cpu_uarch = kai_cpu_dotprod(), + deps = [ + ":clamp_f32_qai8dxp_qsi4c32p_interface", + ], +) + +kai_c_library( + name = "clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h"], + cpu_uarch = kai_cpu_i8mm(), + deps = [ + ":clamp_f32_qai8dxp_qsi4c32p_interface", + ], +) + +kai_c_library( + name = "clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h"], + cpu_uarch = kai_cpu_i8mm(), + deps = [ + ":clamp_f32_qai8dxp_qsi4c32p_interface", + ], +) + +kai_c_library( + name = "rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", + srcs = ["pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c"], + hdrs = ["pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h"], + cpu_uarch = kai_cpu_neon(), +) + +kai_c_library( + name = "rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", + srcs = ["pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c"], + hdrs = ["pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h"], + cpu_uarch = kai_cpu_neon(), +) + kai_c_library( name = "matmul", deps = [ @@ -215,12 +274,17 @@ kai_c_library( ":clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla", ":clamp_f32_f32_f32p", ":clamp_f32_f32p_f32p", + ":clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod", + ":clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod", ":clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod", ":clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod", + ":clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm", + ":clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm", ":clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm", + ":clamp_f32_qai8dxp_qsi4c32p_interface", ":clamp_f32_qsi8d32p_qsi4c32p_dotprod", ":clamp_f32_qsi8d32p_qsi4c32p_i8mm", ":clamp_f32_qsi8d32p_qsi4c32p_interface", @@ -229,7 +293,9 @@ kai_c_library( ":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", ":rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", ":rhs_pack_kxn_f32pbiasf32_f32_f32_neon", + ":rhs_pack_kxn_qsi4c32p_qsu4c32s1s0", ":rhs_pack_kxn_qsi4cxp_qsu4cxs1s0", + ":rhs_pack_nxk_qsi4c32p_qsu4c32s1s0", ":rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0", ":rhs_pack_nxk_qsi4cxp_qsu4cxs1s0", ], diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c new file mode 100644 index 0000000000000000000000000000000000000000..12d9d6ead83cd7a1b6bf24ec535f1c2f2e32b80d --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c @@ -0,0 +1,225 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_DOTPROD) +#error "Dotprod extension required to compile this micro-kernel" +#else +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_bl_multiple_of = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + + return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t m, size_t n, size_t k, size_t bl, const void* restrict lhs_packed, const void* restrict rhs_packed, + float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + size_t num_subblocks = bl / kai_bl_multiple_of; + size_t num_blocks = kai_num_blocks_per_row(k, bl); + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x27, #0x20\n" + "mov x21, #0x3d800000\n" + "movi v0.16b, #0xf0\n" + "mov x20, #0x8\n" + "mov x26, %x[m]\n" + "mul x27, %x[num_subblocks], x27\n" + "dup v31.4s, w21\n" + "madd x27, %x[num_blocks], x27, x20\n" + "1:" // Row loop + "mov x25, %x[rhs_packed]\n" + "mov x24, %x[n]\n" + "add x23, %x[dst], %x[dst_stride_row]\n" + "2:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v30.16b, #0x0\n" + "mov x21, %x[num_blocks]\n" + "3:" // Block loop + "movi v29.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "4:" // Sub block loop + "ldr q27, [x25, #0x0]\n" + "ldr q26, [x25, #0x10]\n" + "subs x20, x20, #0x1\n" + "ld1r { v25.2d }, [x22], #0x8\n" + "ldr q24, [x25, #0x20]\n" + "ldr q23, [x25, #0x30]\n" + "add x25, x25, #0x40\n" + "ld1r { v22.2d }, [x22], #0x8\n" + "ld1r { v21.2d }, [x22], #0x8\n" + "shl v20.16b, v27.16b, #0x4\n" + "shl v19.16b, v26.16b, #0x4\n" + "ld1r { v18.2d }, [x22], #0x8\n" + "shl v17.16b, v24.16b, #0x4\n" + "and v27.16b, v27.16b, v0.16b\n" + "shl v16.16b, v23.16b, #0x4\n" + "and v26.16b, v26.16b, v0.16b\n" + ".inst 0x4e99969d // sdot v29.4s, v20.16b, v25.16b\n" + ".inst 0x4e99967c // sdot v28.4s, v19.16b, v25.16b\n" + "and v24.16b, v24.16b, v0.16b\n" + "and v23.16b, v23.16b, v0.16b\n" + ".inst 0x4e96963d // sdot v29.4s, v17.16b, v22.16b\n" + ".inst 0x4e96961c // sdot v28.4s, v16.16b, v22.16b\n" + ".inst 0x4e95977d // sdot v29.4s, v27.16b, v21.16b\n" + ".inst 0x4e95975c // sdot v28.4s, v26.16b, v21.16b\n" + ".inst 0x4e92971d // sdot v29.4s, v24.16b, v18.16b\n" + ".inst 0x4e9296fc // sdot v28.4s, v23.16b, v18.16b\n" + "bgt 4b\n" + "ldr d16, [x25, #0x0]\n" + "addp v29.4s, v29.4s, v28.4s\n" + "sub x21, x21, #0x1\n" + "add x25, x25, #0x8\n" + "shll v16.4s, v16.4h, #0x10\n" + "scvtf v29.4s, v29.4s\n" + "fmul v16.4s, v16.4s, v31.4s\n" + "fmla v30.4s, v29.4s, v16.4s\n" + "cbnz x21, 3b\n" + "ld1r { v21.4s }, [x22]\n" + "ldr q20, [x25, #0x0]\n" + "add x22, x22, #0x4\n" + "add x20, %x[clamp_vals], #0x4\n" + "ld1r { v19.4s }, [x22]\n" + "ldr q18, [x25, #0x10]\n" + "cmp x24, #0x4\n" + "add x25, x25, #0x20\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v21.4s, v21.4s\n" + "fmla v30.4s, v20.4s, v21.s[0]\n" + "fmul v30.4s, v30.4s, v19.4s\n" + "fadd v30.4s, v30.4s, v18.4s\n" + "fmax v30.4s, v30.4s, v17.4s\n" + "fmin v30.4s, v30.4s, v16.4s\n" + "blt 5f\n" + "str q30, [%x[dst], #0x0]\n" + "b 8f\n" + "5:" // Partial output + "mov x20, %x[dst]\n" + "tbz x24, #1, 6f\n" + "st1 { v30.d }[0], [x20], #0x8\n" + "tbz x24, #0, 7f\n" + "st1 { v30.s }[2], [x20]\n" + "b 7f\n" + "6:" // Output block 0: partial_1_0 + "st1 { v30.s }[0], [x20]\n" + "7:" // Output block 0: Done + "8:" // Stores done + "subs x24, x24, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "subs x26, x26, #0x1\n" + "add %x[lhs_packed], %x[lhs_packed], x27\n" + "mov %x[dst], x23\n" + "bgt 1b\n" + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27"); +} +#endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h new file mode 100644 index 0000000000000000000000000000000000000000..6fb4736eba44359e5a49aca0386929963dfa4320 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -0,0 +1,142 @@ + +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 1 +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. +/// Output tile: (rows x cols) = 1 x 4 +/// Accumulation performed in a single for loop: 32 +/// Extension used: dotprod +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] bl Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref +/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c new file mode 100644 index 0000000000000000000000000000000000000000..edd0971131a5624455f7eaee7fa932754c839549 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c @@ -0,0 +1,274 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_DOTPROD) +#error "Dotprod extension required to compile this micro-kernel" +#else +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 8; +static const size_t kai_mr = 1; +static const size_t kai_nr = 8; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_bl_multiple_of = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + + return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t m, size_t n, size_t k, size_t bl, const void* restrict lhs_packed, const void* restrict rhs_packed, + float* restrict dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + size_t num_subblocks = bl / kai_bl_multiple_of; + size_t num_blocks = kai_num_blocks_per_row(k, bl); + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x27, #0x20\n" + "mov x21, #0x3d800000\n" + "movi v8.16b, #0xf0\n" + "mov x20, #0x8\n" + "mov x26, %x[m]\n" + "mul x27, %x[num_subblocks], x27\n" + "dup v7.4s, w21\n" + "madd x27, %x[num_blocks], x27, x20\n" + "1:" // Row loop + "mov x25, %x[rhs_packed]\n" + "mov x24, %x[n]\n" + "add x23, %x[dst], %x[dst_stride_row]\n" + "2:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v6.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "mov x21, %x[num_blocks]\n" + "3:" // Block loop + "movi v4.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v2.4s, #0x0\n" + "movi v1.4s, #0x0\n" + "4:" // Sub block loop + "ldr q0, [x25, #0x0]\n" + "ldr q31, [x25, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q30, [x25, #0x20]\n" + "ldr q29, [x25, #0x30]\n" + "ld1r { v28.2d }, [x22], #0x8\n" + "ldr q27, [x25, #0x40]\n" + "ldr q26, [x25, #0x50]\n" + "ldr q25, [x25, #0x60]\n" + "shl v24.16b, v0.16b, #0x4\n" + "shl v18.16b, v31.16b, #0x4\n" + "ldr q23, [x25, #0x70]\n" + "shl v17.16b, v30.16b, #0x4\n" + "shl v16.16b, v29.16b, #0x4\n" + "add x25, x25, #0x80\n" + "ld1r { v22.2d }, [x22], #0x8\n" + "shl v21.16b, v27.16b, #0x4\n" + "and v0.16b, v0.16b, v8.16b\n" + "ld1r { v20.2d }, [x22], #0x8\n" + "ld1r { v19.2d }, [x22], #0x8\n" + ".inst 0x4e9c9704 // sdot v4.4s, v24.16b, v28.16b\n" + ".inst 0x4e9c9643 // sdot v3.4s, v18.16b, v28.16b\n" + "shl v18.16b, v26.16b, #0x4\n" + ".inst 0x4e9c9622 // sdot v2.4s, v17.16b, v28.16b\n" + ".inst 0x4e9c9601 // sdot v1.4s, v16.16b, v28.16b\n" + "shl v17.16b, v25.16b, #0x4\n" + "shl v16.16b, v23.16b, #0x4\n" + "and v31.16b, v31.16b, v8.16b\n" + "and v30.16b, v30.16b, v8.16b\n" + "and v29.16b, v29.16b, v8.16b\n" + ".inst 0x4e9696a4 // sdot v4.4s, v21.16b, v22.16b\n" + ".inst 0x4e969643 // sdot v3.4s, v18.16b, v22.16b\n" + "and v27.16b, v27.16b, v8.16b\n" + ".inst 0x4e969622 // sdot v2.4s, v17.16b, v22.16b\n" + ".inst 0x4e969601 // sdot v1.4s, v16.16b, v22.16b\n" + "and v26.16b, v26.16b, v8.16b\n" + "and v25.16b, v25.16b, v8.16b\n" + "and v23.16b, v23.16b, v8.16b\n" + ".inst 0x4e949404 // sdot v4.4s, v0.16b, v20.16b\n" + ".inst 0x4e9497e3 // sdot v3.4s, v31.16b, v20.16b\n" + ".inst 0x4e9497c2 // sdot v2.4s, v30.16b, v20.16b\n" + ".inst 0x4e9497a1 // sdot v1.4s, v29.16b, v20.16b\n" + ".inst 0x4e939764 // sdot v4.4s, v27.16b, v19.16b\n" + ".inst 0x4e939743 // sdot v3.4s, v26.16b, v19.16b\n" + ".inst 0x4e939722 // sdot v2.4s, v25.16b, v19.16b\n" + ".inst 0x4e9396e1 // sdot v1.4s, v23.16b, v19.16b\n" + "bgt 4b\n" + "ldr q16, [x25, #0x0]\n" + "addp v4.4s, v4.4s, v3.4s\n" + "addp v2.4s, v2.4s, v1.4s\n" + "sub x21, x21, #0x1\n" + "add x25, x25, #0x10\n" + "shll v17.4s, v16.4h, #0x10\n" + "shll2 v16.4s, v16.8h, #0x10\n" + "scvtf v4.4s, v4.4s\n" + "scvtf v2.4s, v2.4s\n" + "fmul v17.4s, v17.4s, v7.4s\n" + "fmul v16.4s, v16.4s, v7.4s\n" + "fmla v6.4s, v4.4s, v17.4s\n" + "fmla v5.4s, v2.4s, v16.4s\n" + "cbnz x21, 3b\n" + "ld1r { v23.4s }, [x22]\n" + "ldr q22, [x25, #0x0]\n" + "add x22, x22, #0x4\n" + "add x20, %x[clamp_vals], #0x4\n" + "ldr q21, [x25, #0x10]\n" + "ld1r { v20.4s }, [x22]\n" + "cmp x24, #0x8\n" + "ldr q19, [x25, #0x20]\n" + "ldr q18, [x25, #0x30]\n" + "add x25, x25, #0x40\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v23.4s, v23.4s\n" + "fmla v6.4s, v22.4s, v23.s[0]\n" + "fmla v5.4s, v21.4s, v23.s[0]\n" + "fmul v6.4s, v6.4s, v20.4s\n" + "fadd v6.4s, v6.4s, v19.4s\n" + "fmul v5.4s, v5.4s, v20.4s\n" + "fadd v5.4s, v5.4s, v18.4s\n" + "fmax v6.4s, v6.4s, v17.4s\n" + "fmax v5.4s, v5.4s, v17.4s\n" + "fmin v6.4s, v6.4s, v16.4s\n" + "fmin v5.4s, v5.4s, v16.4s\n" + "blt 5f\n" + "str q6, [%x[dst], #0x0]\n" + "str q5, [%x[dst], #0x10]\n" + "b 10f\n" + "5:" // Partial output + "mov x20, %x[dst]\n" + "tbz x24, #2, 7f\n" + "st1 { v6.4s }, [x20], #0x10\n" + "tbz x24, #1, 6f\n" + "st1 { v5.d }[0], [x20], #0x8\n" + "tbz x24, #0, 9f\n" + "st1 { v5.s }[2], [x20]\n" + "b 9f\n" + "6:" // Output block 0: partial_1_4 + "tbz x24, #0, 9f\n" + "st1 { v5.s }[0], [x20]\n" + "b 9f\n" + "7:" // Output block 0: partial_2_0 + "tbz x24, #1, 8f\n" + "st1 { v6.d }[0], [x20], #0x8\n" + "tbz x24, #0, 9f\n" + "st1 { v6.s }[2], [x20]\n" + "b 9f\n" + "8:" // Output block 0: partial_1_0 + "st1 { v6.s }[0], [x20]\n" + "9:" // Output block 0: Done + "10:" // Stores done + "subs x24, x24, #0x8\n" + "add %x[dst], %x[dst], #0x20\n" + "bgt 2b\n" + "subs x26, x26, #0x1\n" + "add %x[lhs_packed], %x[lhs_packed], x27\n" + "mov %x[dst], x23\n" + "bgt 1b\n" + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", + "x24", "x25", "x26", "x27"); +} +#endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h new file mode 100644 index 0000000000000000000000000000000000000000..5c549ee502d954d8dfe42dc5581eee2484b14360 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -0,0 +1,142 @@ + +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 1 +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 8. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 1. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 8. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. +/// Output tile: (rows x cols) = 1 x 8 +/// Accumulation performed in a single for loop: 32 +/// Extension used: dotprod +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] bl Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref +/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c new file mode 100644 index 0000000000000000000000000000000000000000..3b9cb2c14fc213455475eb8992db635f09d48924 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c @@ -0,0 +1,547 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_MATMUL_INT8) +#error "I8mm extension required to compile this micro-kernel" +#else +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 8; +static const size_t kai_n_step = 4; +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_bl_multiple_of = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + + return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + size_t num_subblocks = bl / kai_bl_multiple_of; + size_t num_blocks = kai_num_blocks_per_row(k, bl); + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x12, #0x80\n" + "mov x11, %x[m]\n" + "movi v15.16b, #0xf0\n" + "mov x21, #0x3d800000\n" + "mov x20, #0x20\n" + "mul x12, %x[num_subblocks], x12\n" + "cmp x11, #0x8\n" + "dup v24.4s, w21\n" + "madd x12, %x[num_blocks], x12, x20\n" + "blt 11f\n" + "1:" // Row loop + "mov x10, %x[rhs_packed]\n" + "mov x9, %x[n]\n" + "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" + "2:" // Column loop + "mov x23, %x[lhs_packed]\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "mov x22, %x[num_blocks]\n" + "movi v22.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v0.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "add x21, x23, x12\n" + "3:" // Block loop + "movi v6.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v4.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v31.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "4:" // Sub block loop + "ldr q2, [x10, #0x0]\n" + "ldr q20, [x10, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q25, [x23, #0x0]\n" + "ldr q11, [x23, #0x10]\n" + "ldr q9, [x21, #0x0]\n" + "ldr q19, [x21, #0x10]\n" + "ldr q1, [x10, #0x20]\n" + "ldr q29, [x10, #0x30]\n" + "shl v27.16b, v2.16b, #0x4\n" + "shl v21.16b, v20.16b, #0x4\n" + "ldr q17, [x23, #0x20]\n" + "ldr q26, [x23, #0x30]\n" + "and v2.16b, v2.16b, v15.16b\n" + "and v20.16b, v20.16b, v15.16b\n" + "ldr q28, [x21, #0x20]\n" + "ldr q16, [x21, #0x30]\n" + "add x10, x10, #0x40\n" + ".inst 0x4e9ba726 // smmla v6.4s, v25.16b, v27.16b\n" + ".inst 0x4e95a72a // smmla v10.4s, v25.16b, v21.16b\n" + "ldr q25, [x23, #0x40]\n" + ".inst 0x4e9ba564 // smmla v4.4s, v11.16b, v27.16b\n" + ".inst 0x4e95a572 // smmla v18.4s, v11.16b, v21.16b\n" + "ldr q11, [x23, #0x50]\n" + ".inst 0x4e9ba53f // smmla v31.4s, v9.16b, v27.16b\n" + ".inst 0x4e95a523 // smmla v3.4s, v9.16b, v21.16b\n" + "ldr q9, [x21, #0x40]\n" + ".inst 0x4e9ba667 // smmla v7.4s, v19.16b, v27.16b\n" + "ldr q27, [x21, #0x50]\n" + ".inst 0x4e95a677 // smmla v23.4s, v19.16b, v21.16b\n" + "ldr q21, [x23, #0x60]\n" + "shl v19.16b, v1.16b, #0x4\n" + "and v1.16b, v1.16b, v15.16b\n" + ".inst 0x4e93a626 // smmla v6.4s, v17.16b, v19.16b\n" + ".inst 0x4e93a744 // smmla v4.4s, v26.16b, v19.16b\n" + ".inst 0x4e93a79f // smmla v31.4s, v28.16b, v19.16b\n" + ".inst 0x4e93a607 // smmla v7.4s, v16.16b, v19.16b\n" + "ldr q19, [x23, #0x70]\n" + "add x23, x23, #0x80\n" + ".inst 0x4e82a726 // smmla v6.4s, v25.16b, v2.16b\n" + ".inst 0x4e82a564 // smmla v4.4s, v11.16b, v2.16b\n" + ".inst 0x4e82a53f // smmla v31.4s, v9.16b, v2.16b\n" + ".inst 0x4e82a767 // smmla v7.4s, v27.16b, v2.16b\n" + "shl v2.16b, v29.16b, #0x4\n" + "and v29.16b, v29.16b, v15.16b\n" + ".inst 0x4e82a62a // smmla v10.4s, v17.16b, v2.16b\n" + "ldr q17, [x21, #0x60]\n" + ".inst 0x4e82a752 // smmla v18.4s, v26.16b, v2.16b\n" + "ldr q26, [x21, #0x70]\n" + "add x21, x21, #0x80\n" + ".inst 0x4e82a783 // smmla v3.4s, v28.16b, v2.16b\n" + ".inst 0x4e82a617 // smmla v23.4s, v16.16b, v2.16b\n" + ".inst 0x4e81a6a6 // smmla v6.4s, v21.16b, v1.16b\n" + ".inst 0x4e81a664 // smmla v4.4s, v19.16b, v1.16b\n" + ".inst 0x4e81a63f // smmla v31.4s, v17.16b, v1.16b\n" + ".inst 0x4e94a72a // smmla v10.4s, v25.16b, v20.16b\n" + ".inst 0x4e94a572 // smmla v18.4s, v11.16b, v20.16b\n" + ".inst 0x4e81a747 // smmla v7.4s, v26.16b, v1.16b\n" + ".inst 0x4e94a523 // smmla v3.4s, v9.16b, v20.16b\n" + ".inst 0x4e94a777 // smmla v23.4s, v27.16b, v20.16b\n" + ".inst 0x4e9da6aa // smmla v10.4s, v21.16b, v29.16b\n" + ".inst 0x4e9da672 // smmla v18.4s, v19.16b, v29.16b\n" + ".inst 0x4e9da623 // smmla v3.4s, v17.16b, v29.16b\n" + ".inst 0x4e9da757 // smmla v23.4s, v26.16b, v29.16b\n" + "bgt 4b\n" + "ldr d20, [x10, #0x0]\n" + "uzp1 v21.2d, v6.2d, v10.2d\n" + "uzp2 v19.2d, v6.2d, v10.2d\n" + "add x10, x10, #0x8\n" + "uzp1 v17.2d, v4.2d, v18.2d\n" + "uzp2 v16.2d, v4.2d, v18.2d\n" + "shll v20.4s, v20.4h, #0x10\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v19.4s, v19.4s\n" + "scvtf v17.4s, v17.4s\n" + "scvtf v16.4s, v16.4s\n" + "fmul v20.4s, v20.4s, v24.4s\n" + "fmla v12.4s, v21.4s, v20.4s\n" + "fmla v13.4s, v19.4s, v20.4s\n" + "fmla v22.4s, v17.4s, v20.4s\n" + "fmla v14.4s, v16.4s, v20.4s\n" + "uzp1 v19.2d, v31.2d, v3.2d\n" + "uzp2 v18.2d, v31.2d, v3.2d\n" + "uzp1 v17.2d, v7.2d, v23.2d\n" + "uzp2 v16.2d, v7.2d, v23.2d\n" + "scvtf v19.4s, v19.4s\n" + "scvtf v18.4s, v18.4s\n" + "scvtf v17.4s, v17.4s\n" + "scvtf v16.4s, v16.4s\n" + "fmla v5.4s, v19.4s, v20.4s\n" + "fmla v0.4s, v18.4s, v20.4s\n" + "fmla v30.4s, v17.4s, v20.4s\n" + "fmla v8.4s, v16.4s, v20.4s\n" + "subs x22, x22, #0x1\n" + "bgt 3b\n" + "ld1 { v23.4s }, [x23]\n" + "ld1 { v1.4s }, [x21]\n" + "add x23, x23, #0x10\n" + "add x21, x21, #0x10\n" + "ldr q21, [x10, #0x0]\n" + "ldr q20, [x23, #0x0]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x9, #0x4\n" + "ldr q19, [x21, #0x0]\n" + "ldr q18, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v1.4s, v1.4s\n" + "fmla v12.4s, v21.4s, v23.s[0]\n" + "fmla v13.4s, v21.4s, v23.s[1]\n" + "fmla v22.4s, v21.4s, v23.s[2]\n" + "fmla v14.4s, v21.4s, v23.s[3]\n" + "fmla v5.4s, v21.4s, v1.s[0]\n" + "fmla v0.4s, v21.4s, v1.s[1]\n" + "fmla v30.4s, v21.4s, v1.s[2]\n" + "fmla v8.4s, v21.4s, v1.s[3]\n" + "fmul v12.4s, v12.4s, v20.s[0]\n" + "fmul v13.4s, v13.4s, v20.s[1]\n" + "fmul v22.4s, v22.4s, v20.s[2]\n" + "fmul v14.4s, v14.4s, v20.s[3]\n" + "fmul v5.4s, v5.4s, v19.s[0]\n" + "fmul v0.4s, v0.4s, v19.s[1]\n" + "fadd v12.4s, v12.4s, v18.4s\n" + "fmul v30.4s, v30.4s, v19.s[2]\n" + "fmul v8.4s, v8.4s, v19.s[3]\n" + "fadd v13.4s, v13.4s, v18.4s\n" + "fadd v22.4s, v22.4s, v18.4s\n" + "fadd v14.4s, v14.4s, v18.4s\n" + "fadd v5.4s, v5.4s, v18.4s\n" + "fadd v0.4s, v0.4s, v18.4s\n" + "fadd v30.4s, v30.4s, v18.4s\n" + "fadd v8.4s, v8.4s, v18.4s\n" + "fmax v12.4s, v12.4s, v17.4s\n" + "fmax v13.4s, v13.4s, v17.4s\n" + "fmax v22.4s, v22.4s, v17.4s\n" + "fmax v14.4s, v14.4s, v17.4s\n" + "fmax v5.4s, v5.4s, v17.4s\n" + "fmax v0.4s, v0.4s, v17.4s\n" + "fmax v30.4s, v30.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmin v12.4s, v12.4s, v16.4s\n" + "fmin v13.4s, v13.4s, v16.4s\n" + "fmin v22.4s, v22.4s, v16.4s\n" + "fmin v14.4s, v14.4s, v16.4s\n" + "fmin v5.4s, v5.4s, v16.4s\n" + "fmin v0.4s, v0.4s, v16.4s\n" + "fmin v30.4s, v30.4s, v16.4s\n" + "fmin v8.4s, v8.4s, v16.4s\n" + "blt 7f\n" + "mov x20, %x[dst]\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q14, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q0, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q8, [x20, #0x0]\n" + "b 10f\n" + "7:" // Partial output + "mov x27, %x[dst]\n" + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x27, %x[dst_stride_row], LSL #1\n" + "add x21, x27, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "tbz x9, #1, 8f\n" + "st1 { v8.d }[0], [x23], #0x8\n" + "st1 { v30.d }[0], [x25], #0x8\n" + "st1 { v0.d }[0], [x24], #0x8\n" + "st1 { v5.d }[0], [x26], #0x8\n" + "st1 { v14.d }[0], [x20], #0x8\n" + "st1 { v22.d }[0], [x22], #0x8\n" + "st1 { v13.d }[0], [x21], #0x8\n" + "st1 { v12.d }[0], [x27], #0x8\n" + "tbz x9, #0, 9f\n" + "st1 { v8.s }[2], [x23]\n" + "st1 { v30.s }[2], [x25]\n" + "st1 { v0.s }[2], [x24]\n" + "st1 { v5.s }[2], [x26]\n" + "st1 { v14.s }[2], [x20]\n" + "st1 { v22.s }[2], [x22]\n" + "st1 { v13.s }[2], [x21]\n" + "st1 { v12.s }[2], [x27]\n" + "b 9f\n" + "8:" // Output block 0: partial_1_0 + "st1 { v8.s }[0], [x23]\n" + "st1 { v30.s }[0], [x25]\n" + "st1 { v0.s }[0], [x24]\n" + "st1 { v5.s }[0], [x26]\n" + "st1 { v14.s }[0], [x20]\n" + "st1 { v22.s }[0], [x22]\n" + "st1 { v13.s }[0], [x21]\n" + "st1 { v12.s }[0], [x27]\n" + "9:" // Output block 0: Done + "10:" // Output stage exit + "subs x9, x9, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "mov x20, #0x2\n" + "sub x11, x11, #0x8\n" + "cmp x11, #0x8\n" + "mov %x[dst], x28\n" + "madd %x[lhs_packed], x20, x12, %x[lhs_packed]\n" + "bge 1b\n" + "11:" // Row loop skip + "cbz x11, 21f\n" + "12:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "13:" // Row tail: Column loop + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "mov x23, %x[lhs_packed]\n" + "mov x21, %x[num_blocks]\n" + "movi v22.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "14:" // Row tail: Block loop + "movi v6.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v4.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "15:" // Row tail: Sub block loop + "ldr q0, [x26, #0x0]\n" + "ldr q31, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q11, [x23, #0x0]\n" + "ldr q30, [x23, #0x10]\n" + "ldr q29, [x26, #0x20]\n" + "ldr q28, [x26, #0x30]\n" + "add x26, x26, #0x40\n" + "ldr q27, [x23, #0x20]\n" + "ldr q26, [x23, #0x30]\n" + "shl v25.16b, v0.16b, #0x4\n" + "shl v23.16b, v31.16b, #0x4\n" + "ldr q1, [x23, #0x40]\n" + "ldr q21, [x23, #0x50]\n" + "and v0.16b, v0.16b, v15.16b\n" + "and v31.16b, v31.16b, v15.16b\n" + "ldr q20, [x23, #0x60]\n" + "ldr q19, [x23, #0x70]\n" + "shl v17.16b, v29.16b, #0x4\n" + "shl v16.16b, v28.16b, #0x4\n" + ".inst 0x4e99a566 // smmla v6.4s, v11.16b, v25.16b\n" + ".inst 0x4e97a56a // smmla v10.4s, v11.16b, v23.16b\n" + "and v29.16b, v29.16b, v15.16b\n" + "add x23, x23, #0x80\n" + ".inst 0x4e99a7c4 // smmla v4.4s, v30.16b, v25.16b\n" + ".inst 0x4e97a7d2 // smmla v18.4s, v30.16b, v23.16b\n" + "and v28.16b, v28.16b, v15.16b\n" + ".inst 0x4e91a766 // smmla v6.4s, v27.16b, v17.16b\n" + ".inst 0x4e90a76a // smmla v10.4s, v27.16b, v16.16b\n" + ".inst 0x4e91a744 // smmla v4.4s, v26.16b, v17.16b\n" + ".inst 0x4e90a752 // smmla v18.4s, v26.16b, v16.16b\n" + ".inst 0x4e80a426 // smmla v6.4s, v1.16b, v0.16b\n" + ".inst 0x4e9fa42a // smmla v10.4s, v1.16b, v31.16b\n" + ".inst 0x4e80a6a4 // smmla v4.4s, v21.16b, v0.16b\n" + ".inst 0x4e9fa6b2 // smmla v18.4s, v21.16b, v31.16b\n" + ".inst 0x4e9da686 // smmla v6.4s, v20.16b, v29.16b\n" + ".inst 0x4e9ca68a // smmla v10.4s, v20.16b, v28.16b\n" + ".inst 0x4e9da664 // smmla v4.4s, v19.16b, v29.16b\n" + ".inst 0x4e9ca672 // smmla v18.4s, v19.16b, v28.16b\n" + "bgt 15b\n" + "ldr d16, [x26, #0x0]\n" + "uzp1 v21.2d, v6.2d, v10.2d\n" + "uzp2 v20.2d, v6.2d, v10.2d\n" + "add x26, x26, #0x8\n" + "uzp1 v19.2d, v4.2d, v18.2d\n" + "uzp2 v17.2d, v4.2d, v18.2d\n" + "shll v16.4s, v16.4h, #0x10\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" + "scvtf v19.4s, v19.4s\n" + "scvtf v17.4s, v17.4s\n" + "fmul v16.4s, v16.4s, v24.4s\n" + "fmla v12.4s, v21.4s, v16.4s\n" + "fmla v13.4s, v20.4s, v16.4s\n" + "fmla v22.4s, v19.4s, v16.4s\n" + "fmla v14.4s, v17.4s, v16.4s\n" + "subs x21, x21, #0x1\n" + "bgt 14b\n" + "ld1 { v21.4s }, [x23]\n" + "ldr q20, [x26, #0x0]\n" + "add x23, x23, #0x10\n" + "add x20, %x[clamp_vals], #0x4\n" + "ldr q19, [x23, #0x0]\n" + "ldr q18, [x26, #0x10]\n" + "cmp x25, #0x4\n" + "add x26, x26, #0x20\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "ld1r { v16.4s }, [x20]\n" + "scvtf v21.4s, v21.4s\n" + "fmla v12.4s, v20.4s, v21.s[0]\n" + "fmla v13.4s, v20.4s, v21.s[1]\n" + "fmla v22.4s, v20.4s, v21.s[2]\n" + "fmla v14.4s, v20.4s, v21.s[3]\n" + "fmul v12.4s, v12.4s, v19.s[0]\n" + "fmul v13.4s, v13.4s, v19.s[1]\n" + "fmul v22.4s, v22.4s, v19.s[2]\n" + "fadd v12.4s, v12.4s, v18.4s\n" + "fmul v14.4s, v14.4s, v19.s[3]\n" + "fadd v13.4s, v13.4s, v18.4s\n" + "fadd v22.4s, v22.4s, v18.4s\n" + "fadd v14.4s, v14.4s, v18.4s\n" + "fmax v12.4s, v12.4s, v17.4s\n" + "fmax v13.4s, v13.4s, v17.4s\n" + "fmax v22.4s, v22.4s, v17.4s\n" + "fmax v14.4s, v14.4s, v17.4s\n" + "fmin v12.4s, v12.4s, v16.4s\n" + "fmin v13.4s, v13.4s, v16.4s\n" + "fmin v22.4s, v22.4s, v16.4s\n" + "fmin v14.4s, v14.4s, v16.4s\n" + "blt 17f\n" + "mov x20, %x[dst]\n" + "cmp x11, #0x1\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 20f\n" + "cmp x11, #0x2\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 20f\n" + "cmp x11, #0x3\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 20f\n" + "str q14, [x20, #0x0]\n" + "b 20f\n" + "17:" // Row tail: Partial output + "mov x23, %x[dst]\n" + "cmp x11, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GT\n" + "cmp x11, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GT\n" + "cmp x11, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GT\n" + "tbz x25, #1, 18f\n" + "st1 { v14.d }[0], [x20], #0x8\n" + "st1 { v22.d }[0], [x21], #0x8\n" + "st1 { v13.d }[0], [x22], #0x8\n" + "st1 { v12.d }[0], [x23], #0x8\n" + "tbz x25, #0, 19f\n" + "st1 { v14.s }[2], [x20]\n" + "st1 { v22.s }[2], [x21]\n" + "st1 { v13.s }[2], [x22]\n" + "st1 { v12.s }[2], [x23]\n" + "b 19f\n" + "18:" // Row tail: Output block 0: partial_1_0 + "st1 { v14.s }[0], [x20]\n" + "st1 { v22.s }[0], [x21]\n" + "st1 { v13.s }[0], [x22]\n" + "st1 { v12.s }[0], [x23]\n" + "19:" // Row tail: Output block 0: Done + "20:" // Row tail: Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 13b\n" + "subs x11, x11, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x12\n" + "mov %x[dst], x24\n" + "bgt 12b\n" + "21:" // Row tail: Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} +#endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h new file mode 100644 index 0000000000000000000000000000000000000000..765d84abf8c9ba859314f28fa3af84290f988328 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h @@ -0,0 +1,142 @@ + +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); + +/// Gets the nr value, which must be used to pack the RHS matrix +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 8 +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 8. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. +/// Output tile: (rows x cols) = 8 x 4 +/// Accumulation performed in a single for loop: 32 +/// Extension used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] bl Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref +/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c new file mode 100644 index 0000000000000000000000000000000000000000..4dcb19555e90b8022499c10ced7ac6b701660816 --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.c @@ -0,0 +1,401 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_MATMUL_INT8) +#error "I8mm extension required to compile this micro-kernel" +#else +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 4; +static const size_t kai_n_step = 8; +static const size_t kai_mr = 4; +static const size_t kai_nr = 8; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_bl_multiple_of = 32; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(uint16_t); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = (bl / 2) + kai_num_bytes_multiplier_rhs; + + return kai_nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t n_idx, size_t k, size_t bl) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k, bl); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, float* dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT((bl % kai_kr) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + size_t num_subblocks = bl / kai_bl_multiple_of; + size_t num_blocks = kai_num_blocks_per_row(k, bl); + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x28, #0x80\n" + "mov x21, #0x3d800000\n" + "movi v17.16b, #0xf0\n" + "mov x20, #0x20\n" + "mov x27, %x[m]\n" + "mul x28, %x[num_subblocks], x28\n" + "dup v14.4s, w21\n" + "madd x28, %x[num_blocks], x28, x20\n" + "cbz x27, 12f\n" + "1:" // Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "2:" // Column loop + "movi v1.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "mov x22, %x[lhs_packed]\n" + "mov x21, %x[num_blocks]\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "3:" // Block loop + "movi v21.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "mov x20, %x[num_subblocks]\n" + "movi v24.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "movi v2.4s, #0x0\n" + "movi v8.4s, #0x0\n" + "4:" // Sub block loop + "ldr q6, [x26, #0x0]\n" + "ldr q0, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q10, [x26, #0x20]\n" + "ldr q26, [x26, #0x30]\n" + "ldr q22, [x22, #0x0]\n" + "ldr q20, [x22, #0x10]\n" + "ldr q31, [x26, #0x40]\n" + "ldr q15, [x26, #0x50]\n" + "shl v29.16b, v6.16b, #0x4\n" + "shl v9.16b, v0.16b, #0x4\n" + "ldr q25, [x26, #0x60]\n" + "ldr q16, [x26, #0x70]\n" + "shl v5.16b, v10.16b, #0x4\n" + "shl v19.16b, v26.16b, #0x4\n" + "and v6.16b, v6.16b, v17.16b\n" + "and v0.16b, v0.16b, v17.16b\n" + "add x26, x26, #0x80\n" + ".inst 0x4e9da6d5 // smmla v21.4s, v22.16b, v29.16b\n" + ".inst 0x4e89a6d8 // smmla v24.4s, v22.16b, v9.16b\n" + ".inst 0x4e9da687 // smmla v7.4s, v20.16b, v29.16b\n" + "ldr q29, [x22, #0x20]\n" + "and v10.16b, v10.16b, v17.16b\n" + ".inst 0x4e85a6de // smmla v30.4s, v22.16b, v5.16b\n" + ".inst 0x4e93a6d7 // smmla v23.4s, v22.16b, v19.16b\n" + "ldr q22, [x22, #0x30]\n" + "and v26.16b, v26.16b, v17.16b\n" + ".inst 0x4e89a682 // smmla v2.4s, v20.16b, v9.16b\n" + "ldr q9, [x22, #0x40]\n" + ".inst 0x4e85a683 // smmla v3.4s, v20.16b, v5.16b\n" + "ldr q5, [x22, #0x50]\n" + ".inst 0x4e93a688 // smmla v8.4s, v20.16b, v19.16b\n" + "ldr q19, [x22, #0x60]\n" + "shl v20.16b, v31.16b, #0x4\n" + "and v31.16b, v31.16b, v17.16b\n" + ".inst 0x4e94a7b5 // smmla v21.4s, v29.16b, v20.16b\n" + ".inst 0x4e94a6c7 // smmla v7.4s, v22.16b, v20.16b\n" + "ldr q20, [x22, #0x70]\n" + "add x22, x22, #0x80\n" + ".inst 0x4e86a535 // smmla v21.4s, v9.16b, v6.16b\n" + ".inst 0x4e86a4a7 // smmla v7.4s, v5.16b, v6.16b\n" + "shl v6.16b, v15.16b, #0x4\n" + "and v15.16b, v15.16b, v17.16b\n" + ".inst 0x4e86a7b8 // smmla v24.4s, v29.16b, v6.16b\n" + ".inst 0x4e86a6c2 // smmla v2.4s, v22.16b, v6.16b\n" + "shl v6.16b, v25.16b, #0x4\n" + "and v25.16b, v25.16b, v17.16b\n" + ".inst 0x4e9fa675 // smmla v21.4s, v19.16b, v31.16b\n" + ".inst 0x4e9fa687 // smmla v7.4s, v20.16b, v31.16b\n" + "shl v31.16b, v16.16b, #0x4\n" + "and v16.16b, v16.16b, v17.16b\n" + ".inst 0x4e86a7be // smmla v30.4s, v29.16b, v6.16b\n" + ".inst 0x4e86a6c3 // smmla v3.4s, v22.16b, v6.16b\n" + ".inst 0x4e80a538 // smmla v24.4s, v9.16b, v0.16b\n" + ".inst 0x4e80a4a2 // smmla v2.4s, v5.16b, v0.16b\n" + ".inst 0x4e9fa7b7 // smmla v23.4s, v29.16b, v31.16b\n" + ".inst 0x4e9fa6c8 // smmla v8.4s, v22.16b, v31.16b\n" + ".inst 0x4e8aa53e // smmla v30.4s, v9.16b, v10.16b\n" + ".inst 0x4e8aa4a3 // smmla v3.4s, v5.16b, v10.16b\n" + ".inst 0x4e8fa678 // smmla v24.4s, v19.16b, v15.16b\n" + ".inst 0x4e8fa682 // smmla v2.4s, v20.16b, v15.16b\n" + ".inst 0x4e9aa537 // smmla v23.4s, v9.16b, v26.16b\n" + ".inst 0x4e9aa4a8 // smmla v8.4s, v5.16b, v26.16b\n" + ".inst 0x4e99a67e // smmla v30.4s, v19.16b, v25.16b\n" + ".inst 0x4e99a683 // smmla v3.4s, v20.16b, v25.16b\n" + ".inst 0x4e90a677 // smmla v23.4s, v19.16b, v16.16b\n" + ".inst 0x4e90a688 // smmla v8.4s, v20.16b, v16.16b\n" + "bgt 4b\n" + "ldr q29, [x26, #0x0]\n" + "uzp1 v26.2d, v21.2d, v24.2d\n" + "uzp2 v25.2d, v21.2d, v24.2d\n" + "add x26, x26, #0x10\n" + "uzp1 v24.2d, v30.2d, v23.2d\n" + "uzp2 v23.2d, v30.2d, v23.2d\n" + "uzp1 v22.2d, v7.2d, v2.2d\n" + "uzp2 v21.2d, v7.2d, v2.2d\n" + "shll v20.4s, v29.4h, #0x10\n" + "shll2 v19.4s, v29.8h, #0x10\n" + "uzp1 v0.2d, v3.2d, v8.2d\n" + "uzp2 v8.2d, v3.2d, v8.2d\n" + "scvtf v26.4s, v26.4s\n" + "scvtf v24.4s, v24.4s\n" + "fmul v20.4s, v20.4s, v14.4s\n" + "fmul v19.4s, v19.4s, v14.4s\n" + "scvtf v25.4s, v25.4s\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v0.4s, v0.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v8.4s, v8.4s\n" + "fmla v1.4s, v26.4s, v20.4s\n" + "fmla v12.4s, v24.4s, v19.4s\n" + "fmla v11.4s, v25.4s, v20.4s\n" + "fmla v13.4s, v23.4s, v19.4s\n" + "fmla v18.4s, v22.4s, v20.4s\n" + "fmla v27.4s, v0.4s, v19.4s\n" + "fmla v28.4s, v21.4s, v20.4s\n" + "fmla v4.4s, v8.4s, v19.4s\n" + "subs x21, x21, #0x1\n" + "bgt 3b\n" + "ld1 { v23.4s }, [x22]\n" + "ldr q22, [x26, #0x0]\n" + "add x22, x22, #0x10\n" + "add x20, %x[clamp_vals], #0x4\n" + "ldr q9, [x26, #0x10]\n" + "ldr q20, [x22, #0x0]\n" + "cmp x25, #0x8\n" + "ldr q19, [x26, #0x20]\n" + "ldr q21, [x26, #0x30]\n" + "add x26, x26, #0x40\n" + "ld1r { v10.4s }, [%x[clamp_vals]]\n" + "ld1r { v30.4s }, [x20]\n" + "scvtf v23.4s, v23.4s\n" + "fmla v1.4s, v22.4s, v23.s[0]\n" + "fmla v12.4s, v9.4s, v23.s[0]\n" + "fmla v11.4s, v22.4s, v23.s[1]\n" + "fmla v13.4s, v9.4s, v23.s[1]\n" + "fmla v18.4s, v22.4s, v23.s[2]\n" + "fmla v27.4s, v9.4s, v23.s[2]\n" + "fmla v28.4s, v22.4s, v23.s[3]\n" + "fmla v4.4s, v9.4s, v23.s[3]\n" + "fmul v1.4s, v1.4s, v20.s[0]\n" + "fmul v12.4s, v12.4s, v20.s[0]\n" + "fmul v11.4s, v11.4s, v20.s[1]\n" + "fmul v13.4s, v13.4s, v20.s[1]\n" + "fmul v18.4s, v18.4s, v20.s[2]\n" + "fmul v27.4s, v27.4s, v20.s[2]\n" + "fmul v28.4s, v28.4s, v20.s[3]\n" + "fmul v4.4s, v4.4s, v20.s[3]\n" + "fadd v1.4s, v1.4s, v19.4s\n" + "fadd v12.4s, v12.4s, v21.4s\n" + "fadd v11.4s, v11.4s, v19.4s\n" + "fadd v13.4s, v13.4s, v21.4s\n" + "fadd v18.4s, v18.4s, v19.4s\n" + "fadd v27.4s, v27.4s, v21.4s\n" + "fadd v28.4s, v28.4s, v19.4s\n" + "fadd v4.4s, v4.4s, v21.4s\n" + "fmax v1.4s, v1.4s, v10.4s\n" + "fmax v12.4s, v12.4s, v10.4s\n" + "fmax v11.4s, v11.4s, v10.4s\n" + "fmax v13.4s, v13.4s, v10.4s\n" + "fmax v18.4s, v18.4s, v10.4s\n" + "fmax v27.4s, v27.4s, v10.4s\n" + "fmax v28.4s, v28.4s, v10.4s\n" + "fmax v4.4s, v4.4s, v10.4s\n" + "fmin v1.4s, v1.4s, v30.4s\n" + "fmin v12.4s, v12.4s, v30.4s\n" + "fmin v11.4s, v11.4s, v30.4s\n" + "fmin v13.4s, v13.4s, v30.4s\n" + "fmin v18.4s, v18.4s, v30.4s\n" + "fmin v27.4s, v27.4s, v30.4s\n" + "fmin v28.4s, v28.4s, v30.4s\n" + "fmin v4.4s, v4.4s, v30.4s\n" + "blt 6f\n" + "mov x20, %x[dst]\n" + "cmp x27, #0x1\n" + "str q1, [x20, #0x0]\n" + "str q12, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 11f\n" + "cmp x27, #0x2\n" + "str q11, [x20, #0x0]\n" + "str q13, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 11f\n" + "cmp x27, #0x3\n" + "str q18, [x20, #0x0]\n" + "str q27, [x20, #0x10]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 11f\n" + "str q28, [x20, #0x0]\n" + "str q4, [x20, #0x10]\n" + "b 11f\n" + "6:" // Partial output + "mov x23, %x[dst]\n" + "cmp x27, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GT\n" + "cmp x27, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GT\n" + "cmp x27, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GT\n" + "tbz x25, #2, 8f\n" + "st1 { v28.4s }, [x20], #0x10\n" + "st1 { v18.4s }, [x21], #0x10\n" + "st1 { v11.4s }, [x22], #0x10\n" + "st1 { v1.4s }, [x23], #0x10\n" + "tbz x25, #1, 7f\n" + "st1 { v4.d }[0], [x20], #0x8\n" + "st1 { v27.d }[0], [x21], #0x8\n" + "st1 { v13.d }[0], [x22], #0x8\n" + "st1 { v12.d }[0], [x23], #0x8\n" + "tbz x25, #0, 10f\n" + "st1 { v4.s }[2], [x20]\n" + "st1 { v27.s }[2], [x21]\n" + "st1 { v13.s }[2], [x22]\n" + "st1 { v12.s }[2], [x23]\n" + "b 10f\n" + "7:" // Output block 0: partial_1_4 + "tbz x25, #0, 10f\n" + "st1 { v4.s }[0], [x20]\n" + "st1 { v27.s }[0], [x21]\n" + "st1 { v13.s }[0], [x22]\n" + "st1 { v12.s }[0], [x23]\n" + "b 10f\n" + "8:" // Output block 0: partial_2_0 + "tbz x25, #1, 9f\n" + "st1 { v28.d }[0], [x20], #0x8\n" + "st1 { v18.d }[0], [x21], #0x8\n" + "st1 { v11.d }[0], [x22], #0x8\n" + "st1 { v1.d }[0], [x23], #0x8\n" + "tbz x25, #0, 10f\n" + "st1 { v28.s }[2], [x20]\n" + "st1 { v18.s }[2], [x21]\n" + "st1 { v11.s }[2], [x22]\n" + "st1 { v1.s }[2], [x23]\n" + "b 10f\n" + "9:" // Output block 0: partial_1_0 + "st1 { v28.s }[0], [x20]\n" + "st1 { v18.s }[0], [x21]\n" + "st1 { v11.s }[0], [x22]\n" + "st1 { v1.s }[0], [x23]\n" + "10:" // Output block 0: Done + "11:" // Output stage exit + "subs x25, x25, #0x8\n" + "add %x[dst], %x[dst], #0x20\n" + "bgt 2b\n" + "subs x27, x27, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x28\n" + "mov %x[dst], x24\n" + "bgt 1b\n" + "12:" // Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [num_subblocks] "r"(num_subblocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} +#endif // Architectural feature check diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h new file mode 100644 index 0000000000000000000000000000000000000000..b87c0220178527d4c99f9bcd85b91292adc71d0c --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h @@ -0,0 +1,142 @@ + +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 OR kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); + +/// Gets the nr value, which must be used to pack the RHS matrix +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the LHS and RHS matrices +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed Signed 8-bit quantized asymmetric per-row (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 4 +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t m_idx, // + size_t k); // + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed Signed 4-bit quantized symmetric per-channel (qsi4c32) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 8. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// @param[in] bl Block length. It must be a multiple of 32. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t n_idx, // + size_t k, // + size_t bl); // + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 4. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 8. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t m_idx, // + size_t n_idx, // + size_t dst_stride); // + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t m, // + size_t n); // + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4c32) and packed. +/// Output tile: (rows x cols) = 4 x 8 +/// Accumulation performed in a single for loop: 32 +/// Extension used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] bl Block length. It must be a multiple of 32. +/// @param[in] lhs_packed The LHS packed matrix. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref +/// kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 or @ref kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + size_t m, // + size_t n, // + size_t k, // + size_t bl, // + const void* lhs_packed, // + const void* rhs_packed, // + float* dst, // + size_t dst_stride_row, // + size_t dst_stride_col, // + float scalar_min, // + float scalar_max); // + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..f32d677cd9ec97dd6f17a263ee3fdd039157c8be --- /dev/null +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h @@ -0,0 +1,52 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// All micro-kernels variants of the same type share the same interfaces +// In this case, the micro-kernel type is: matmul_clamp_f32_qai8dxp_qsi4c32p + +/// Micro-kernel helper functions ("get" methods) +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_m_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_n_step_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_mr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_nr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_kr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_sr_func_t)(void); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k, size_t bl); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_dst_offset_func_t)( + size_t m_idx, size_t n_idx, size_t dst_stride); +typedef size_t (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_dst_size_func_t)(size_t m, size_t n); + +/// Micro-kernel core function ("run" method) +typedef void (*kai_matmul_clamp_f32_qai8dxp_qsi4c32p_run_matmul_func_t)( + size_t m, size_t n, size_t k, size_t bl, const void* lhs_p, const void* rhs_p, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +/// Micro-kernel interface +struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel { + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_m_step_func_t get_m_step; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_n_step_func_t get_n_step; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_mr_func_t get_mr; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_nr_func_t get_nr; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_nr_func_t get_kr; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_sr_func_t get_sr; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_lhs_packed_offset_func_t get_lhs_packed_offset; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_rhs_packed_offset_func_t get_rhs_packed_offset; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_dst_offset_func_t get_dst_offset; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_get_dst_size_func_t get_dst_size; + kai_matmul_clamp_f32_qai8dxp_qsi4c32p_run_matmul_func_t run_matmul; +}; + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c new file mode 100644 index 0000000000000000000000000000000000000000..1ed35a163e848e93383fbbdaf211af00a182eedf --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c @@ -0,0 +1,274 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); +static const size_t kai_nr_multiple_of = 4; +static const size_t kai_bl_multiple_of = 32; + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return (bl / 2) + num_bytes_multiplier_rhs; +} + +inline static size_t kai_rhs_packed_offset_end_of_all_blocks( + size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + + return (nr * num_bytes_per_block * num_blocks_per_row); +} + +size_t kai_get_n_step_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(size_t nr) { + return nr; +} + +size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t n_idx, // + size_t rhs_stride) { + KAI_UNUSED(rhs_stride); + KAI_ASSERT((n_idx % 2) == 0); + return (n_idx / 2) * sizeof(int8_t); +} + +size_t kai_get_rhs_packed_stride_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); + + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + + return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((n_idx % nr) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); + + return (n_idx / nr) * kai_get_rhs_packed_stride_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); +} + +size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); + + const size_t num_rows = kai_roundup(n, nr) / nr; + + return num_rows * kai_get_rhs_packed_stride_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); +} + +void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(params->scale_dt == kai_dt_f32 || params->scale_dt == kai_dt_f16 || params->scale_dt == kai_dt_bf16); + + // Note: The input matrix (rhs) is expected with: + // "n" columns and "k" rows (kxn) + + const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(params->scale_dt); + const size_t rhs_packed_stride = + kai_get_rhs_packed_stride_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, params->scale_dt); + const size_t rhs_packed_offset_end_of_all_blocks = + kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); + const size_t num_qblocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + const size_t num_bytes_per_block_k = bl / 2; + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t k_interleaved_v = 16U; + const size_t block_length_in_bytes = kr / sr; + + const int32_t rhs_zero_point = params->rhs_zero_point; + const enum kai_datatype scale_dt = params->scale_dt; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + // Before packing, it keeps the pointer to the first quantized block + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); + + // Initialize the RHS reduction sums to zero + memset(sums, 0, nr * kai_num_bytes_sum_rhs); + + // Iterate over the quantized blocks + for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) { + // Store the scales after packing all K values + void* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr; + + for (size_t i = 0; i < nr; ++i) { + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs; + void* src_scales_ptr = (void*)(scale + dst_qblock_idx * num_bytes_multiplier_rhs + // + (src_row_idx * scale_stride)); // + + memcpy( + dst_scales_ptr, // + src_scales_ptr, // + num_bytes_multiplier_rhs); // + } + + for (size_t dst_byte_idx = 0; dst_byte_idx < nr * num_bytes_per_block_k; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = + dst_qblock_idx * bl + block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = (n0_valid_idx / 2) + k0_idx * rhs_stride; + const size_t src_addr_byte1 = (n0_valid_idx / 2) + k1_idx * rhs_stride; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + float d = 0.0F; + switch (scale_dt) { + case kai_dt_f32: + d = ((float*)rhs_packed_scale)[nr_idx]; + break; + case kai_dt_f16: + d = kai_cast_f32_f16(((uint16_t*)rhs_packed_scale)[nr_idx]); + break; + case kai_dt_bf16: + d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); + break; + default: + KAI_ERROR("Unsupported scale data type"); + break; + } + + if ((n0_idx % 2) == 0) { + const uint8_t src_x0_lo = (byte0 & 0x0F); + const uint8_t src_x0_hi = (byte1 & 0x0F); + + sums[nr_idx] += ((int32_t)src_x0_lo - rhs_zero_point) * d; + sums[nr_idx] += ((int32_t)src_x0_hi - rhs_zero_point) * d; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + dst_row[dst_byte_idx] = dst_qs0 ^ 0x88; + } else { + const uint8_t src_x1_lo = (byte0 >> 4); + const uint8_t src_x1_hi = (byte1 >> 4); + + sums[nr_idx] += ((int32_t)src_x1_lo - rhs_zero_point) * d; + sums[nr_idx] += ((int32_t)src_x1_hi - rhs_zero_point) * d; + + const uint8_t dst_qs1 = src_x1_lo | (src_x1_hi << 4); + + dst_row[dst_byte_idx] = dst_qs1 ^ 0x88; + } + } + // Move the pointer after K values + dst_row += num_bytes_per_block * nr; + } + + // Move the pointer after the row sum + dst_row += kai_num_bytes_sum_rhs * nr; + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + + // Move the pointer after the biases + dst_row += kai_num_bytes_bias * nr; + } +} diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h new file mode 100644 index 0000000000000000000000000000000000000000..856639bc6947f4c04af264d22d5bc8cc13c71e6a --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h @@ -0,0 +1,152 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params { + int8_t lhs_zero_point; + uint8_t rhs_zero_point; + enum kai_datatype scale_dt; +}; + +/// Get the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// +/// @return the n step value +size_t kai_get_n_step_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(size_t nr); + +/// Gets the offset in bytes for the RHS matrix (not packed), which holds +/// the int4 values in a K x N matrix, where N is number of columns and K is the number of rows. +/// +/// Two int4 values are stored in one byte. +/// The lower order part of the byte (low) holds the first nibble (K-index + 0). +/// The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t n_idx, // + size_t rhs_stride); // + +/// Get the row stride in bytes to the packed RHS matrix +/// +/// @param[in] k The number of rows in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the stride in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Gets the offset in bytes for the packed RHS matrix. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The number of rows in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Gets the size in bytes for the quantized and packed RHS matrix. +/// +/// @param[in] n The number of columns in the RHS matrix (not packed). +/// @param[in] k The number of rows in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple +/// of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Runs the RHS packing micro-kernel. +/// +/// The int4 values are stored in a K x N matrix, where N is number of columns and K is the number of rows. +/// Two int4 values are stored in one byte. The lower order part of the byte (low) holds +/// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of columns in the RHS matrix (not packed). +/// @param[in] k The number of rows in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel. It must be a multiple of 4. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] bl The block length, which defines the number of +/// K values stored in a single block. It must be a multiple of 32. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n/// k/// (sizeof(uint8_t) / 2). +/// @param[in] rhs_stride The number of bytes per row in bytes of the RHS matrix +/// @param[in] bias The biases. +/// @param[in] scale The per-block quantization scales. +/// The scale data type must be provided with the params object. +/// Supported scale data types are FP32, FP16 and BF16. +/// @param[in] scale_stride The number of bytes per row in bytes of the scale matrix +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params* params); // + +#ifdef __cplusplus +} +#endif diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c new file mode 100644 index 0000000000000000000000000000000000000000..c6d897e3a8e320fa5f83e1abc3848f35cc89a668 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -0,0 +1,281 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); +static const size_t kai_nr_multiple_of = 4; +static const size_t kai_bl_multiple_of = 32; + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return (bl / 2) + num_bytes_multiplier_rhs; +} + +inline static size_t kai_rhs_packed_offset_end_of_all_blocks( + size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + + return (nr * num_bytes_per_block * num_blocks_per_row); +} + +size_t kai_get_n_step_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(size_t nr) { + return nr; +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + size_t n_idx, // + size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); + + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + + return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((n_idx % nr) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); + + return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); + + const size_t num_rows = kai_roundup(n, nr) / nr; + + return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); +} + +void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + KAI_ASSERT(params->scale_dt == kai_dt_f32 || params->scale_dt == kai_dt_f16 || params->scale_dt == kai_dt_bf16); + + // Note: The input matrix (rhs) is expected with: + // "k" columns and "n" rows (NxK) + + const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(params->scale_dt); + const size_t rhs_packed_stride = + kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, params->scale_dt); + const size_t rhs_packed_offset_end_of_all_blocks = + kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); + const size_t num_qblocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + const size_t num_bytes_per_block_k = bl / 2; + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t k_interleaved_v = 16U; + const size_t block_length_in_bytes = kr / sr; + + const int32_t rhs_zero_point = params->rhs_zero_point; + const enum kai_datatype scale_dt = params->scale_dt; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + // Before packing, it keeps the pointer to the first quantized block + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); + + // Initialize the RHS reduction sums to zero + memset(sums, 0, nr * kai_num_bytes_sum_rhs); + + // Iterate over the quantized blocks + for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) { + // Store the scales after packing all K values + void* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr; + + for (size_t i = 0; i < nr; ++i) { + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs; + void* src_scales_ptr = (void*)(scale + dst_qblock_idx * num_bytes_multiplier_rhs + // + (src_row_idx * scale_stride)); // + + memcpy( + dst_scales_ptr, // + src_scales_ptr, // + num_bytes_multiplier_rhs); // + } + + for (size_t dst_byte_idx = 0; dst_byte_idx < nr * num_bytes_per_block_k; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = + dst_qblock_idx * bl + block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + // The following operations where we extract the values from the bytes + // can be also written in the following and less efficient manner: + /* + uint8_t src_x0_lo = 0; + uint8_t src_x0_hi = 0; + + if ((k0_idx % 2) == 0) { + src_x0_lo = (byte0 & 0x0F); + } else { + src_x0_lo = (byte0 >> 4); + } + + if ((k1_idx % 2) == 0) { + src_x0_hi = (byte1 & 0x0F); + } else { + src_x0_hi = (byte1 >> 4); + } + */ + const size_t shift_right_x0 = (k0_idx % 2) * 4; + const size_t shift_right_x1 = (k1_idx % 2) * 4; + + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + + float d = 0.0F; + switch (scale_dt) { + case kai_dt_f32: + d = ((float*)rhs_packed_scale)[nr_idx]; + break; + case kai_dt_f16: + d = kai_cast_f32_f16(((uint16_t*)rhs_packed_scale)[nr_idx]); + break; + case kai_dt_bf16: + d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); + break; + default: + KAI_ERROR("Unsupported scale data type"); + break; + } + + sums[nr_idx] += ((int32_t)src_x0_lo - rhs_zero_point) * d; + sums[nr_idx] += ((int32_t)src_x0_hi - rhs_zero_point) * d; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + dst_row[dst_byte_idx] = dst_qs0 ^ 0x88; + } + // Move the pointer after K values + dst_row += num_bytes_per_block * nr; + } + + // Move the pointer after the row sum + dst_row += kai_num_bytes_sum_rhs * nr; + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + + // Move the pointer after the biases + dst_row += kai_num_bytes_bias * nr; + } +} diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h new file mode 100644 index 0000000000000000000000000000000000000000..2b411f43eee048968c364dec929f1b4b1560db86 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h @@ -0,0 +1,152 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params { + int8_t lhs_zero_point; + uint8_t rhs_zero_point; + enum kai_datatype scale_dt; +}; + +/// Get the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// +/// @return the n step value +size_t kai_get_n_step_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(size_t nr); + +/// Gets the offset in bytes for the RHS matrix (not packed), which holds +/// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. +/// +/// Two int4 values are stored in one byte. +/// The lower order part of the byte (low) holds the first nibble (K-index + 0). +/// The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + size_t n_idx, // + size_t rhs_stride); // + +/// Get the row stride in bytes to the packed RHS matrix +/// +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the stride in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Gets the offset in bytes for the packed RHS matrix. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a +/// multiple of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Gets the size in bytes for the quantized and packed RHS matrix. +/// +/// @param[in] n The number of rows in the RHS matrix (not packed) +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple +/// of 32. +/// @param[in] scale_dt Block scale data type +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + enum kai_datatype scale_dt); // + +/// Runs the RHS packing micro-kernel. +/// +/// The int4 values are stored in a N x K matrix, where N is number of rows and K is the number of columns. +/// Two int4 values are stored in one byte. The lower order part of the byte (low) holds +/// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of rows in the RHS matrix (not packed). +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel. It must be a multiple of 4. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] bl The block length, which defines the number of +/// K values stored in a single block. It must be a multiple of 32. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n/// k/// (sizeof(uint8_t) / 2). +/// @param[in] rhs_stride The number of bytes per row in bytes of the RHS matrix +/// @param[in] bias The biases. +/// @param[in] scale The per-block quantization scales. +/// The scale data type must be provided with the params object. +/// Supported scale data types are FP32, FP16 and BF16. +/// @param[in] scale_stride The number of bytes per row in bytes of the scale matrix +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params); // + +#ifdef __cplusplus +} +#endif diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 2ad1121582fe0c7b3ff188e369d164ea5952364d..c806b5d54fdd54ddd92082dfcf502ffbb2c3d29a 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -53,6 +53,7 @@ kai_cxx_library( cc_test( name = "kleidiai_test", srcs = [ + "tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp", "tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp", "tests/matmul_test.cpp", ], diff --git a/test/common/bfloat16.hpp b/test/common/bfloat16.hpp index 8b61d92eae8ad0c7b894411c495ee2b3f716e4a5..4fa8026ceccdc190944a9556e64a85694dcedff9 100644 --- a/test/common/bfloat16.hpp +++ b/test/common/bfloat16.hpp @@ -38,10 +38,8 @@ public: BFloat16& operator=(BFloat16&&) = default; /// Creates a new object from the specified numeric value. - template , bool> = true> - explicit BFloat16(T value) : _data(0) { - const auto value_f32 = static_cast(value); - asm("bfcvt %h[output], %s[input]" : [output] "=w"(_data) : [input] "w"(value_f32)); + BFloat16(float value) : _data(0) { + asm("bfcvt %h[output], %s[input]" : [output] "=w"(_data) : [input] "w"(value)); } /// Assigns to the specified numeric value which will be converted to `bfloat16_t`. @@ -52,9 +50,8 @@ public: return *this; } - /// Converts to numeric type `T`. - template , bool> = true> - explicit operator T() const { + /// Converts to floating-point. + operator float() const { union { float f32; uint32_t u32; @@ -62,7 +59,7 @@ public: data.u32 = static_cast(_data) << 16; - return static_cast(data.f32); + return data.f32; } /// Equality operator. diff --git a/test/reference/matmul.cpp b/test/reference/matmul.cpp index 72d03bb2cac0cc4b25a931c44bea4bff1503e397..381cf1d5ebfd0b6575e9842ea11720cff86475eb 100644 --- a/test/reference/matmul.cpp +++ b/test/reference/matmul.cpp @@ -12,6 +12,7 @@ #include #include "kai/kai_common.h" +#include "test/common/bfloat16.hpp" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" #include "test/common/float16.hpp" @@ -255,4 +256,90 @@ matmul_clamp_nt_t matmul_clamp_nt_t( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + float min_value, float max_value); + +template < + typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, + typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> +std::vector matmul_clamp_nt_nt( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + DstData min_value, DstData max_value) { + const auto lhs_num_quant_per_row = round_up_division(k, lhs_quant_width); + const auto rhs_num_quant_per_row = round_up_division(k, rhs_quant_width); + + std::vector dst(m * n * sizeof(DstData)); + + const auto* lhs_scales_ptr = reinterpret_cast(lhs_scales); + const auto* rhs_scales_ptr = reinterpret_cast(rhs_scales); + const auto* lhs_zero_points_ptr = reinterpret_cast(lhs_zero_points); + const auto* rhs_zero_points_ptr = reinterpret_cast(rhs_zero_points); + const auto* biases_ptr = reinterpret_cast(biases); + auto* dst_ptr = reinterpret_cast(dst.data()); + + for (size_t y = 0; y < m; ++y) { + for (size_t x = 0; x < n; ++x) { + DstData acc = 0; + + for (size_t i = 0; i < k; ++i) { + const auto lhs_value = read_array(lhs_data, y * k + i); + const auto lhs_scale = lhs_scales_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width]; + const auto lhs_zero_point = lhs_zero_points_ptr != nullptr + ? lhs_zero_points_ptr[y * lhs_num_quant_per_row + i / lhs_quant_width] + : 0; + + const auto rhs_value = read_array(rhs_data, x + i * n); + const auto rhs_scale = rhs_scales_ptr[x * rhs_num_quant_per_row + i / rhs_quant_width]; + const auto rhs_zero_point = rhs_zero_points_ptr != nullptr + ? rhs_zero_points_ptr[y * rhs_num_quant_per_row + i / rhs_quant_width] + : 0; + + acc += static_cast( + (static_cast(lhs_value) + static_cast(lhs_zero_point)) * + (static_cast(rhs_value) + static_cast(rhs_zero_point))) * + static_cast(lhs_scale) * static_cast(rhs_scale); + } + + if (biases_ptr != nullptr) { + acc += static_cast(biases_ptr[x]); + } + + acc = std::clamp(acc, min_value, max_value); + dst_ptr[y * n + x] = acc; + } + } + + return dst; +} + +template std::vector matmul_clamp_nt_nt( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + float min_value, float max_value); + +template std::vector +matmul_clamp_nt_nt( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + float min_value, float max_value); + +template std::vector +matmul_clamp_nt_nt( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + float min_value, float max_value); + } // namespace kai::test diff --git a/test/reference/matmul.hpp b/test/reference/matmul.hpp index 40fb684dced6fde78e24c4c17527dd3ac999d640..88a0729ff662f67b25321034a6981818c13edfe7 100644 --- a/test/reference/matmul.hpp +++ b/test/reference/matmul.hpp @@ -105,4 +105,44 @@ std::vector matmul_clamp_nt_t( const void* biases, // DstData min_value, DstData max_value); +/// Matrix multiplication with quantized input and floating-point output. +/// +/// The LHS matrix is non-transposed and the RHS matrix is non-transposed. +/// +/// @tparam LhsData The data type of the LHS matrix. +/// @tparam LhsScale The data type of the quantization scales of the LHS matrix. +/// @tparam LhsZeroPoint The data type of the quantization zero points of the LHS matrix. +/// @tparam Rhsdata The data type of the RHS matrix. +/// @tparam RhsScale The data type of the quantization scales of the RHS matrix. +/// @tparam RhsZeroPoint The data type of the quantization zero points of the RHS matrix. +/// @tparam Bias The data type of the bias vector. +/// @tparam IntAcc The data type of the intermediate integer accumulator. +/// @tparam DstData The data type of the floating-point accumulator and the output matrix. +/// +/// @param[in] m The LHS and output height. +/// @param[in] n The RHS height and output width. +/// @param[in] k The LHS and RHS width. +/// @param[in] lhs_data The LHS data matrix. +/// @param[in] lhs_scales The LHS quantization scales matrix. +/// @param[in] lhs_zero_points The LHS quantization zero points matrix. +/// @param[in] lhs_quant_width The LHS quantization block width. +/// @param[in] rhs_data The RHS data matrix. +/// @param[in] rhs_scales The RHS quantization scales matrix. +/// @param[in] rhs_zero_points The RHS quantization zero points matrix. +/// @param[in] rhs_quant_width The RHS quantization block width. +/// @param[in] biases The biases vector. +/// @param[in] min_value The minimum output value. +/// @param[in] max_value The maximum output value. +/// +/// @return The output matrix. +template < + typename LhsData, typename LhsScale, typename LhsZeroPoint, typename RhsData, typename RhsScale, + typename RhsZeroPoint, typename Bias, typename IntAcc, typename DstData> +std::vector matmul_clamp_nt_nt( + size_t m, size_t n, size_t k, // + const void* lhs_data, const void* lhs_scales, const void* lhs_zero_points, size_t lhs_quant_width, // + const void* rhs_data, const void* rhs_scales, const void* rhs_zero_points, size_t rhs_quant_width, // + const void* biases, // + DstData min_value, DstData max_value); + } // namespace kai::test diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 0ff3904a5620f70f984d65d93fb2efc46b3caab0..221ba36079ec8d6c6bf59289826a37ec2525b300 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -206,7 +206,7 @@ std::vector pack_data_scales_interleave_block( const auto num_quant_packets_x = round_up_multiple(width, quant_width) / quant_width; const auto data_bytes = height * width * size_in_bits / 8; - const auto scales_bytes = height * num_quant_packets_x * sizeof(Scale); + const auto scales_bytes = scales != nullptr ? height * num_quant_packets_x * sizeof(Scale) : 0; std::vector dst(data_bytes + scales_bytes); @@ -215,9 +215,11 @@ std::vector pack_data_scales_interleave_block( for (size_t y = 0; y < height; ++y) { for (size_t x_quant = 0; x_quant < width; x_quant += quant_width) { - write_array(dst_ptr, 0, *scales_ptr); - dst_ptr += sizeof(Scale); - ++scales_ptr; + if (scales_ptr != nullptr) { + write_array(dst_ptr, 0, *scales_ptr); + dst_ptr += sizeof(Scale); + ++scales_ptr; + } for (size_t x_element = 0; x_element < quant_width; ++x_element) { const auto x = x_quant + x_element / 2 + (x_element % 2 != 0 ? quant_width / 2 : 0); @@ -235,6 +237,8 @@ std::vector pack_data_scales_interleave_block( template std::vector pack_data_scales_interleave_block( const void* data, const void* scales, size_t height, size_t width, size_t quant_width); +template std::vector pack_data_scales_interleave_block( + const void* data, const void* scales, size_t height, size_t width, size_t quant_width); template std::vector pack_block_data_zero_points_scale_bias( diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp index 8564c810b78facfe012df623346e9b953dab7d17..10d76a7f6a289fea0a5217c391cea77b703cedaa 100644 --- a/test/reference/pack.hpp +++ b/test/reference/pack.hpp @@ -136,4 +136,43 @@ template std::vector pack_data_scales_interleave_block( const void* data, const void* scales, size_t height, size_t width, size_t quant_width); +/// Packs the quantized data with two halves of a block interleaved. +/// +/// ``` +/// Quantized data matrix: +/// +/// --->|-----------------|<--- Block width +/// | | +/// +-----------------+-----------------+----- ... +/// | q00 q01 q02 q03 | q04 q05 q06 q07 | ........ +/// | q10 q11 q12 q13 | q14 q15 q16 q17 | ........ +/// | q20 q21 q22 q23 | q24 q25 q26 q27 | ........ +/// | q30 q31 q32 q33 | q34 q35 q36 q37 | ........ +/// | ............... | ............... | ........ +/// : ............... : ............... : ........ +/// +/// Packed data: +/// +/// +-----------------+-----------------+----- ... +/// | q00 q02 q01 q03 | q04 q06 q05 q07 | ........ +/// | q10 q12 q11 q13 | q14 q16 q15 q17 | ........ +/// | q20 q22 q21 q23 | q24 q26 q25 q27 | ........ +/// | q30 q32 q31 q33 | q34 q36 q35 q37 | ........ +/// | ............... | ............... | ........ +/// : ............... : ............... : ........ +/// ``` +/// +/// @tparam Data The data type of the quantized value. +/// +/// @param[in] data The quantized data. +/// @param[in] height The number of rows. +/// @param[in] width The number of columns. +/// @param[in] block_width The number of columns in a block. +/// +/// @return The packed data buffer. +template +std::vector pack_data_interleave_block(const void* data, size_t height, size_t width, size_t block_width) { + return pack_data_scales_interleave_block(data, nullptr, height, width, block_width); +} + } // namespace kai::test diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp index ad4a450f2b06599c58130268e572b0ae88b3daaf..7d7a012a8a6c234d5b330870eda090ef494ecb10 100644 --- a/test/reference/quantize.cpp +++ b/test/reference/quantize.cpp @@ -13,6 +13,7 @@ #include #include +#include "test/common/bfloat16.hpp" #include "test/common/int4.hpp" #include "test/common/memory.hpp" #include "test/common/numeric_limits.hpp" @@ -74,7 +75,7 @@ IntType quantize_asymmetric(FloatType value, FloatType scale, ZeroPointType zero template std::tuple, std::vector> quantize_symmetric_per_block( - const void* src, size_t height, size_t width, size_t quant_width) { + const void* src, size_t height, size_t width, size_t quant_width, bool is_transposed) { static_assert(is_floating_point); static_assert(is_integral); static_assert(is_floating_point); @@ -113,7 +114,11 @@ std::tuple, std::vector> quantize_symmetric_per_bl if (x < width) { const auto quantized = quantize_symmetric(src_ptr[y * width + x], scale); - write_array(data.data(), y * width + x, quantized); + if (is_transposed) { + write_array(data.data(), y * width + x, quantized); + } else { + write_array(data.data(), x * height + y, quantized); + } } } } @@ -123,11 +128,13 @@ std::tuple, std::vector> quantize_symmetric_per_bl } template std::tuple, std::vector> quantize_symmetric_per_block( - const void* src, size_t height, size_t width, size_t quant_width); + const void* src, size_t height, size_t width, size_t quant_width, bool is_transposed); template std::tuple, std::vector> quantize_symmetric_per_block( - const void* src, size_t height, size_t width, size_t quant_width); + const void* src, size_t height, size_t width, size_t quant_width, bool is_transposed); +template std::tuple, std::vector> quantize_symmetric_per_block( + const void* src, size_t height, size_t width, size_t quant_width, bool is_transposed); template std::tuple, std::vector> quantize_symmetric_per_block( - const void* src, size_t height, size_t width, size_t quant_width); + const void* src, size_t height, size_t width, size_t quant_width, bool is_transposed); template std::tuple, std::vector, std::vector> quantize_asymmetric_per_block( @@ -192,5 +199,7 @@ std::tuple, std::vector, std::vector> qua template std::tuple, std::vector, std::vector> quantize_asymmetric_per_block< float, int8_t, float, int32_t>(const void* src, size_t height, size_t width, size_t quant_width); +template std::tuple, std::vector, std::vector> quantize_asymmetric_per_block< + float, int8_t, BFloat16, int32_t>(const void* src, size_t height, size_t width, size_t quant_width); } // namespace kai::test diff --git a/test/reference/quantize.hpp b/test/reference/quantize.hpp index 58eb88bbd203e8fd1dc5a6e3e2622738b6bd9d01..77bbfc84b3064a9709a6b9157c0dd395b6ba9e22 100644 --- a/test/reference/quantize.hpp +++ b/test/reference/quantize.hpp @@ -89,7 +89,7 @@ enum class QuantizationMethod : uint32_t { /// @return The quantized data matrix and the quantization scale matrix. template std::tuple, std::vector> quantize_symmetric_per_block( - const void* src, size_t height, size_t width, size_t quant_width); + const void* src, size_t height, size_t width, size_t quant_width, bool is_transposed = true); /// Quantizes each subblock of the matrix using asymmetric quantization method. /// diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15acde4ed293b514d82c1ffd3230335df2f6d7e0 --- /dev/null +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp @@ -0,0 +1,213 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" +#include "test/common/bfloat16.hpp" +#include "test/common/int4.hpp" +#include "test/common/memory.hpp" +#include "test/common/round.hpp" +#include "test/common/test_suite.hpp" +#include "test/reference/cast.hpp" +#include "test/reference/fill.hpp" +#include "test/reference/matmul.hpp" +#include "test/reference/pack.hpp" +#include "test/reference/quantize.hpp" + +namespace kai::test { + +static const std::array, 4> + variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p = {{ + UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod), + UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod), + UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm), + UKERNEL_MATMUL_VARIANT(clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm), + }}; + +class MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p : public UkernelVariantTest {}; + +TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_Transposed) { + auto& [variant_index, matmul_shape] = GetParam(); + const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_index); + + const uint64_t seed = 0; + + const size_t M = matmul_shape.m; + const size_t N = matmul_shape.n; + const size_t K = matmul_shape.k; + + const auto mr = ukernel_variant.interface.get_mr(); + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); + + const size_t bl = 32; + + // Generates input data. + const auto ref_lhs = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit asymmetric quantization. + // * Quantizes the RHS matrix using 4-bit symmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi4, ref_rhs_scales] = + quantize_symmetric_per_block(ref_rhs.data(), N, K, bl); + + const auto ref_dst = matmul_clamp_nt_t( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), + ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), + std::numeric_limits::max()); + + // Runs the LHS packing micro-kernel. + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + std::vector imp_packed_lhs(imp_packed_lhs_size); + kai_run_lhs_quant_pack_qai8dxp_f32( + M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + + // Runs the RHS packing micro-kernel. + // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. + // * Packs the RHS matrix. + const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); + + const size_t ref_rhs_qsu4_stride = round_up_division(K, 2); + const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); + + const auto imp_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::kai_dt_bf16); + std::vector imp_packed_rhs(imp_packed_rhs_size); + const kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params{ + .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::kai_dt_bf16}; + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), ref_rhs_qsu4_stride, nullptr, + reinterpret_cast(ref_rhs_scales.data()), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, + ¶ms); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + std::vector imp_dst(imp_dst_size); + ukernel_variant.interface.run_matmul( + M, N, K, bl, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), + N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); + + // Compares the output of the micro-kernels against the output of the reference implementation. + for (size_t y = 0; y < M; ++y) { + for (size_t x = 0; x < N; ++x) { + const auto imp_value = read_array(imp_dst.data(), y * N + x); + const auto ref_value = read_array(ref_dst.data(), y * N + x); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); + + if (rel_error > 0.0001F) { + ASSERT_EQ(imp_value, ref_value); + } + } + } +} + +TEST_P(MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, EndToEnd_RHS_NonTransposed) { + auto& [variant_index, matmul_shape] = GetParam(); + const auto& ukernel_variant = variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.at(variant_index); + + const uint64_t seed = 0; + + const size_t M = matmul_shape.m; + const size_t N = matmul_shape.n; + const size_t K = matmul_shape.k; + + const auto mr = ukernel_variant.interface.get_mr(); + const auto nr = ukernel_variant.interface.get_nr(); + const auto kr = ukernel_variant.interface.get_kr(); + const auto sr = ukernel_variant.interface.get_sr(); + + const size_t bl = 32; + + // Generates input data. + const auto ref_lhs = fill_random(M * K, seed + 0); + const auto ref_rhs = fill_random(N * K, seed + 1); + + // Runs the reference implementation. + // * Quantizes the LHS matrix using 8-bit asymmetric quantization. + // * Quantizes the RHS matrix using 4-bit symmetric quantization. + // * Performs GEMM. + const auto [ref_lhs_qvalues, ref_lhs_scales, ref_lhs_zero_points] = + quantize_asymmetric_per_block(ref_lhs.data(), M, K, K); + const auto [ref_rhs_qsi4, ref_rhs_scales] = + quantize_symmetric_per_block(ref_rhs.data(), N, K, bl, false /* is_transposed */); + + const auto ref_dst = matmul_clamp_nt_nt( + M, N, K, ref_lhs_qvalues.data(), ref_lhs_scales.data(), ref_lhs_zero_points.data(), K, ref_rhs_qsi4.data(), + ref_rhs_scales.data(), nullptr, bl, nullptr, std::numeric_limits::lowest(), + std::numeric_limits::max()); + + // Runs the LHS packing micro-kernel. + const auto imp_packed_lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + std::vector imp_packed_lhs(imp_packed_lhs_size); + kai_run_lhs_quant_pack_qai8dxp_f32( + M, K, mr, kr, sr, 0, reinterpret_cast(ref_lhs.data()), K * sizeof(float), imp_packed_lhs.data()); + + // Runs the RHS packing micro-kernel. + // * Generates the 4-bit unsigned symmetric quantized input for the micro-kernel. + // * Packs the RHS matrix. + const auto ref_rhs_qsu4 = cast_qsu4_qsi4(ref_rhs_qsi4.data(), N * K); + + const size_t ref_rhs_qsu4_stride = round_up_division(N, 2); + const size_t ref_rhs_scales_stride = round_up_division(K, bl) * sizeof(uint16_t); + + const auto imp_packed_rhs_size = + kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, bl, kai_datatype::kai_dt_bf16); + std::vector imp_packed_rhs(imp_packed_rhs_size); + const kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params{ + .lhs_zero_point = 1, .rhs_zero_point = 8, .scale_dt = kai_datatype::kai_dt_bf16}; + kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4.data(), ref_rhs_qsu4_stride, nullptr, + reinterpret_cast(ref_rhs_scales.data()), ref_rhs_scales_stride, imp_packed_rhs.data(), 0, + ¶ms); + + // Runs the GEMM micro-kernel. + const auto imp_dst_size = ukernel_variant.interface.get_dst_size(M, N); + ASSERT_EQ(imp_dst_size, ref_dst.size()); + std::vector imp_dst(imp_dst_size); + ukernel_variant.interface.run_matmul( + M, N, K, bl, imp_packed_lhs.data(), imp_packed_rhs.data(), reinterpret_cast(imp_dst.data()), + N * sizeof(float), sizeof(float), std::numeric_limits::lowest(), std::numeric_limits::max()); + + // Compares the output of the micro-kernels against the output of the reference implementation. + for (size_t y = 0; y < M; ++y) { + for (size_t x = 0; x < N; ++x) { + const auto imp_value = read_array(imp_dst.data(), y * N + x); + const auto ref_value = read_array(ref_dst.data(), y * N + x); + const auto rel_error = ref_value != 0 ? std::abs((imp_value - ref_value) / ref_value) : std::abs(imp_value); + + if (rel_error > 0.0001F) { + ASSERT_EQ(imp_value, ref_value); + } + } + } +} + +INSTANTIATE_TEST_SUITE_P( + MatMul, MatMulTest_f32_qmatmul_clamp_f32_qai8dxp_qsi4c32p, + testing::Combine( + testing::Range(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi4c32p.size()), + testing::Values(MatMulShape{16, 32, 64}))); + +} // namespace kai::test