diff --git a/CHANGELOG.md b/CHANGELOG.md index c662ed3cf83c398a3de3ba5aaf7d4e920e8d4d8e..80be6e69f082c6082d6e5b941893a7c7db349a00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo - Optimizations for FEAT_I8MM. - Fixes: - Remove "-Weffc++" from build flags + - Fix out-of-bound read from LHS packed matrix in `kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa`. ## v1.4.0 diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c index 66db8dc7c83d2ccf23d455b705824149104d0446..acfa64542915124d1e8e721f342111bab16fc6b4 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c @@ -151,16 +151,26 @@ void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( return; } - const size_t lhs_packed_stride = kai_get_lhs_packed_stride(k, bl); - const size_t rhs_packed_stride = kai_get_rhs_packed_stride(k, bl); + typedef struct { + size_t lhs_packed_stride; + size_t rhs_packed_stride; + size_t mr; + } KernelArgs; + + KernelArgs ka; + const size_t num_blocks = kai_get_num_blocks_per_row(k, bl); const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); - const uint16_t* lhs_scales = (const uint16_t*)((const int8_t*)lhs_packed + lhs_packed_stride - + ka.mr = mr; + ka.lhs_packed_stride = kai_get_lhs_packed_stride(k, bl); + ka.rhs_packed_stride = kai_get_rhs_packed_stride(k, bl); + + const uint16_t* lhs_scales = (const uint16_t*)((const int8_t*)lhs_packed + ka.lhs_packed_stride - (mr * num_blocks) * kai_num_bytes_multiplier_lhs); - const uint16_t* rhs_scales = (const uint16_t*)((const uint8_t*)rhs_packed + rhs_packed_stride - + const uint16_t* rhs_scales = (const uint16_t*)((const uint8_t*)rhs_packed + ka.rhs_packed_stride - (nr * num_blocks) * kai_num_bytes_multiplier_rhs); __asm__ volatile( @@ -174,6 +184,11 @@ void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( " ptrue p0.b, all \n" " .inst 0x25a07810 // ptrue pn8.s \n" + // Predicate for loading fp16 scaling factors + " ldr x5, [%x[args_ptr], %[offset_mr]]\n" + " lsl x5, x5, #1 \n" + " whilelt p4.b, xzr, x5 \n" + // Initialize ZT0 (Lookup table) " mov x6, %[lut]\n" " .inst 0xe11f80c0 // ldr zt0, [x6] \n" @@ -268,8 +283,8 @@ void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( // Copy destination pointer for store loop " mov x25, x24 \n" - // Load the fp16 scaling factors for the right matrix block - " ld1b {z16.b}, p0/z, [x23, x21] \n" + // Load the fp16 scaling factors for the left matrix block + " ld1b {z16.b}, p4/z, [x23, x21] \n" " inch x21, all \n" // Predicate for the selection of a scaling among the vector @@ -337,11 +352,13 @@ void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( // === End of the K loop === + " ldr x5, [%x[args_ptr], %[offset_stride_l]] \n" + // Increment pointer to the quantized values of the right matrix - " add x22, x22, %[stride_l] \n" + " add x22, x22, x5\n" // Increment pointer to the scaling factors of the right matrix - " add x23, x23, %[stride_l] \n" + " add x23, x23, x5 \n" // Update destination pointer " mov x24, x25 \n" @@ -357,8 +374,10 @@ void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( // Increment output pointer " incb %[dst], all, mul #4 \n" - " add x16, x16, %[stride_r] \n" - " add x17, x17, %[stride_r] \n" + " ldr x5, [%x[args_ptr], %[offset_stride_r]]\n" + + " add x16, x16, x5 \n" + " add x17, x17, x5 \n" // Increment N loop index " incb x8, all \n" @@ -375,11 +394,12 @@ void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( " .inst 0xd503467f // smstop \n" : [dst] "+r"(dst), [rhs_packed] "+r"(rhs_packed), [rhs_scales] "+r"(rhs_scales) : [M] "r"(m), [N] "r"(n), [K] "r"(k), [lhs_packed] "r"(lhs_packed), [lhs_scales] "r"(lhs_scales), - [stride] "r"(dst_stride_row), [lut] "r"(lut), [stride_l] "r"(lhs_packed_stride), - [stride_r] "r"(rhs_packed_stride) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0", - "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", - "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x6", + [stride] "r"(dst_stride_row), [lut] "r"(lut), [args_ptr] "r"(&ka), + [offset_stride_l] "I"(offsetof(KernelArgs, lhs_packed_stride)), + [offset_stride_r] "I"(offsetof(KernelArgs, rhs_packed_stride)), [offset_mr] "I"(offsetof(KernelArgs, mr)) + : "p0", "p1", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "z0", "z1", + "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", + "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31", "x0", "x5", "x6", "x8", "x9", "x10", "x11", "x12", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25", "memory", "cc"); }