diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c index d7be3f447a04cbf35019100aa99fd70a73426d06..14a33795c7c1d0134fbbabea14296602d6be8c7c 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c @@ -22,7 +22,7 @@ static size_t kai_get_mr_lhs_pack_x16p2vlx2_x16_sme(void) { } size_t kai_get_m_step_lhs_pack_x16p2vlx2_x16_sme(size_t mr) { - KAI_ASSUME(mr == kai_mr * kai_get_sme_vector_length_u16()); + KAI_ASSUME(mr == kai_get_mr_lhs_pack_x16p2vlx2_x16_sme()); KAI_UNUSED(mr); return kai_get_mr_lhs_pack_x16p2vlx2_x16_sme(); diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.h b/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.h index 6cc275857a0fe70a3d0a77553e1bcb3ed2982680..08043053bb5c42b7b9f848d00fb691fb9916e4e5 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.h +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.h @@ -61,7 +61,7 @@ size_t kai_get_lhs_packed_size_lhs_pack_x16p2vlx2_x16_sme(size_t m, size_t k, si /// /// @param[in] m Number of rows of the unpacked LHS matrix. /// @param[in] k Common dimension between the LHS and RHS matrix. -/// @param[in] mr Block size in M dimension. It must be 2 * kai_get_sme_vector_length_u16(). +/// @param[in] mr Block size in M dimension. It must be kai_get_m_step_lhs_pack_x16p2vlx2_x16_sme(). /// @param[in] kr Block size in K dimension. It must be 2. /// @param[in] sr Number of kr splits. It must be 1. /// @param[in] m_idx_start Unused. Must be 0. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.h index 3309ad6f8693589fb247628257cde4c163bd32be..e5cbc7915ca1e83424625bfaedb1f9af1d315cc6 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.h @@ -69,7 +69,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(size_t n, siz /// @param[in] num_groups Number of groups. It must be 1. /// @param[in] n Number of columns of the output matrix. /// @param[in] k Common dimension between the LHS and RHS matrix. -/// @param[in] nr Block size in N dimension. It must be 2 * kai_get_sme_vector_length_u16(). +/// @param[in] nr Block size in N dimension. It must match kai_get_n_step_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(). /// @param[in] kr Block size in K dimension. It must be 2. /// @param[in] sr Number of kr splits. It must be 1. /// @param[in] rhs_stride Row stride in bytes of the RHS matrix.