From 392f2b4f0a315cd0ea921ba47f60878425f8caab Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 10 Dec 2024 18:04:09 +0000 Subject: [PATCH] Refactor RHS packing function for F32 <- QAI8DXP x QSU4C32 - Rename the packing function to include the the bf16 scale factor - Optimize the scalar variant. The new implementation is ~1.5x faster than the previous one Signed-off-by: Gian Marco Iodice --- .../CMakeLists.txt | 4 +- .../matmul_clamp_f32_qai8dxp_qsi4c32p.cpp | 20 +- kai/kai_common.h | 5 + ..._pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0.c} | 90 ++---- ..._pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0.h} | 107 +++---- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | 289 ------------------ ...s_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0.c | 266 ++++++++++++++++ ..._pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0.h} | 91 +++--- 8 files changed, 400 insertions(+), 472 deletions(-) rename kai/ukernels/matmul/pack/{kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c => kai_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0.c} (69%) rename kai/ukernels/matmul/pack/{kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h => kai_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0.h} (58%) delete mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c create mode 100644 kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0.c rename kai/ukernels/matmul/pack/{kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h => kai_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0.h} (64%) diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt index d1a81f3d..60750a93 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/CMakeLists.txt @@ -22,8 +22,8 @@ include_directories( add_executable(matmul_clamp_f32_qai8dxp_qsi4c32p matmul_clamp_f32_qai8dxp_qsi4c32p.cpp ${KLEIDIAI_PATH}/kai/kai_common.h - ${MATMUL_PACK_PATH}/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c - ${MATMUL_PACK_PATH}/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c + ${MATMUL_PACK_PATH}/kai_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0.c + ${MATMUL_PACK_PATH}/kai_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0.c ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_qai8dxp_f32.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c ${MATMUL_PATH}/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.c diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp index 03f8b332..a10a3f50 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4c32p/matmul_clamp_f32_qai8dxp_qsi4c32p.cpp @@ -23,8 +23,8 @@ #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h" #include "kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" -#include "kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" -#include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" +#include "kai_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0.h" +#include "kai_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0.h" #define INT4_MIN (-8) #define INT4_MAX (7) @@ -627,11 +627,11 @@ int main() { if (format == rhs_format::nxk) { rhs_packed_size = - kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, kai_dt_bf16); + kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0(n, k, nr, kr, sr, bl); } else { rhs_packed_size = - kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(n, k, nr, kr, sr, bl, kai_dt_bf16); + kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0(n, k, nr, kr, sr, bl); } const size_t dst_size = ukernel_variants[idx_variant].ukernel.get_dst_size(m, n); @@ -646,12 +646,10 @@ int main() { // If the RHS matrix contains constant values, the packing can be performed // only once if (format == rhs_format::nxk) { - kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; + kai_rhs_pack_qsu4c32_params params; params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - params.scale_dt = kai_dt_bf16; - kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + kai_run_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0( 1, n, k, // Dimensions nr, kr, sr, // Packing arguments bl, // Block length @@ -664,12 +662,10 @@ int main() { 0, ¶ms); } else { - kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params params; + kai_rhs_pack_qsu4c32_params params; params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - params.scale_dt = kai_dt_bf16; - kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( + kai_run_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0( 1, n, k, // Dimensions nr, kr, sr, // Packing arguments bl, // Block length diff --git a/kai/kai_common.h b/kai/kai_common.h index 7a28f768..39d1a93a 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -170,6 +170,11 @@ struct kai_rhs_pack_qsi8cx_params { float scale_multiplier; ///< Product of input (refers to lhs and rhs) and output quantization scales. }; +/// Parameter struct for RHS matrix packing (Quantized Symmetric Integer 8-bit with per-block quantization) +struct kai_rhs_pack_qsu4c32_params { + int32_t lhs_zero_point; ///< LHS Matrix quantization zero-point +}; + /// Parameter struct for RHS matrix packing struct kai_rhs_pack_qs4cxs1s0_param { int8_t lhs_zero_point; ///< LHS Matrix quantization zero-point diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0.c similarity index 69% rename from kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c rename to kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0.c index c9628280..59e5d72b 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0.c @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" +#include "kai_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0.h" #include #include @@ -13,36 +13,37 @@ static const size_t kai_num_bytes_sum_rhs = sizeof(float); static const size_t kai_num_bytes_bias = sizeof(float); +static const size_t kai_num_bytes_scale = sizeof(uint16_t); static const size_t kai_nr_multiple_of = 4; static const size_t kai_bl_multiple_of = 32; -inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { KAI_ASSERT((bl % kai_bl_multiple_of) == 0); return kai_roundup(k, bl) / bl; } -inline static size_t kai_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { +inline static size_t kai_get_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { KAI_ASSERT((bl % kai_bl_multiple_of) == 0); return (bl / 2) + num_bytes_multiplier_rhs; } -inline static size_t kai_rhs_packed_offset_end_of_all_blocks( +inline static size_t kai_get_rhs_packed_offset_end_of_all_blocks( size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs); return (nr * num_bytes_per_block * num_blocks_per_row); } -size_t kai_get_n_step_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(size_t nr) { +size_t kai_get_n_step_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0(size_t nr) { return nr; } -size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( +size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0( size_t n_idx, // size_t rhs_stride) { KAI_UNUSED(rhs_stride); @@ -50,64 +51,57 @@ size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( return (n_idx / 2) * sizeof(int8_t); } -size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0( size_t k, // size_t nr, // size_t kr, // size_t sr, // - size_t bl, // - enum kai_datatype scale_dt) { + size_t bl) { KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); KAI_UNUSED(kr); KAI_UNUSED(sr); - const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); - const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, kai_num_bytes_scale); return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } -size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0( size_t n_idx, // size_t k, // size_t nr, // size_t kr, // size_t sr, // - size_t bl, // - enum kai_datatype scale_dt) { + size_t bl) { KAI_ASSERT((n_idx % nr) == 0); KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); - return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); + return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0(k, nr, kr, sr, bl); } -size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( +size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0( size_t n, // size_t k, // size_t nr, // size_t kr, // size_t sr, // - size_t bl, // - enum kai_datatype scale_dt) { + size_t bl) { KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); + return num_rows * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0(k, nr, kr, sr, bl); } -void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( +void kai_run_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0( size_t num_groups, // size_t n, // size_t k, // @@ -122,39 +116,35 @@ void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( size_t scale_stride, // void* rhs_packed, // size_t extra_bytes, // - const struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params* params) { + const struct kai_rhs_pack_qsu4c32_params* params) { KAI_ASSERT(num_groups == 1); KAI_ASSERT(extra_bytes == 0); KAI_ASSERT(rhs != NULL); KAI_ASSERT(scale != NULL); KAI_ASSERT(rhs_packed != NULL); KAI_ASSERT(params != NULL); - KAI_ASSERT(params->rhs_zero_point == 8); KAI_ASSERT(params->lhs_zero_point == 1); KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((kr % sr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(params->scale_dt == kai_dt_f32 || params->scale_dt == kai_dt_f16 || params->scale_dt == kai_dt_bf16); // Note: The input matrix (rhs) is expected with: // "n" columns and "k" rows (kxn) - const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(params->scale_dt); const size_t rhs_packed_stride = - kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, params->scale_dt); + kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0(k, nr, kr, sr, bl); const size_t rhs_packed_offset_end_of_all_blocks = - kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); - const size_t num_qblocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + kai_get_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, kai_num_bytes_scale); + const size_t num_qblocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, kai_num_bytes_scale); const size_t num_bytes_per_block_k = bl / 2; const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t k_interleaved_v = 16U; const size_t block_length_in_bytes = kr / sr; - const int32_t rhs_zero_point = params->rhs_zero_point; - const enum kai_datatype scale_dt = params->scale_dt; + const int32_t rhs_zero_point = 8; for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { // Before packing, it keeps the pointer to the first quantized block @@ -174,14 +164,14 @@ void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( for (size_t i = 0; i < nr; ++i) { const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); - void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs; - const void* src_scales_ptr = scale_ptr + dst_qblock_idx * num_bytes_multiplier_rhs + // - (src_row_idx * scale_stride); // + void* dst_scales_ptr = rhs_packed_scale + i * kai_num_bytes_scale; + const void* src_scales_ptr = scale_ptr + dst_qblock_idx * kai_num_bytes_scale + // + (src_row_idx * scale_stride); // memcpy( - dst_scales_ptr, // - src_scales_ptr, // - num_bytes_multiplier_rhs); // + dst_scales_ptr, // + src_scales_ptr, // + kai_num_bytes_scale); // } size_t kr_block_idx = 0; @@ -194,21 +184,7 @@ void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( // Clamp the index to avoid out-of-bound reads const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); - float d = 0.0F; - switch (scale_dt) { - case kai_dt_f32: - d = ((float*)rhs_packed_scale)[nr_idx]; - break; - case kai_dt_f16: - d = kai_cast_f32_f16(((uint16_t*)rhs_packed_scale)[nr_idx]); - break; - case kai_dt_bf16: - d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); - break; - default: - KAI_ERROR("Unsupported scale data type"); - break; - } + const float d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); const size_t k_adjustment = ((super_kr_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0.h similarity index 58% rename from kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h rename to kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0.h index 42b546a0..f61b516e 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0.h @@ -13,12 +13,6 @@ extern "C" { #endif -struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params { - int8_t lhs_zero_point; - uint8_t rhs_zero_point; - enum kai_datatype scale_dt; -}; - /// Get the n step value. /// The micro-kernel can process any N values. However, the starting N index to /// be processed must be a multiple of n step. @@ -26,7 +20,7 @@ struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params { /// @param[in] nr The number of columns written by the matmul micro-kernel /// /// @return the n step value -size_t kai_get_n_step_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(size_t nr); +size_t kai_get_n_step_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0(size_t nr); /// Gets the offset in bytes for the RHS matrix (not packed), which holds /// the int4 values in a K x N matrix, where N is number of columns and K is the number of rows. @@ -39,80 +33,75 @@ size_t kai_get_n_step_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(size_t nr); /// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) /// /// @return the offset in bytes to the RHS matrix (not packed) -size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( +size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0( size_t n_idx, // size_t rhs_stride); // /// Get the row stride in bytes to the packed RHS matrix /// -/// @param[in] k The number of rows in the RHS matrix (not packed). +/// @param[in] k The number of columns in the RHS matrix (not packed). /// @param[in] nr The number of columns written by the matmul micro-kernel. /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a /// multiple of 32. -/// @param[in] scale_dt Block scale data type /// /// @return the stride in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( - size_t k, // - size_t nr, // - size_t kr, // - size_t sr, // - size_t bl, // - enum kai_datatype scale_dt); // +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl); // /// Gets the offset in bytes for the packed RHS matrix. /// /// @param[in] n_idx Row index in the RHS matrix (not packed). -/// @param[in] k The number of rows in the RHS matrix (not packed). +/// @param[in] k The number of columns in the RHS matrix (not packed). /// @param[in] nr The number of columns written by the matmul micro-kernel /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a /// multiple of 32. -/// @param[in] scale_dt Block scale data type /// /// @return the offset in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( - size_t n_idx, // - size_t k, // - size_t nr, // - size_t kr, // - size_t sr, // - size_t bl, // - enum kai_datatype scale_dt); // +size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl); // /// Gets the size in bytes for the quantized and packed RHS matrix. /// -/// @param[in] n The number of columns in the RHS matrix (not packed). -/// @param[in] k The number of rows in the RHS matrix (not packed). -/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] n The number of rows in the RHS matrix (not packed) +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple /// of 32. -/// @param[in] scale_dt Block scale data type /// /// @return the packed RHS matrix size in bytes -size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( - size_t n, // - size_t k, // - size_t nr, // - size_t kr, // - size_t sr, // - size_t bl, // - enum kai_datatype scale_dt); // +size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl); // /// Runs the RHS packing micro-kernel. /// /// The int4 values are stored in a K x N matrix, where N is number of columns and K is the number of rows. /// Two int4 values are stored in one byte. The lower order part of the byte (low) holds /// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). +/// The scale factor is stored as Bfloat16 data type /// /// @param[in] num_groups The number of groups. It must be 1. -/// @param[in] n The number of columns in the RHS matrix (not packed). -/// @param[in] k The number of rows in the RHS matrix (not packed). +/// @param[in] n The number of rows in the RHS matrix (not packed). +/// @param[in] k The number of columns in the RHS matrix (not packed). /// @param[in] nr The number of columns written by the matmul micro-kernel. It must be a multiple of 4. /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. @@ -120,32 +109,30 @@ size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( /// @param[in] bl The block length, which defines the number of /// K values stored in a single block. It must be a multiple of 32. /// @param[in] rhs The RHS matrix containing the 4-bit values. -/// Size in bytes is expected to be greater than or equal to n/// k/// (sizeof(uint8_t) / 2). +/// Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). /// @param[in] rhs_stride The number of bytes per row in bytes of the RHS matrix /// @param[in] bias The biases. /// @param[in] scale The per-block quantization scales. -/// The scale data type must be provided with the params object. -/// Supported scale data types are FP32, FP16 and BF16. /// @param[in] scale_stride The number of bytes per row in bytes of the scale matrix /// @param[out] rhs_packed The packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. /// @param[in] params Parameters for the micro-kernel. -void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( - size_t num_groups, // - size_t n, // - size_t k, // - size_t nr, // - size_t kr, // - size_t sr, // - size_t bl, // - const uint8_t* rhs, // - size_t rhs_stride, // - const float* bias, // - const void* scale, // - size_t scale_stride, // - void* rhs_packed, // - size_t extra_bytes, // - const struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params* params); // +void kai_run_rhs_pack_kxn_qsi4c32pscalebf16_qsu4c32s1s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_qsu4c32_params* params); // #ifdef __cplusplus } diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c deleted file mode 100644 index 1f73357b..00000000 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ /dev/null @@ -1,289 +0,0 @@ -// -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// -// SPDX-License-Identifier: Apache-2.0 -// -#include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" - -#include -#include -#include - -#include "kai/kai_common.h" - -static const size_t kai_num_bytes_sum_rhs = sizeof(float); -static const size_t kai_num_bytes_bias = sizeof(float); -static const size_t kai_nr_multiple_of = 4; -static const size_t kai_bl_multiple_of = 32; - -inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - return kai_roundup(k, bl) / bl; -} - -inline static size_t kai_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - return (bl / 2) + num_bytes_multiplier_rhs; -} - -inline static size_t kai_rhs_packed_offset_end_of_all_blocks( - size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { - KAI_ASSERT((bl % kr) == 0); - KAI_ASSERT((nr % kai_nr_multiple_of) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - - const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); - - return (nr * num_bytes_per_block * num_blocks_per_row); -} - -size_t kai_get_n_step_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(size_t nr) { - return nr; -} - -size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - size_t n_idx, // - size_t rhs_stride) { - return n_idx * rhs_stride; -} - -size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - size_t k, // - size_t nr, // - size_t kr, // - size_t sr, // - size_t bl, // - enum kai_datatype scale_dt) { - KAI_ASSERT((bl % kr) == 0); - KAI_ASSERT((nr % kai_nr_multiple_of) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); - - KAI_UNUSED(kr); - KAI_UNUSED(sr); - - const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(scale_dt); - const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); - - return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); -} - -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - size_t n_idx, // - size_t k, // - size_t nr, // - size_t kr, // - size_t sr, // - size_t bl, // - enum kai_datatype scale_dt) { - KAI_ASSERT((n_idx % nr) == 0); - KAI_ASSERT((bl % kr) == 0); - KAI_ASSERT((nr % kai_nr_multiple_of) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); - - return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); -} - -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - size_t n, // - size_t k, // - size_t nr, // - size_t kr, // - size_t sr, // - size_t bl, // - enum kai_datatype scale_dt) { - KAI_ASSERT((bl % kr) == 0); - KAI_ASSERT((nr % kai_nr_multiple_of) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); - - const size_t num_rows = kai_roundup(n, nr) / nr; - - return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); -} - -void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - size_t num_groups, // - size_t n, // - size_t k, // - size_t nr, // - size_t kr, // - size_t sr, // - size_t bl, // - const uint8_t* rhs, // - size_t rhs_stride, // - const float* bias, // - const void* scale, // - size_t scale_stride, // - void* rhs_packed, // - size_t extra_bytes, // - const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params) { - KAI_ASSERT(num_groups == 1); - KAI_ASSERT(extra_bytes == 0); - KAI_ASSERT(rhs != NULL); - KAI_ASSERT(scale != NULL); - KAI_ASSERT(rhs_packed != NULL); - KAI_ASSERT(params != NULL); - KAI_ASSERT(params->rhs_zero_point == 8); - KAI_ASSERT(params->lhs_zero_point == 1); - - KAI_ASSERT((bl % kr) == 0); - KAI_ASSERT((kr % sr) == 0); - KAI_ASSERT((nr % kai_nr_multiple_of) == 0); - KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(params->scale_dt == kai_dt_f32 || params->scale_dt == kai_dt_f16 || params->scale_dt == kai_dt_bf16); - - // Note: The input matrix (rhs) is expected with: - // "k" columns and "n" rows (NxK) - - const size_t num_bytes_multiplier_rhs = kai_get_datatype_size_in_bytes(params->scale_dt); - const size_t rhs_packed_stride = - kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, params->scale_dt); - const size_t rhs_packed_offset_end_of_all_blocks = - kai_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, num_bytes_multiplier_rhs); - const size_t num_qblocks_per_row = kai_num_blocks_per_row(k, bl); - const size_t num_bytes_per_block = kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs); - const size_t num_bytes_per_block_k = bl / 2; - const size_t dst_num_rows = kai_roundup(n, nr) / nr; - const size_t k_interleaved_v = 16U; - const size_t block_length_in_bytes = kr / sr; - - const int32_t rhs_zero_point = params->rhs_zero_point; - const enum kai_datatype scale_dt = params->scale_dt; - - for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { - // Before packing, it keeps the pointer to the first quantized block - uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; - - float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); - - // Initialize the RHS reduction sums to zero - memset(sums, 0, nr * kai_num_bytes_sum_rhs); - - // Iterate over the quantized blocks - for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) { - // Store the scales after packing all K values - uint8_t* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr; - const uint8_t* scale_ptr = scale; - - for (size_t i = 0; i < nr; ++i) { - const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); - - void* dst_scales_ptr = rhs_packed_scale + i * num_bytes_multiplier_rhs; - const void* src_scales_ptr = - (scale_ptr + dst_qblock_idx * num_bytes_multiplier_rhs + // - (src_row_idx * scale_stride)); - - memcpy( - dst_scales_ptr, // - src_scales_ptr, // - num_bytes_multiplier_rhs); // - } - - size_t kr_block_idx = 0; - for (size_t dst_byte_idx = 0; dst_byte_idx < nr * num_bytes_per_block_k; - dst_byte_idx += block_length_in_bytes) { - const size_t super_kr_block_idx = kr_block_idx / nr; - const size_t nr_idx = kr_block_idx % nr; - const size_t n0_idx = dst_row_idx * nr + nr_idx; - - // Clamp the index to avoid out-of-bound reads - const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); - - float d = 0.0F; - switch (scale_dt) { - case kai_dt_f32: - d = ((float*)rhs_packed_scale)[nr_idx]; - break; - case kai_dt_f16: - d = kai_cast_f32_f16(((uint16_t*)rhs_packed_scale)[nr_idx]); - break; - case kai_dt_bf16: - d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); - break; - default: - KAI_ERROR("Unsupported scale data type"); - break; - } - - const size_t k_adjustment = - ((super_kr_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; - size_t k0_idx = dst_qblock_idx * bl + super_kr_block_idx * block_length_in_bytes + k_adjustment; - size_t k1_idx = k0_idx + k_interleaved_v; - - float partial_sum = 0.0F; - - for (size_t block_byte_idx = 0; block_byte_idx < block_length_in_bytes; ++block_byte_idx) { - const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; - const size_t src_addr_byte1 = src_addr_byte0 + k_interleaved_v / 2; - - uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; - uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; - - if (k0_idx < k) { - byte0 = rhs[src_addr_byte0]; - } - - if (k1_idx < k) { - byte1 = rhs[src_addr_byte1]; - } - - // The following operations where we extract the values from the bytes - // can be also written in the following and less efficient manner: - /* - uint8_t src_x0_lo = 0; - uint8_t src_x0_hi = 0; - - if ((k0_idx % 2) == 0) { - src_x0_lo = (byte0 & 0x0F); - } else { - src_x0_lo = (byte0 >> 4); - } - - if ((k1_idx % 2) == 0) { - src_x0_hi = (byte1 & 0x0F); - } else { - src_x0_hi = (byte1 >> 4); - } - */ - const size_t shift_right_x0 = (k0_idx % 2) * 4; - - const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; - const uint8_t src_x0_hi = (byte1 >> shift_right_x0) & 0x0F; - - partial_sum += (float)((int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * rhs_zero_point) * d; - - const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); - - dst_row[dst_byte_idx + block_byte_idx] = dst_qs0 ^ 0x88; - - k0_idx++; - k1_idx++; - } - sums[nr_idx] += partial_sum; - - // Increment the Kr block index - kr_block_idx++; - } - // Move the pointer after K values - dst_row += num_bytes_per_block * nr; - } - - // Move the pointer after the row sum - dst_row += kai_num_bytes_sum_rhs * nr; - - // Set the bias - if (bias == NULL) { - memset(dst_row, 0, nr * kai_num_bytes_bias); - } else { - for (size_t i = 0; i < nr; ++i) { - // Clamp the row index to avoid out-of-bound reads - const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); - ((float*)dst_row)[i] = bias[src_row_idx]; - } - } - } -} diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0.c new file mode 100644 index 00000000..9c1ea8d5 --- /dev/null +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0.c @@ -0,0 +1,266 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0.h" + +#include +#include +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_sum_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); +static const size_t kai_num_bytes_scale = sizeof(uint16_t); +static const size_t kai_nr_multiple_of = 4; +static const size_t kai_bl_multiple_of = 32; + +inline static size_t kai_get_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return kai_roundup(k, bl) / bl; +} + +inline static size_t kai_get_num_bytes_per_block(size_t bl, size_t num_bytes_multiplier_rhs) { + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + return (bl / 2) + num_bytes_multiplier_rhs; +} + +inline static size_t kai_get_rhs_packed_offset_end_of_all_blocks( + size_t k, size_t nr, size_t kr, size_t bl, size_t num_bytes_multiplier_rhs) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, num_bytes_multiplier_rhs); + + return (nr * num_bytes_per_block * num_blocks_per_row); +} + +size_t kai_get_n_step_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0(size_t nr) { + return nr; +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0( + size_t n_idx, // + size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + const size_t num_blocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, kai_num_bytes_scale); + + return nr * ((num_bytes_per_block * num_blocks_per_row) + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl) { + KAI_ASSERT((n_idx % nr) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0(k, nr, kr, sr, bl); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl) { + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + const size_t num_rows = kai_roundup(n, nr) / nr; + + return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0(k, nr, kr, sr, bl); +} + +void kai_run_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_qsu4c32_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->lhs_zero_point == 1); + + KAI_ASSERT((nr % 4) == 0); + KAI_ASSERT((bl % kr) == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT((nr % kai_nr_multiple_of) == 0); + KAI_ASSERT((bl % kai_bl_multiple_of) == 0); + + // Note: The input matrix (rhs) is expected with: + // "k" columns and "n" rows (NxK) + + const size_t rhs_packed_stride = + kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0(k, nr, kr, sr, bl); + const size_t rhs_packed_offset_end_of_all_blocks = + kai_get_rhs_packed_offset_end_of_all_blocks(k, nr, kr, bl, kai_num_bytes_scale); + const size_t num_qblocks_per_row = kai_get_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_get_num_bytes_per_block(bl, kai_num_bytes_scale); + const size_t num_bytes_per_block_k = bl / 2; + const size_t dst_num_rows = kai_roundup(n, nr); + const size_t block_length_in_bytes = kr / sr; + + uint8_t* dst_row = (uint8_t*)rhs_packed; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; dst_row_idx += nr) { + float* sums = (float*)(dst_row + rhs_packed_offset_end_of_all_blocks); + + // Initialize the RHS reduction sums to zero + memset(sums, 0, nr * kai_num_bytes_sum_rhs); + + // Iterate over the quantized blocks + for (size_t dst_qblock_idx = 0; dst_qblock_idx < num_qblocks_per_row; ++dst_qblock_idx) { + // Store the scales after packing all K values in the block + uint8_t* rhs_packed_scale = dst_row + num_bytes_per_block_k * nr; + const uint8_t* scale_ptr = (const uint8_t*)scale + dst_qblock_idx * kai_num_bytes_scale; + + for (size_t i = 0; i < nr; ++i) { + const size_t src_row_idx = fmin(dst_row_idx + i, n - 1); + + const void* src_scales_ptr = scale_ptr + src_row_idx * scale_stride; + void* dst_scales_ptr = rhs_packed_scale + i * kai_num_bytes_scale; + + memcpy( + dst_scales_ptr, // + src_scales_ptr, // + kai_num_bytes_scale); // + } + + size_t k0_idx_i = dst_qblock_idx * bl; + + for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { + for (size_t segment_idx = 0; segment_idx < 16 / block_length_in_bytes; ++segment_idx) { + for (size_t nr_idx = 0; nr_idx < nr; ++nr_idx) { + const size_t n0_idx = dst_row_idx + nr_idx; + + size_t k0_idx = k0_idx_i; + size_t k1_idx = k0_idx_i + 16; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = fmin(n0_idx, n - 1); + + const float d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); + + int32_t partial_sum = 0; + + size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; + + for (size_t block_byte_idx = 0; block_byte_idx < block_length_in_bytes; block_byte_idx += 2) { + // uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4 + + uint8_t byte0 = 136; + uint8_t byte1 = 136; + uint8_t byte2 = 136; + uint8_t byte3 = 136; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte0 + 8]; + } + + if (k0_idx + 1 < k) { + byte2 = byte0; + } + + if (k1_idx + 1 < k) { + byte3 = byte1; + } + + k0_idx += 2; + k1_idx += 2; + + const uint8_t src_x0_lo = byte0 & 0x0F; + const uint8_t src_x0_hi = byte1 & 0x0F; + const uint8_t src_x1_lo = (byte2 >> 4) & 0x0F; + const uint8_t src_x1_hi = (byte3 >> 4) & 0x0F; + + partial_sum += (int32_t)src_x0_lo; + partial_sum += (int32_t)src_x0_hi; + partial_sum += (int32_t)src_x1_lo; + partial_sum += (int32_t)src_x1_hi; + partial_sum -= 32; // 4 * zero_point (8) + + const uint16_t dst_q = + ((src_x0_lo)) | ((src_x0_hi) << 4) | ((src_x1_lo) << 8) | ((src_x1_hi) << 12); + + *((uint16_t*)dst_row) = dst_q ^ 0x8888; + + dst_row += 2; + src_addr_byte0 += 1; + } + + sums[nr_idx] += partial_sum * d; + } + + k0_idx_i += block_length_in_bytes; + } + k0_idx_i += 16; + } + // Move the pointer after scales + dst_row += kai_num_bytes_scale * nr; + } + + // Move the pointer after the row sum + dst_row += kai_num_bytes_sum_rhs * nr; + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = fmin(dst_row_idx + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + // Move the pointer after the row sum + dst_row += kai_num_bytes_bias * nr; + } +} diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0.h similarity index 64% rename from kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h rename to kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0.h index 2822740b..40bad4ad 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0.h @@ -13,12 +13,6 @@ extern "C" { #endif -struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params { - int8_t lhs_zero_point; - uint8_t rhs_zero_point; - enum kai_datatype scale_dt; -}; - /// Get the n step value. /// The micro-kernel can process any N values. However, the starting N index to /// be processed must be a multiple of n step. @@ -26,7 +20,7 @@ struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params { /// @param[in] nr The number of columns written by the matmul micro-kernel /// /// @return the n step value -size_t kai_get_n_step_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(size_t nr); +size_t kai_get_n_step_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0(size_t nr); /// Gets the offset in bytes for the RHS matrix (not packed), which holds /// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. @@ -39,7 +33,7 @@ size_t kai_get_n_step_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(size_t nr); /// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) /// /// @return the offset in bytes to the RHS matrix (not packed) -size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0( size_t n_idx, // size_t rhs_stride); // @@ -51,16 +45,14 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a /// multiple of 32. -/// @param[in] scale_dt Block scale data type /// /// @return the stride in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - size_t k, // - size_t nr, // - size_t kr, // - size_t sr, // - size_t bl, // - enum kai_datatype scale_dt); // +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0( + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl); // /// Gets the offset in bytes for the packed RHS matrix. /// @@ -71,17 +63,15 @@ size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a /// multiple of 32. -/// @param[in] scale_dt Block scale data type /// /// @return the offset in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - size_t n_idx, // - size_t k, // - size_t nr, // - size_t kr, // - size_t sr, // - size_t bl, // - enum kai_datatype scale_dt); // +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0( + size_t n_idx, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl); // /// Gets the size in bytes for the quantized and packed RHS matrix. /// @@ -92,23 +82,22 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple /// of 32. -/// @param[in] scale_dt Block scale data type /// /// @return the packed RHS matrix size in bytes -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - size_t n, // - size_t k, // - size_t nr, // - size_t kr, // - size_t sr, // - size_t bl, // - enum kai_datatype scale_dt); // +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0( + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl); // /// Runs the RHS packing micro-kernel. /// /// The int4 values are stored in a N x K matrix, where N is number of rows and K is the number of columns. /// Two int4 values are stored in one byte. The lower order part of the byte (low) holds /// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). +/// The scale factor is stored as Bfloat16 data type /// /// @param[in] num_groups The number of groups. It must be 1. /// @param[in] n The number of rows in the RHS matrix (not packed). @@ -124,28 +113,26 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( /// @param[in] rhs_stride The number of bytes per row in bytes of the RHS matrix /// @param[in] bias The biases. /// @param[in] scale The per-block quantization scales. -/// The scale data type must be provided with the params object. -/// Supported scale data types are FP32, FP16 and BF16. /// @param[in] scale_stride The number of bytes per row in bytes of the scale matrix /// @param[out] rhs_packed The packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. /// @param[in] params Parameters for the micro-kernel. -void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - size_t num_groups, // - size_t n, // - size_t k, // - size_t nr, // - size_t kr, // - size_t sr, // - size_t bl, // - const uint8_t* rhs, // - size_t rhs_stride, // - const float* bias, // - const void* scale, // - size_t scale_stride, // - void* rhs_packed, // - size_t extra_bytes, // - const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params); // +void kai_run_rhs_pack_nxk_qsi4c32pscalebf16_qsu4c32s1s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + size_t bl, // + const uint8_t* rhs, // + size_t rhs_stride, // + const float* bias, // + const void* scale, // + size_t scale_stride, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_qsu4c32_params* params); // #ifdef __cplusplus } -- GitLab