diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c index b25bfc0f38c3616b82aa8248de9b50200a61b111..26db6275e45f8b0b698149132e5de26708aceb34 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c @@ -32,25 +32,25 @@ size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t r return (n_idx / 2) * sizeof(int8_t); } -size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) { +size_t kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr) { const size_t k_internal = kai_k_roundedup(k, kr, sr); KAI_ASSERT((k_internal % 2) == 0); - return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); + return ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { KAI_ASSERT((n_idx % nr) == 0); - return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return n_idx * kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, kr, sr); } size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return num_rows * nr * kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, kr, sr); } void kai_run_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( @@ -68,7 +68,7 @@ void kai_run_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( KAI_ASSERT(params->lhs_zero_point == 1); const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + const size_t rhs_packed_stride = nr * kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, kr, sr); const size_t k_internal = kai_k_roundedup(k, kr, sr); const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h index 8cd1c7a345db2102a645f17014e1fe4feb5586c1..4760e75ff3e7869a8f35064ca5cc902345f17db6 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h @@ -39,15 +39,22 @@ size_t kai_get_n_step_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t nr); /// @return the offset in bytes to the RHS matrix (not packed) size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride); -/// Get the row stride in bytes to the packed RHS matrix +/// Get the N stride in bytes to the packed RHS matrix /// +/// @note This stride should not be confused with the stride of the packed RHS matrix. +/// The stride of the packed RHS matrix is: +/// rhs_packed_stride = nr * kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0() +/// +/// This stride can be used to calculate the offset in bytes for the N values +/// stored in the packed matrix, such as in the following example: +/// for n = 0; n < N; n+=nr +/// rhs_packed_offset = n * kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0() /// @param[in] k In the RHS matrix (not packed), K is the number of rows. -/// @param[in] nr The number of columns written by the matmul micro-kernel. /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the stride in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr); +size_t kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr); /// Gets the offset in bytes for the packed RHS matrix, which contains the packed 4-bit quantized symmetric per-channel /// (qsu4cx) values. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c index 359471879f9e26eac7a10f00f3215d2f10c13538..8470df35267a6e24c91a42b9ea3aa2883021a004 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c @@ -30,25 +30,25 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t r return n_idx * rhs_stride; } -size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) { +size_t kai_get_rhs_packed_stride_n_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr) { const size_t k_internal = kai_k_roundedup(k, kr, sr); KAI_ASSERT((k_internal % 2) == 0); - return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); + return ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { KAI_ASSERT((n_idx % nr) == 0); - return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return n_idx * kai_get_rhs_packed_stride_n_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, kr, sr); } size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return num_rows * nr * kai_get_rhs_packed_stride_n_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, kr, sr); } void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( @@ -66,7 +66,7 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( KAI_ASSERT(params->lhs_zero_point == 1); const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + const size_t rhs_packed_stride = nr * kai_get_rhs_packed_stride_n_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, kr, sr); const size_t k_internal = kai_k_roundedup(k, kr, sr); const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h index 4fc97ba967dbbfd58619e3d679b89f2b33fe486d..f3306f1b70f1fa464d110b4eec690404c25ac112 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h @@ -38,15 +38,22 @@ size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr); /// @return the offset in bytes to the RHS matrix (not packed) size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride); -/// Get the row stride in bytes to the packed RHS matrix +/// Get the N stride in bytes to the packed RHS matrix /// +/// @note This stride should not be confused with the stride of the packed RHS matrix. +/// The stride of the packed RHS matrix is: +/// rhs_packed_stride = nr * kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0() +/// +/// This stride can be used to calculate the offset in bytes for the N values +/// stored in the packed matrix, such as in the following example: +/// for n = 0; n < N; n+=nr +/// rhs_packed_offset = n * kai_get_rhs_packed_stride_n_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0() /// @param[in] k In the RHS matrix (not packed), K is the number of columns. -/// @param[in] nr The number of columns written by the matmul micro-kernel. /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the stride in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr); +size_t kai_get_rhs_packed_stride_n_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr); /// Gets the offset in bytes for the packed RHS matrix, which contains the packed 4-bit quantized symmetric per-channel /// (qsu4cx) values.