diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.c index 462218746eb8ee82736200714953e2233f69f463..7c75764ecd35cded15ed9ceb81d1f18fb0b344da 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.c @@ -69,7 +69,7 @@ void kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0( KAI_ASSERT(params->lhs_zero_point == 1); KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); - const size_t rhs_zero_point = params->rhs_zero_point; + const uint8_t rhs_zero_point = params->rhs_zero_point; const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); const size_t k_internal = kai_k_roundedup(k); const size_t dst_num_rows = kai_roundup(n, nr) / nr; @@ -106,7 +106,7 @@ void kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0( const size_t shift_right_x0 = (n0_idx % 2) * 4; - if (params->rhs_zero_point == 8) { + if (rhs_zero_point == 8) { uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c index 19887d096e9e0a9f3e8c9650b8c143dfa5b75bbe..898fa442b69a71847ffd42319b089ba972b67ca3 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c @@ -66,7 +66,7 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( KAI_ASSERT(params->lhs_zero_point == 1); KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); - const size_t rhs_zero_point = params->rhs_zero_point; + const uint8_t rhs_zero_point = params->rhs_zero_point; const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); const size_t k_internal = kai_k_roundedup(k); const size_t dst_num_rows = kai_roundup(n, nr) / nr; @@ -104,7 +104,7 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( const size_t shift_right_x0 = (k0_idx % 2) * 4; const size_t shift_right_x1 = (k1_idx % 2) * 4; - if (params->rhs_zero_point == 8) { + if (rhs_zero_point == 8) { uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4;