diff --git a/CHANGELOG.md b/CHANGELOG.md index 1247ef29981e6b313615168eca23bec690e20b09..6795beff0f3486b20f974d5febce8ab9975de649 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo ## Upcoming Release +- Breaking changes: + - Add a boolean flag to `kai_rhs_pack_qs4cxs1s0_param` to identify nibble pairing order in the int4 input. + ## v1.6.0 - Add CMake installation and `find_package()` support. diff --git a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp index 53e64664a602565caa8d10643043a74f72d24416..1589154c83749c482731d2645dbb2c2f721bad8d 100644 --- a/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp +++ b/examples/matmul_clamp_f32_qai8dxp_qsi4cxp/matmul_clamp_f32_qai8dxp_qsi4cxp.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -637,6 +637,7 @@ int main(int argc, char** argv) { nxk_params.lhs_zero_point = 1; nxk_params.rhs_zero_point = 8; + nxk_params.is_nibble_order_reversed = false; // RHS packing kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( 1, n, k, nr, kr, sr, // Packing arguments @@ -649,6 +650,7 @@ int main(int argc, char** argv) { struct kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params kxn_params; kxn_params.lhs_zero_point = 1; kxn_params.rhs_zero_point = 8; + kxn_params.is_nibble_order_reversed = false; // RHS packing kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0( 1, n, k, nr, kr, sr, // Packing arguments diff --git a/examples/matmul_clamp_f32_qsi8d32p_qsi4c32p/matmul_clamp_f32_qsi8d32p_qsi4c32p.cpp b/examples/matmul_clamp_f32_qsi8d32p_qsi4c32p/matmul_clamp_f32_qsi8d32p_qsi4c32p.cpp index 6d992b625e4387f7d5e7230261764ce11119e545..d30231866495c9984983811860b2a1ee48f23d95 100644 --- a/examples/matmul_clamp_f32_qsi8d32p_qsi4c32p/matmul_clamp_f32_qsi8d32p_qsi4c32p.cpp +++ b/examples/matmul_clamp_f32_qsi8d32p_qsi4c32p/matmul_clamp_f32_qsi8d32p_qsi4c32p.cpp @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -364,6 +364,7 @@ int main(int argc, char** argv) { struct kai_rhs_pack_qs4cxs1s0_param params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; + params.is_nibble_order_reversed = false; // RHS packing kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( diff --git a/kai/kai_common.h b/kai/kai_common.h index c1cb1eca10d750308eb09945d4d71301a2977bb5..67cd5cda7fa4a0ad184aaffa7097ef10800bd9ea 100644 --- a/kai/kai_common.h +++ b/kai/kai_common.h @@ -9,6 +9,7 @@ #include #endif // defined(__ARM_NEON) +#include #include #include #include @@ -183,8 +184,10 @@ struct kai_rhs_pack_qsi8cx_params { /// Parameter struct for RHS matrix packing struct kai_rhs_pack_qs4cxs1s0_param { - int8_t lhs_zero_point; ///< LHS Matrix quantization zero-point - uint8_t rhs_zero_point; ///< RHS Matrix quantization zero-point + int8_t lhs_zero_point; ///< LHS Matrix quantization zero-point + uint8_t rhs_zero_point; ///< RHS Matrix quantization zero-point + bool is_nibble_order_reversed; /// Default order of int4 in a byte is specified in the kernel name. If this flag is + /// true, the nibble pairing is reversed in the input argument }; /// Requantization and clamp parameters for GEMM/GEMV output stage. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.c index 7c75764ecd35cded15ed9ceb81d1f18fb0b344da..6e9116df9fe2755f5fcdfdfb8a7d0b217f988dee 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -68,6 +68,7 @@ void kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0( KAI_ASSERT(params != NULL); KAI_ASSERT(params->lhs_zero_point == 1); KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); + KAI_ASSERT(params->is_nibble_order_reversed == false); const uint8_t rhs_zero_point = params->rhs_zero_point; const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c index 419f5cd0e62812e1d5000b3defac7d6507302ba0..f835bd6881724e82e2d02e35419f8edad116263a 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -115,6 +115,7 @@ void kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon( KAI_ASSUME(params != NULL); KAI_ASSUME(params->rhs_zero_point == 8); KAI_ASSUME(params->lhs_zero_point == 1); + KAI_ASSUME(params->is_nibble_order_reversed == false); // Note: The input matrix (rhs) is expected with: // "k" columns and "n" rows (NxK) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c index 34209acf18940619a05bcc6eaace897b4600199b..cb99d41672babc547c1a5e0f7ecafc2a06b78312 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c @@ -93,6 +93,7 @@ void kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( KAI_ASSUME(params != NULL); KAI_ASSUME(params->rhs_zero_point == 8); KAI_ASSUME(params->lhs_zero_point == 1); + KAI_ASSUME(params->is_nibble_order_reversed == false); // Note: The input matrix (rhs) is expected with: // "k" columns and "n" rows (NxK) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c index 898fa442b69a71847ffd42319b089ba972b67ca3..a2624d7dd3b0ac1c1242188c30146000348e4da5 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -65,6 +65,7 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( KAI_ASSERT(params != NULL); KAI_ASSERT(params->lhs_zero_point == 1); KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); + KAI_ASSERT(params->is_nibble_order_reversed == false); const uint8_t rhs_zero_point = params->rhs_zero_point; const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c index d0c66276369fbe141c8bd5c185563c3fb16afe89..e2e2426df150dc64ff733735001cbca09e8127dd 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.c @@ -74,6 +74,7 @@ void kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon( KAI_ASSERT(params != NULL); KAI_ASSERT(params->lhs_zero_point == 1); KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); + KAI_ASSERT(params->is_nibble_order_reversed == false); // Note: The input matrix (rhs) is expected with: // "k" columns and "n" rows (NxK) diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp index eaaa6b06acebd7e0382fdf75bc8772b3641a0d87..6a9fefb2d2d50c9ea38795c491be9667780725fa 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp @@ -355,7 +355,8 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsi4cx) { size_t scale_offset = rhs_start_row * sizeof(float); std::vector imp_packed_rhs(imp_packed_rhs_size); - const kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{.lhs_zero_point = 1, .rhs_zero_point = 0}; + const kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{ + .lhs_zero_point = 1, .rhs_zero_point = 0, .is_nibble_order_reversed = false}; ukernel_variant.run_rhs_pack( 1, rect.width() /* n */, K, nr, kr, sr, ref_rhs_qsi4_padded.data() + rhs_offset, reinterpret_cast(ref_biases.data() + bias_offset), @@ -480,7 +481,8 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_nxk_qsu4cx) { size_t scale_offset = rhs_start_row * sizeof(float); std::vector imp_packed_rhs(imp_packed_rhs_size); - const kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{.lhs_zero_point = 1, .rhs_zero_point = 8}; + const kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params{ + .lhs_zero_point = 1, .rhs_zero_point = 8, .is_nibble_order_reversed = false}; ukernel_variant.run_rhs_pack( 1, rect.width() /* n */, K, nr, kr, sr, ref_rhs_qsu4_padded.data() + rhs_offset, reinterpret_cast(ref_biases.data() + bias_offset), @@ -612,7 +614,8 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsi4cx) { ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); std::vector imp_packed_rhs(imp_packed_rhs_size); - const kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{.lhs_zero_point = 1, .rhs_zero_point = 0}; + const kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{ + .lhs_zero_point = 1, .rhs_zero_point = 0, .is_nibble_order_reversed = false}; ukernel_variant.run_rhs_pack( 1, N, K, nr, kr, sr, ref_rhs_qsi4_padded.data(), reinterpret_cast(ref_biases.data()), reinterpret_cast(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, ¶ms); @@ -743,7 +746,8 @@ TEST_P(MatMulTest_f32_qai8dxp_qsi4cxp, EndToEnd_RHS_kxn_qsu4cx) { ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); std::vector imp_packed_rhs(imp_packed_rhs_size); - const kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{.lhs_zero_point = 1, .rhs_zero_point = 8}; + const kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params params{ + .lhs_zero_point = 1, .rhs_zero_point = 8, .is_nibble_order_reversed = false}; ukernel_variant.run_rhs_pack( 1, N, K, nr, kr, sr, ref_rhs_qsu4_padded.data(), reinterpret_cast(ref_biases.data()), reinterpret_cast(ref_rhs_scales.data()), imp_packed_rhs.data(), 0, ¶ms); diff --git a/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp b/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp index 153decb53e1c839e68eafbadf67c60dcf9c387d5..35c753040658ce9b30b1897a332be852464795ca 100644 --- a/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp +++ b/test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp @@ -235,7 +235,8 @@ TEST_P(MatMulTest_f32_qsi8d32p_qsi4c32p, EndToEnd) { auto rhs_matmul_offset = ukernel_variant.ukernel.interface.get_rhs_packed_offset(rhs_start_row, K, bl); ASSERT_EQ(rhs_packed_offset, rhs_matmul_offset); - const kai_rhs_pack_qs4cxs1s0_param params{.lhs_zero_point = 1, .rhs_zero_point = 8}; + const kai_rhs_pack_qs4cxs1s0_param params{ + .lhs_zero_point = 1, .rhs_zero_point = 8, .is_nibble_order_reversed = false}; ukernel_variant.pack_interface.rhs_pack( 1, N, K, nr, kr, sr, bl, ref_rhs_qsu4_scale_f16.data(), nullptr, imp_packed_rhs.data(), 0, ¶ms);