From d492f1d9039900fa85cfcd589638048e942010be Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 25 Jul 2024 16:16:54 +0100 Subject: [PATCH 1/3] Change convention for the stride of the rhs_packed matrix Before the rhs_packed_stride depended on the nr value. With this patch, we remove this dependency so this stride can be used to calculate the offset for the N values stored in the packed matrix without dividing N by nr Resolves COMPMID-7319 Signed-off-by: Gian Marco Iodice --- .../matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c | 8 ++++---- .../matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h | 4 ++++ .../matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c | 8 ++++---- .../matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h | 4 ++++ 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c index b25bfc0f..d61dd7e6 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c @@ -37,20 +37,20 @@ size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t k, size_ KAI_ASSERT((k_internal % 2) == 0); - return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); + return ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { KAI_ASSERT((n_idx % nr) == 0); - return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return n_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); } size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return num_rows * nr * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); } void kai_run_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( @@ -68,7 +68,7 @@ void kai_run_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( KAI_ASSERT(params->lhs_zero_point == 1); const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + const size_t rhs_packed_stride = nr * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); const size_t k_internal = kai_k_roundedup(k, kr, sr); const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h index 8cd1c7a3..bb7c0782 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h @@ -41,6 +41,10 @@ size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t r /// Get the row stride in bytes to the packed RHS matrix /// +/// This stride can be used to calculate the offset in bytes for the N values +/// stored in the packed matrix, such as in the following example: +/// for n = 0; n < N; n+=nr +/// rhs_packed_offset = n * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0() /// @param[in] k In the RHS matrix (not packed), K is the number of rows. /// @param[in] nr The number of columns written by the matmul micro-kernel. /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c index 35947187..722b68e6 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c @@ -35,20 +35,20 @@ size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_ KAI_ASSERT((k_internal % 2) == 0); - return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); + return ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { KAI_ASSERT((n_idx % nr) == 0); - return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return n_idx * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); } size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return num_rows * nr * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); } void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( @@ -66,7 +66,7 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( KAI_ASSERT(params->lhs_zero_point == 1); const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + const size_t rhs_packed_stride = nr * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); const size_t k_internal = kai_k_roundedup(k, kr, sr); const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h index 4fc97ba9..ad688797 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h @@ -40,6 +40,10 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t r /// Get the row stride in bytes to the packed RHS matrix /// +/// This stride can be used to calculate the offset in bytes for the N values +/// stored in the packed matrix, such as in the following example: +/// for n = 0; n < N; n+=nr +/// rhs_packed_offset = n * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0() /// @param[in] k In the RHS matrix (not packed), K is the number of columns. /// @param[in] nr The number of columns written by the matmul micro-kernel. /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. -- GitLab From 9fccef0db1974a982667c6807e8bce42b648bd6a Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 25 Jul 2024 17:18:24 +0100 Subject: [PATCH 2/3] Remove nr from input argument list Signed-off-by: Gian Marco Iodice --- .../matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c | 8 ++++---- .../matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h | 3 +-- .../matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c | 8 ++++---- .../matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h | 3 +-- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c index d61dd7e6..ff2a40c7 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c @@ -32,7 +32,7 @@ size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t r return (n_idx / 2) * sizeof(int8_t); } -size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) { +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr) { const size_t k_internal = kai_k_roundedup(k, kr, sr); KAI_ASSERT((k_internal % 2) == 0); @@ -44,13 +44,13 @@ size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { KAI_ASSERT((n_idx % nr) == 0); - return n_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return n_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, kr, sr); } size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * nr * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return num_rows * nr * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, kr, sr); } void kai_run_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( @@ -68,7 +68,7 @@ void kai_run_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( KAI_ASSERT(params->lhs_zero_point == 1); const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = nr * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + const size_t rhs_packed_stride = nr * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, kr, sr); const size_t k_internal = kai_k_roundedup(k, kr, sr); const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h index bb7c0782..348091be 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h @@ -46,12 +46,11 @@ size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t r /// for n = 0; n < N; n+=nr /// rhs_packed_offset = n * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0() /// @param[in] k In the RHS matrix (not packed), K is the number of rows. -/// @param[in] nr The number of columns written by the matmul micro-kernel. /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the stride in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr); +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr); /// Gets the offset in bytes for the packed RHS matrix, which contains the packed 4-bit quantized symmetric per-channel /// (qsu4cx) values. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c index 722b68e6..4e0620c2 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c @@ -30,7 +30,7 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t r return n_idx * rhs_stride; } -size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) { +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr) { const size_t k_internal = kai_k_roundedup(k, kr, sr); KAI_ASSERT((k_internal % 2) == 0); @@ -42,13 +42,13 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { KAI_ASSERT((n_idx % nr) == 0); - return n_idx * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return n_idx * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, kr, sr); } size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * nr * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return num_rows * nr * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, kr, sr); } void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( @@ -66,7 +66,7 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( KAI_ASSERT(params->lhs_zero_point == 1); const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = nr * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + const size_t rhs_packed_stride = nr * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, kr, sr); const size_t k_internal = kai_k_roundedup(k, kr, sr); const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h index ad688797..44b0b538 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h @@ -45,12 +45,11 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t r /// for n = 0; n < N; n+=nr /// rhs_packed_offset = n * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0() /// @param[in] k In the RHS matrix (not packed), K is the number of columns. -/// @param[in] nr The number of columns written by the matmul micro-kernel. /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the stride in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr); +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr); /// Gets the offset in bytes for the packed RHS matrix, which contains the packed 4-bit quantized symmetric per-channel /// (qsu4cx) values. -- GitLab From 72e5b39d320037ebb6ab059a06f5c680dfc8794f Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 26 Jul 2024 12:27:22 +0100 Subject: [PATCH 3/3] Rename to stride_n - The name change was required to reflect the meaning of the bytes returned by the method Signed-off-by: Gian Marco Iodice --- .../matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c | 8 ++++---- .../matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h | 10 +++++++--- .../matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c | 8 ++++---- .../matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h | 10 +++++++--- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c index ff2a40c7..26db6275 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c @@ -32,7 +32,7 @@ size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t r return (n_idx / 2) * sizeof(int8_t); } -size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr) { +size_t kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr) { const size_t k_internal = kai_k_roundedup(k, kr, sr); KAI_ASSERT((k_internal % 2) == 0); @@ -44,13 +44,13 @@ size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { KAI_ASSERT((n_idx % nr) == 0); - return n_idx * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, kr, sr); + return n_idx * kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, kr, sr); } size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * nr * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, kr, sr); + return num_rows * nr * kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, kr, sr); } void kai_run_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( @@ -68,7 +68,7 @@ void kai_run_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0( KAI_ASSERT(params->lhs_zero_point == 1); const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = nr * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, kr, sr); + const size_t rhs_packed_stride = nr * kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(k, kr, sr); const size_t k_internal = kai_k_roundedup(k, kr, sr); const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h index 348091be..4760e75f 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h @@ -39,18 +39,22 @@ size_t kai_get_n_step_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t nr); /// @return the offset in bytes to the RHS matrix (not packed) size_t kai_get_rhs_offset_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride); -/// Get the row stride in bytes to the packed RHS matrix +/// Get the N stride in bytes to the packed RHS matrix +/// +/// @note This stride should not be confused with the stride of the packed RHS matrix. +/// The stride of the packed RHS matrix is: +/// rhs_packed_stride = nr * kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0() /// /// This stride can be used to calculate the offset in bytes for the N values /// stored in the packed matrix, such as in the following example: /// for n = 0; n < N; n+=nr -/// rhs_packed_offset = n * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0() +/// rhs_packed_offset = n * kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0() /// @param[in] k In the RHS matrix (not packed), K is the number of rows. /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the stride in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr); +size_t kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr); /// Gets the offset in bytes for the packed RHS matrix, which contains the packed 4-bit quantized symmetric per-channel /// (qsu4cx) values. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c index 4e0620c2..8470df35 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c @@ -30,7 +30,7 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t r return n_idx * rhs_stride; } -size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr) { +size_t kai_get_rhs_packed_stride_n_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr) { const size_t k_internal = kai_k_roundedup(k, kr, sr); KAI_ASSERT((k_internal % 2) == 0); @@ -42,13 +42,13 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { KAI_ASSERT((n_idx % nr) == 0); - return n_idx * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, kr, sr); + return n_idx * kai_get_rhs_packed_stride_n_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, kr, sr); } size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * nr * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, kr, sr); + return num_rows * nr * kai_get_rhs_packed_stride_n_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, kr, sr); } void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( @@ -66,7 +66,7 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( KAI_ASSERT(params->lhs_zero_point == 1); const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = nr * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, kr, sr); + const size_t rhs_packed_stride = nr * kai_get_rhs_packed_stride_n_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, kr, sr); const size_t k_internal = kai_k_roundedup(k, kr, sr); const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h index 44b0b538..f3306f1b 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h @@ -38,18 +38,22 @@ size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr); /// @return the offset in bytes to the RHS matrix (not packed) size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride); -/// Get the row stride in bytes to the packed RHS matrix +/// Get the N stride in bytes to the packed RHS matrix +/// +/// @note This stride should not be confused with the stride of the packed RHS matrix. +/// The stride of the packed RHS matrix is: +/// rhs_packed_stride = nr * kai_get_rhs_packed_stride_n_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0() /// /// This stride can be used to calculate the offset in bytes for the N values /// stored in the packed matrix, such as in the following example: /// for n = 0; n < N; n+=nr -/// rhs_packed_offset = n * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0() +/// rhs_packed_offset = n * kai_get_rhs_packed_stride_n_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0() /// @param[in] k In the RHS matrix (not packed), K is the number of columns. /// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the stride in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr); +size_t kai_get_rhs_packed_stride_n_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t kr, size_t sr); /// Gets the offset in bytes for the packed RHS matrix, which contains the packed 4-bit quantized symmetric per-channel /// (qsu4cx) values. -- GitLab