diff --git a/kai/kai_common.h b/kai/kai_common.h index cbd03a95c6bf16e8ca9183ef04c8979f86caed59..63d02f3a9e279747af5e1d7f989ed41dbdb5e23b 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -166,19 +166,14 @@ inline static int8_t kai_ext_sign_i8_i4(int8_t value) { /// Parameter struct for RHS matrix packing (Quantized Symmetric Integer 8-bit with per-channel quantization) struct kai_rhs_pack_qsi8cx_params { - int8_t lhs_zero_point; /**< LHS Matrix quantization zero-point */ + int32_t lhs_zero_point; ///< LHS Matrix quantization zero-point + float scale_multiplier; ///< Product of input (refers to lhs and rhs) and output quantization scales. }; /// Parameter struct for RHS matrix packing struct kai_rhs_pack_qs4cxs1s0_param { - int8_t lhs_zero_point; /**< LHS Matrix quantization zero-point */ - uint8_t rhs_zero_point; /**< RHS Matrix quantization zero-point */ -}; - -/// RHS packing parameter for 8-bit quantization. -struct kai_rhs_pack_qsi8_params { - int32_t lhs_zero_point; ///< LHS quantization zero point. - float scale_multiplier; ///< Product of input (refers to lhs and rhs) and output quantization scales. + int8_t lhs_zero_point; ///< LHS Matrix quantization zero-point + uint8_t rhs_zero_point; ///< RHS Matrix quantization zero-point }; /// Requantization and clamp parameters for GEMM/GEMV output stage. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c index c50ea6250bc5bf83759e08e569f504ff3eecdcaf..3241c578316f5504cd8544ad91d7ad0a07df59b2 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.c @@ -61,7 +61,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(siz void kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_qsi8_params* params) { + const struct kai_rhs_pack_qsi8cx_params* params) { KAI_ASSUME(num_groups == 1); KAI_ASSUME(nr == kai_nr * kai_get_sme_vector_length_u8() / kai_kr); KAI_ASSUME(kr == kai_kr); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h index effb8da9b9926b582f8e7a864cd195d8714dd5ac..2e97d72317a52b8b3500d30ea7536004f86abec0 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h @@ -84,7 +84,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(siz void kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_qsi8_params* params); + const struct kai_rhs_pack_qsi8cx_params* params); #ifdef __cplusplus } // extern "C" diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.c index bedf69d12d76ed7e5784b69301749d8390a62e6a..02604e89fe92a88098fe1ed959dd0787465f044d 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.c @@ -71,7 +71,7 @@ void kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon( KAI_ASSERT(rhs_packed != NULL); KAI_ASSERT(params != NULL); - const int32_t lhs_zero_point = (int32_t)params->lhs_zero_point; + const int32_t lhs_zero_point = params->lhs_zero_point; const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(k, nr, kr, sr); const size_t k_internal = kai_k_roundedup(k); const size_t dst_num_rows = kai_roundup(n, nr) / nr; diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c index a8f9a06e43ae51e1e56dda4855fec81636d2a2aa..8a73e03bf1ddc7f40e8f810c9ec3c016abf65474 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c @@ -70,7 +70,7 @@ void kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon( KAI_ASSERT(rhs_packed != NULL); KAI_ASSERT(params != NULL); - const int32_t lhs_zero_point = (int32_t)params->lhs_zero_point; + const int32_t lhs_zero_point = params->lhs_zero_point; const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(k, nr, kr, sr); const size_t k_internal = kai_k_roundedup(k); const size_t dst_num_rows = kai_roundup(n, nr) / nr; diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp index db3c530f386daf5640a448edfc12d2a8188ad86d..52b94f7438294a25e866bbb0fe6ecc15d99a1094 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi8cxp_test.cpp @@ -94,7 +94,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, EndToEnd_RHS_nxk_qsi8cx) { const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); std::vector imp_packed_rhs(imp_packed_rhs_size); - const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1}; + const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f}; kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon( 1, N, K, nr, kr, sr, reinterpret_cast(ref_rhs_qsi8.data()), reinterpret_cast(ref_biases.data()), reinterpret_cast(ref_rhs_scales.data()), @@ -183,7 +183,7 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi8cxp, EndToEnd_RHS_kxn_qsi8cx) { const auto imp_packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); std::vector imp_packed_rhs(imp_packed_rhs_size); - const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1}; + const kai_rhs_pack_qsi8cx_params params{.lhs_zero_point = 1, .scale_multiplier = 1.0f}; kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon( 1, N, K, nr, kr, sr, reinterpret_cast(ref_rhs_qsi8.data()), reinterpret_cast(ref_biases.data()), reinterpret_cast(ref_rhs_scales.data()), diff --git a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp index 985e81669ca4dab4cee049721b51952f12a50f15..e8011f4f3480a4bb7f08a4aeab3cb6f1124e3b43 100644 --- a/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp +++ b/test/tests/matmul_clamp_qai8_qai8p_qsi8cxp_test.cpp @@ -58,7 +58,7 @@ struct GemmVariant { void (*fn_pack_rhs_run)( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_qsi8_params* params); + const struct kai_rhs_pack_qsi8cx_params* params); size_t (*fn_main_get_m_step)(); size_t (*fn_main_get_n_step)(); @@ -295,7 +295,7 @@ void run_test(const GemmShape& shape, const GemmVariant& variant, const MatrixPo const auto imp_scale_offset = variant.fn_pack_rhs_get_scale_offset(output_area.start_col()); const auto imp_packed_rhs_offset = variant.fn_pack_rhs_get_packed_rhs_offset(output_area.start_col(), shape.k); - const kai_rhs_pack_qsi8_params imp_pack_rhs_params{ + const kai_rhs_pack_qsi8cx_params imp_pack_rhs_params{ .lhs_zero_point = lhs_zero_point, .scale_multiplier = lhs_scale / dst_scale, };