From e1001a0d483406ca0a31dad23bbd5d239f67e438 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Tue, 28 Jan 2025 17:23:09 +0100 Subject: [PATCH 1/2] Remove unused scale types from packing functions Remove fp32 and fp16 scale datatypes from - kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0 - kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 since they aren't being used and just creates more maintenance work. Signed-off-by: Jens Elofsson --- .../kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c | 25 ++++--------------- .../kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | 24 ++++-------------- 2 files changed, 10 insertions(+), 39 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c index 608de410..da771008 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c @@ -61,7 +61,7 @@ size_t kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); + KAI_ASSERT(scale_dt == kai_dt_bf16); KAI_UNUSED(kr); KAI_UNUSED(sr); @@ -86,7 +86,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); + KAI_ASSERT(scale_dt == kai_dt_bf16); return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); } @@ -103,7 +103,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); + KAI_ASSERT(scale_dt == kai_dt_bf16); const size_t num_rows = kai_roundup(n, nr) / nr; @@ -140,7 +140,7 @@ void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((kr % sr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(params->scale_dt == kai_dt_f32 || params->scale_dt == kai_dt_f16 || params->scale_dt == kai_dt_bf16); + KAI_ASSERT(params->scale_dt == kai_dt_bf16); // Note: The input matrix (rhs) is expected with: // "n" columns and "k" rows (kxn) @@ -158,7 +158,6 @@ void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( const size_t block_length_in_bytes = kr / sr; const int32_t rhs_zero_point = params->rhs_zero_point; - const enum kai_datatype scale_dt = params->scale_dt; for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { // Before packing, it keeps the pointer to the first quantized block @@ -198,21 +197,7 @@ void kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0( // Clamp the index to avoid out-of-bound reads const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); - float d = 0.0F; - switch (scale_dt) { - case kai_dt_f32: - d = ((float*)rhs_packed_scale)[nr_idx]; - break; - case kai_dt_f16: - d = kai_cast_f32_f16(((uint16_t*)rhs_packed_scale)[nr_idx]); - break; - case kai_dt_bf16: - d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); - break; - default: - KAI_ERROR("Unsupported scale data type"); - break; - } + float d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); const size_t k_adjustment = ((super_kr_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index 5a86b8a0..1966d765 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -59,7 +59,7 @@ size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); + KAI_ASSERT(scale_dt == kai_dt_bf16); KAI_UNUSED(kr); KAI_UNUSED(sr); @@ -84,7 +84,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); + KAI_ASSERT(scale_dt == kai_dt_bf16); return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(k, nr, kr, sr, bl, scale_dt); } @@ -101,7 +101,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((bl % kr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(scale_dt == kai_dt_f32 || scale_dt == kai_dt_f16 || scale_dt == kai_dt_bf16); + KAI_ASSERT(scale_dt == kai_dt_bf16); const size_t num_rows = kai_roundup(n, nr) / nr; @@ -138,7 +138,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( KAI_ASSERT((kr % sr) == 0); KAI_ASSERT((nr % kai_nr_multiple_of) == 0); KAI_ASSERT((bl % kai_bl_multiple_of) == 0); - KAI_ASSERT(params->scale_dt == kai_dt_f32 || params->scale_dt == kai_dt_f16 || params->scale_dt == kai_dt_bf16); + KAI_ASSERT(params->scale_dt == kai_dt_bf16); // Note: The input matrix (rhs) is expected with: // "k" columns and "n" rows (NxK) @@ -191,21 +191,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( // Clamp the index to avoid out-of-bound reads const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); - float d = 0.0F; - switch (scale_dt) { - case kai_dt_f32: - d = ((float*)rhs_packed_scale)[nr_idx]; - break; - case kai_dt_f16: - d = kai_cast_f32_f16(((uint16_t*)rhs_packed_scale)[nr_idx]); - break; - case kai_dt_bf16: - d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); - break; - default: - KAI_ERROR("Unsupported scale data type"); - break; - } + float d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); int32_t partial_sum = 0; -- GitLab From 35d4f5471196eee2868ae50db0f8265ea6977910 Mon Sep 17 00:00:00 2001 From: Jens Elofsson Date: Tue, 28 Jan 2025 17:51:04 +0100 Subject: [PATCH 2/2] Address review comments Add 2025 to copyright header. Signed-off-by: Jens Elofsson --- .../matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c | 2 +- .../matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c index da771008..b292045a 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index 1966d765..9af7841d 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // -- GitLab