diff --git a/CHANGELOG.md b/CHANGELOG.md index feb145f0a90df408e4ded633ed7ec0aa0b066eee..4bd9fbaf8fcd544eda295f9204d83bb52835e4ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Support compiling the project with the above compilation options enabled. - Remove `-Werror` from default build flags as to not cause integration problems - Expose the rhs_packed_stride in the header file + - Fix validation error when n > nr in kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa ## v1.2.0 diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c index 9dc417dd41efe5644e4b3db06ac54215ae3cae30..cee8d3119ddf110316d0bf5dfa74e86b3b36324d 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c @@ -84,17 +84,17 @@ size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(v size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m_idx, size_t k) { KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); - const size_t k_internal = kai_k_roundedup(k); + const size_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); - return m_idx * (k_internal + kai_num_bytes_offset_lhs + kai_num_bytes_multiplier_lhs); + return (m_idx / mr) * kai_get_lhs_packed_stride(k); } size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t n_idx, size_t k) { KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); - const size_t k_internal = kai_k_roundedup(k); + const size_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); - return n_idx * ((k_internal / 2) + kai_num_bytes_sum_rhs + kai_num_bytes_multiplier_rhs + kai_num_bytes_bias_rhs); + return (n_idx / nr) * kai_get_rhs_packed_stride(k); } size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( @@ -123,8 +123,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( uint64_t lhs_stride = kai_get_lhs_packed_stride(k); uint64_t rhs_stride = kai_get_rhs_packed_stride(k); uint64_t m_blk = (uint64_t)kai_k_roundedup(k) * mr; - uint64_t n_blk = (uint64_t)kai_k_roundedup(k) * nr; - uint64_t dst_inc = mr * n; + uint64_t dst_inc = mr * dst_stride_row; float scalar_bounds[2] = {scalar_min, scalar_max}; /* --------------------------------------------------- @@ -254,15 +253,15 @@ void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( // N loop tail " add x8, x8, %[rhs_stride] \n" - " .inst 0x04295089 // ddvl x9, x9, #4 \n" - " addvl x13, x13, #-4 \n" + " .inst 0x04295089 // addvl x9, x9, #4 \n" + " sub x13, x13, %[nr] \n" " .inst 0x256d47f0 //whilelt pn8.h, xzr, x13, vlx2 \n" " b.mi 2b \n" // M loop tail " add x20, x20, %[lhs_stride] \n" - " add x19, x19, %[dst_inc], lsl #2 \n" - " addvl x12, x12, #-1 \n" + " add x19, x19, %[dst_inc] \n" + " sub x12, x12, %[mr] \n" " whilelt p0.s, xzr, x12 \n" " b.mi 1b \n" @@ -270,7 +269,7 @@ void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( " .inst 0xd503467f //smstop \n" : : [m] "r"(m), [n] "r"(n), [k] "r"(k), [lhs_stride] "r"(lhs_stride), [rhs_stride] "r"(rhs_stride), - [dst_stride_row] "r"(dst_stride_row), [lut] "r"(lut), [m_blk] "r"(m_blk), [n_blk] "r"(n_blk), + [dst_stride_row] "r"(dst_stride_row), [lut] "r"(lut), [m_blk] "r"(m_blk), [nr] "r"(nr), [mr] "r"(mr), [lhs] "r"(lhs_packed), [rhs] "r"(rhs_packed), [dst_inc] "r"(dst_inc), [scalar_bounds] "r"(scalar_bounds), [dst] "r"(dst) : "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "p0", "p2", "p8", diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c index 66e7fd91eea56ef3db8596e303dd8f942b12c273..b8e2f832a6e19f99910fd6996d6c9365a7e37b98 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot.c @@ -28,17 +28,18 @@ static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); -static const size_t kai_num_bytes_bias_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias_rhs = sizeof(float); +static const size_t kai_k_multiple_of = 32; inline static size_t kai_k_roundedup(size_t k) { // Round up k to be a multiple of 32. - return kai_roundup(k, 32); + return kai_roundup(k, kai_k_multiple_of); } inline static size_t kai_get_lhs_packed_stride(size_t k) { const size_t k_internal = kai_k_roundedup(k); - KAI_ASSERT((k_internal % 32) == 0); + KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); return kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot() * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); @@ -47,7 +48,7 @@ inline static size_t kai_get_lhs_packed_stride(size_t k) { inline static size_t kai_get_rhs_packed_stride(size_t k) { const size_t k_internal = kai_k_roundedup(k); - KAI_ASSERT((k_internal % 32) == 0); + KAI_ASSERT((k_internal % kai_k_multiple_of) == 0); return kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot() * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias_rhs); diff --git a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp index ae0415c24ff7d61370205f30e07c758a8f25aa12..c0018dd600342fc7f25cac1a42ca72e99d19d96c 100644 --- a/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp +++ b/test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp @@ -783,12 +783,16 @@ INSTANTIATE_TEST_SUITE_P( testing::Combine( testing::Range(0, variants_kai_matmul_clamp_f32_qai8dxp_qsi4cxp.size()), testing::Values( - MatMulShape{16, 32, 64}, // - MatMulShape{16, 32, 36}, // - MatMulShape{15, 35, 65}, // - MatMulShape{8, 32, 64}, // - MatMulShape{15, 31, 45}, // - MatMulShape{1, 35, 65}), + MatMulShape{16, 32, 64}, // + MatMulShape{16, 32, 36}, // + MatMulShape{15, 35, 65}, // + MatMulShape{8, 32, 64}, // + MatMulShape{15, 31, 45}, // + MatMulShape{1, 35, 65}, // + MatMulShape{1, 128, 32}, // + MatMulShape{64, 128, 32}, // + MatMulShape{1, 225, 55}, // + MatMulShape{125, 200, 56}), testing::Values( MatrixPortion(0, 0, 1, 1), // Full matrix. MatrixPortion(0, 0, 1, 0.25), // Leftmost portion.