diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa_asm.S index 43f12f1dd19d97c7d0c37996fb9aa3640aa02460..a0e70c2071f460991964accf9b8fb853c016ee7f 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa_asm.S @@ -51,11 +51,15 @@ KAI_ASM_FUNCTION_LABEL(kai_f16_from_float_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_ KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa) KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA mov x14, #0x0 ldr x13, [x0, #0x30] @@ -221,7 +225,11 @@ KAI_ASM_LABEL(label_10) // Store to output array: Accumulator loop ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa) diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h index 7130c4c3b77fd3895cac92a7917d2cddf1f646f1..aea451d61b6a8c316b6366fce8a1e354d0bb8f6f 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h @@ -15,7 +15,8 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_pack_f32p2vlx1_f32_sme to pack the LHS matrix. -/// -# kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme to pack the RHS matrix. +/// -# kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme or kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme to pack the RHS +/// matrix. /// Gets m step value. /// @@ -109,7 +110,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(s /// @param[in] rhs_packed Packed RHS matrix buffer. /// @param[out] dst Output matrix buffer. /// @param[in] dst_stride_row Row stride in bytes of the output matrix. -/// @param[in] dst_stride_col Column stride in bytes of the output matrix. Must be 4 +/// @param[in] dst_stride_col Column stride in bytes of the output matrix. Currently, an unused parameter. /// @param[in] clamp_min Minimum value to clamp the final result. /// @param[in] clamp_max Maximum value to clamp the final result. void kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( diff --git a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_asm.S b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_asm.S index 82069d8489b081dd665ff7da1eb05a5c804fd970..e39b91734c58dcc81ee2240e59b298cda1da6286 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_asm.S +++ b/kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa_asm.S @@ -42,11 +42,15 @@ KAI_ASM_FUNCTION_TYPE(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa) KAI_ASM_FUNCTION_LABEL(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa) - stp x20, x21, [sp, -80]! + stp x20, x21, [sp, -144]! stp x22, x23, [sp, 16] stp x24, x25, [sp, 32] stp x26, x27, [sp, 48] str x28, [sp, 64] + stp d8, d9, [sp, 72] + stp d10, d11, [sp, 88] + stp d12, d13, [sp, 104] + stp d14, d15, [sp, 120] KAI_ASM_INST(0xd503477f) // SMSTART ZA mov x15, #0x0 ptrue p2.b @@ -348,7 +352,11 @@ KAI_ASM_LABEL(label_16) // Store to output array: End ldp x24, x25, [sp, 32] ldp x26, x27, [sp, 48] ldr x28, [sp, 64] - ldp x20, x21, [sp], 80 + ldp d8, d9, [sp, 72] + ldp d10, d11, [sp, 88] + ldp d12, d13, [sp, 104] + ldp d14, d15, [sp, 120] + ldp x20, x21, [sp], 144 ret KAI_ASM_FUNCTION_END(kai_kernel_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa)