From a8ebe2e39744b150285782bde763e1651938e6eb Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 23 Jan 2025 14:13:39 +0000 Subject: [PATCH 1/4] Fix for Int4 per-channel SME GEMM kernel failing with n > 64 In kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa: * Fix the offset calculation * Fix pointer increments in the matmul Add new shapes to unit tests, to test n > 64 Resolves: #KLEIDIAI-405, #COMPMID-7918 Signed-off-by: Anitha Raj --- ...8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c | 21 +++++++++---------- ..._qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c | 9 ++++---- .../matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp | 4 +++- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c index 9dc417dd..cee8d311 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c @@ -84,17 +84,17 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(v size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); - const size_t k_internal = kai_k_roundedup(k); + const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); - return m_idx * (k_internal + kai_num_bytes_offset_lhs + kai_num_bytes_multiplier_lhs); + return (m_idx / mr) * kai_get_lhs_packed_stride(k); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); - const size_t k_internal = kai_k_roundedup(k); + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); - return n_idx * ((k_internal / 2) + kai_num_bytes_sum_rhs + kai_num_bytes_multiplier_rhs + kai_num_bytes_bias_rhs); + return (n_idx / nr) * kai_get_rhs_packed_stride(k); } size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( @@ -123,8 +123,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( uint64_t lhs_stride = kai_get_lhs_packed_stride(k); uint64_t rhs_stride = kai_get_rhs_packed_stride(k); uint64_t m_blk = (uint64_t)kai_k_roundedup(k) * mr; - uint64_t n_blk = (uint64_t)kai_k_roundedup(k) * nr; - uint64_t dst_inc = mr * n; + uint64_t dst_inc = mr * dst_stride_row; float scalar_bounds[2] = {scalar_min, scalar_max}; /* --------------------------------------------------- @@ -254,15 +253,15 @@ void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( // N loop tail " add x8, x8, %[rhs_stride] \n" - " .inst 0x04295089 // ddvl x9, x9, #4 \n" - " addvl x13, x13, #-4 \n" + " .inst 0x04295089 // addvl x9, x9, #4 \n" + " sub x13, x13, %[nr] \n" " .inst 0x256d47f0 //whilelt pn8.h, xzr, x13, vlx2 \n" " b.mi 2b \n" // M loop tail " add x20, x20, %[lhs_stride] \n" - " add x19, x19, %[dst_inc], lsl #2 \n" - " addvl x12, x12, #-1 \n" + " add x19, x19, %[dst_inc] \n" + " sub x12, x12, %[mr] \n" " whilelt p0.s, xzr, x12 \n" " b.mi 1b \n" @@ -270,7 +269,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( " .inst 0xd503467f //smstop \n" : : [m] "r"(m), [n] "r"(n), [k] "r"(k), [lhs_stride] "r"(lhs_stride), [rhs_stride] "r"(rhs_stride), - [dst_stride_row] "r"(dst_stride_row), [lut] "r"(lut), [m_blk] "r"(m_blk), [n_blk] "r"(n_blk), + [dst_stride_row] "r"(dst_stride_row), [lut] "r"(lut), [m_blk] "r"(m_blk), [nr] "r"(nr), [mr] "r"(mr), [lhs] "r"(lhs_packed), [rhs] "r"(rhs_packed), [dst_inc] "r"(dst_inc), [scalar_bounds] "r"(scalar_bounds), [dst] "r"(dst) : "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "p0", "p2", "p8", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c index 66e7fd91..b8e2f832 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c @@ -28,17 +28,18 @@ static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); -static const size_t kai_num_bytes_bias_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias_rhs = sizeof(float); +static const size_t kai_k_multiple_of = 32; inline static size_t kai_k_roundedup(size_t k) { // Round up k to be a multiple of 32. - return kai_roundup(k, 32); + return kai_roundup(k, kai_k_multiple_of); } inline static size_t kai_get_lhs_packed_stride(size_t k) { const size_t k_internal = kai_k_roundedup(k); - KAI_ASSERT((k_internal % 32) == 0); + KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); return kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot() * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); @@ -47,7 +48,7 @@ inline static size_t kai_get_lhs_packed_stride(size_t k) { inline static size_t kai_get_rhs_packed_stride(size_t k) { const size_t k_internal = kai_k_roundedup(k); - KAI_ASSERT((k_internal % 32) == 0); + KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); return kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot() * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias_rhs); diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp index ae0415c2..1b3ac089 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp @@ -788,7 +788,9 @@ INSTANTIATE_TEST_SUITE_P( MatMulShape{15, 35, 65}, // MatMulShape{8, 32, 64}, // MatMulShape{15, 31, 45}, // - MatMulShape{1, 35, 65}), + MatMulShape{1, 35, 65}, // + MatMulShape{1, 128, 32}, // + MatMulShape{64, 128, 32}), testing::Values( MatrixPortion(0, 0, 1, 1), // Full matrix. MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. -- GitLab From 995b360b3d9541302c6109ce9d128961b675c35d Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Thu, 23 Jan 2025 17:39:28 +0000 Subject: [PATCH 2/4] Update changelog Signed-off-by: Anitha Raj --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index feb145f0..089cb4f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Support compiling the project with the above compilation options enabled. - Remove `-Werror` from default build flags as to not cause integration problems - Expose the rhs_packed_stride in the header file + - Fix offsets and mr/nr increment operation in QAI8DXP x QSI4CXP (MxN) SME micro-kernel ## v1.2.0 -- GitLab From c450097897bf4ea1efa9f8209fb83623ba11f21c Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Fri, 24 Jan 2025 16:49:51 +0000 Subject: [PATCH 3/4] Address review comment * Add more shapes to unit tests * Update changelog with kernel name and failure condition Signed-off-by: Anitha Raj --- CHANGELOG.md | 1 + .../matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 089cb4f4..9b6225d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Remove `-Werror` from default build flags as to not cause integration problems - Expose the rhs_packed_stride in the header file - Fix offsets and mr/nr increment operation in QAI8DXP x QSI4CXP (MxN) SME micro-kernel + - Fixes validation error when n > nr in kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa ## v1.2.0 diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp index 1b3ac089..c0018dd6 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp @@ -783,14 +783,16 @@ INSTANTIATE_TEST_SUITE_P( testing::Combine( testing::Range(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.size()), testing::Values( - MatMulShape{16, 32, 64}, // - MatMulShape{16, 32, 36}, // - MatMulShape{15, 35, 65}, // - MatMulShape{8, 32, 64}, // - MatMulShape{15, 31, 45}, // - MatMulShape{1, 35, 65}, // - MatMulShape{1, 128, 32}, // - MatMulShape{64, 128, 32}), + MatMulShape{16, 32, 64}, // + MatMulShape{16, 32, 36}, // + MatMulShape{15, 35, 65}, // + MatMulShape{8, 32, 64}, // + MatMulShape{15, 31, 45}, // + MatMulShape{1, 35, 65}, // + MatMulShape{1, 128, 32}, // + MatMulShape{64, 128, 32}, // + MatMulShape{1, 225, 55}, // + MatMulShape{125, 200, 56}), testing::Values( MatrixPortion(0, 0, 1, 1), // Full matrix. MatrixPortion(0, 0, 1, 0.25), // Leftmost portion. -- GitLab From 787f935eee12d2a05b010511d5beb5cb96b4a571 Mon Sep 17 00:00:00 2001 From: Anitha Raj Date: Mon, 27 Jan 2025 13:31:26 +0000 Subject: [PATCH 4/4] Update changelog Signed-off-by: Anitha Raj --- CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b6225d9..4bd9fbaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,8 +16,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Support compiling the project with the above compilation options enabled. - Remove `-Werror` from default build flags as to not cause integration problems - Expose the rhs_packed_stride in the header file - - Fix offsets and mr/nr increment operation in QAI8DXP x QSI4CXP (MxN) SME micro-kernel - - Fixes validation error when n > nr in kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa + - Fix validation error when n > nr in kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa ## v1.2.0 -- GitLab