diff --git a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c index 7e1a3adeeb6b47c488cdbeaaa44fb975e1b8737a..0d6e68d2351388bc0cc5763bccd2efe15527875d 100644 --- a/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.c @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. +#include "kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" + #include #include #include diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c index b1911981d3616348760e09c34587393905f4d813..b76754f22e05460a86f1ed5c58bc9446dbc4bcca 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c @@ -158,10 +158,10 @@ void kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa( const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa(); - const uint16_t* lhs_scales = - (uint16_t*)((const int8_t*)lhs_packed + lhs_packed_stride - (mr * num_blocks) * kai_num_bytes_multiplier_lhs); - const uint16_t* rhs_scales = - (uint16_t*)((const uint8_t*)rhs_packed + rhs_packed_stride - (nr * num_blocks) * kai_num_bytes_multiplier_rhs); + const uint16_t* lhs_scales = (const uint16_t*)((const uint8_t*)lhs_packed + lhs_packed_stride - + (mr * num_blocks) * kai_num_bytes_multiplier_lhs); + const uint16_t* rhs_scales = (const uint16_t*)((const uint8_t*)rhs_packed + rhs_packed_stride - + (nr * num_blocks) * kai_num_bytes_multiplier_rhs); __asm__ volatile( // Switch to streaming mode with ZA enabling diff --git a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c index 002b55fab6a966ca428053c082347c7befafbd71..9c27a588d98b3f709bbc331ecbeb6971c7ea13e4 100644 --- a/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c +++ b/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c @@ -161,10 +161,10 @@ void kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot( const size_t mr = kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(); const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot(); - const uint16_t* lhs_scales = - (uint16_t*)((const int8_t*)lhs_packed + lhs_packed_stride - (mr * num_blocks) * kai_num_bytes_multiplier_lhs); - const uint16_t* rhs_scales = - (uint16_t*)((const uint8_t*)rhs_packed + rhs_packed_stride - (nr * num_blocks) * kai_num_bytes_multiplier_rhs); + const uint16_t* lhs_scales = (const uint16_t*)((const uint8_t*)lhs_packed + lhs_packed_stride - + (mr * num_blocks) * kai_num_bytes_multiplier_lhs); + const uint16_t* rhs_scales = (const uint16_t*)((const uint8_t*)rhs_packed + rhs_packed_stride - + (nr * num_blocks) * kai_num_bytes_multiplier_rhs); __asm__ volatile( // Switch to streaming mode with ZA enabling diff --git a/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c b/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c index 4d6e4f00de85c7e87a6e1d5b8f0ab6cba5ef293c..902b11e79b7c3fe31f43ac6f6af47e0049313637 100644 --- a/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c +++ b/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. +#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" + #include #include diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c index 7698ce5dd4cb537ff2fc5ca017b7e6007a023b4e..272942dfcc0af00faac650c6c4a8eda939e7353c 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. +#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" + #include #include @@ -81,7 +83,7 @@ void kai_run_lhs_pack_bf16p2vlx2_f32_sme( void* out = (void*)((char*)lhs_packed + block_y * kai_roundup(k, kai_kr) * sizeof(uint16_t)); for (size_t y = 0; y < height; y++) { - in[y] = (void*)((char*)lhs + (block_y + y) * lhs_stride); + in[y] = (const void*)((const char*)lhs + (block_y + y) * lhs_stride); } __asm__ __volatile__( diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.c b/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.c index 723de695b88212ab5b2e8ff4207ffb2191967580..989c6e5f11c809916a96d1cde93e650f855682cf 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p8x4_f16_neon.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -9,6 +9,8 @@ #error This file must be compiled for AArch64, FEAT_BF16, FEAT_FP16. #else // Architectural features check. +#include "kai_lhs_pack_bf16p8x4_f16_neon.h" + #include #include @@ -69,7 +71,7 @@ void kai_run_lhs_pack_bf16p8x4_f16_neon( size_t width = k; for (size_t y = 0; y < height; y++) { - in[y] = (char*)lhs + (block_y + y) * lhs_stride; + in[y] = (const char*)lhs + (block_y + y) * lhs_stride; } __asm__ __volatile__( diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c index 97268a642dd91ce12f36b16114915597d4d99baf..8c2dd83e5cbee38a007ba50729f9ba8d2a95a968 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. +#include "kai_lhs_pack_f32p2vlx1_f32_sme.h" + #include #include diff --git a/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c b/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c index 6d467a9899c274f6747c8f595fa5b930e68f2c3e..ab131e84a9cf15f08a7e82e8f03f11d4cf3ebf57 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c +++ b/kai/ukernels/matmul/pack/kai_lhs_pack_x16p2vlx2_x16_sme.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. +#include "kai_lhs_pack_x16p2vlx2_x16_sme.h" + #include #include diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.c index a53b00885252da5f0389d9ab7f5f2b0ab3ceb4fe..22fb31607a8eaa721db4396278e6b4dd1950264c 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p1x4_f32_neon.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64, FEAT_BF16. #else // Architectural features check. +#include "kai_lhs_quant_pack_bf16p1x4_f32_neon.h" + #include #include #include @@ -64,7 +66,7 @@ void kai_run_lhs_quant_pack_bf16p1x4_f32_neon( KAI_ASSUME(m_idx_start == 0); - const float* lhs_ptr = (float*)(lhs); + const float* lhs_ptr = (const float*)(lhs); uint16_t* lhs_packed_ptr = (uint16_t*)(lhs_packed); // Unroll two 256-bit loops diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.c index 6022ac912205894a7ad6bbef8e85bd3f710a7af9..60862dc4752105e42de65eedb9de55553000ccce 100644 --- a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p8x4_f32_neon.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -10,6 +10,8 @@ #define MAX_MR 8 +#include "kai_lhs_quant_pack_bf16p8x4_f32_neon.h" + #include #include #include @@ -50,8 +52,8 @@ size_t kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon(size_t m, size_t } void kai_run_lhs_quant_pack_bf16p8x4_f32_neon( - size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, size_t lhs_stride, - uint16_t* lhs_packed) { + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, + void* lhs_packed) { KAI_ASSUME(mr == kai_mr); KAI_ASSUME(sr == kai_sr); KAI_ASSUME(kr == kai_kr); @@ -73,7 +75,7 @@ void kai_run_lhs_quant_pack_bf16p8x4_f32_neon( size_t width = k; for (size_t y = 0; y < height; y++) { - in[y] = (char*)lhs + (block_y + y) * lhs_stride; + in[y] = (const char*)lhs + (block_y + y) * lhs_stride; } __asm__ __volatile__( diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.c index cff2dd0f052905d4296c643017eb0ae7ed20af6a..fae15ccb3508c69083ecf22e86c685a3d57cb602 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.c @@ -69,7 +69,7 @@ void kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon( const void* in = rhs; void* out = rhs_packed; const size_t in_stride = rhs_stride; - const uint16_t* pad_row = (uint16_t*)rhs; + const uint16_t* pad_row = (const uint16_t*)rhs; // Fill zeros if bias is nullptr size_t bias_step = nr * sizeof(uint16_t); @@ -80,7 +80,7 @@ void kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon( bias_step = 0; } - const void* bias_ptr = bias == NULL ? (void*)zero_bias : (void*)bias; + const void* bias_ptr = bias == NULL ? (void*)zero_bias : (const void*)bias; size_t out_stride = kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p12x4biasf16_f16_neon(height); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.h b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.h index cd3863021696c92ee0f37fb61c8dc39d263ea7d5..27d033469bfc37fb357256ae637a8f131c16cb7e 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -41,6 +41,12 @@ size_t kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon(size_t n_idx); /// @return The offset in bytes to the data element. size_t kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon(size_t n_idx, size_t k); +/// Get the row stride in bytes to the packed RHS matrix +/// +/// @param[in] k Number of columns. +/// +/// @return The row stride in bytes to the packed RHS matrix. +size_t kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p12x4biasf16_f16_neon(size_t k); /// Gets the size in bytes of the packed RHS buffer. /// /// @param[in] n Number of rows. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.c index 2c4d5e5c9b7d8b1b64655cf5dacc77bcb5492234..db2915258635ae4c090c42befe6473474763c295 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -9,6 +9,8 @@ #error This file must be compiled for AArch64, FEAT_BF16, FEAT_FP16. #else // Architectural features check. +#include "kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.h" + #include #include #include @@ -61,7 +63,7 @@ void kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon( const void* in = rhs; void* out = rhs_packed; const size_t in_stride = rhs_stride; - uint16_t* pad_row = (uint16_t*)rhs; + const uint16_t* pad_row = (const uint16_t*)rhs; // Fill zeros if bias is nullptr size_t bias_step = nr * sizeof(float); @@ -72,7 +74,7 @@ void kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon( bias_step = 0; } - const void* bias_ptr = bias == NULL ? (void*)zero_bias : (void*)bias; + const void* bias_ptr = bias == NULL ? (void*)zero_bias : (const void*)bias; size_t out_stride = kai_nr * kai_roundup(height, kai_kr) * sizeof(uint16_t) + kai_nr * sizeof(uint32_t); diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c index 61b8ba480ec595d630909b915f749a3ad3be6d2b..d85533d9ee920b4290e831acb560d5ddc3e7e3b7 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64. #else // Architectural features check. +#include "kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h" + #include #include diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c index 5ce709fa256fba153637aaee3be76c991b64bb9f..55e4947460fb8493cd36b63f72c68c1c16f4755f 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. +#include "kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" + #include #include diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c index 0c6b0074b2c751c414112861a2864d7a34604dee..afa3d8b597cd510e9116933a14c26da1b62ddbb6 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64. #else // Architectural features check. +#include "kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.h" + #include #include diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c index b8fd3fdcada4f4312066395d1cf1022776d1dc22..29870b4302a4801ee8cc871063bb5668b7def98a 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. +#include "kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h" + #include #include diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h index 8bd3f6ee783104b625724cf524b9e471145ecbe2..bfd5c39385fe58eaf83868c614c7bf75fe81d467 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h @@ -42,6 +42,13 @@ size_t kai_get_bias_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_id /// @return The offset in bytes to the data element. size_t kai_get_rhs_packed_offset_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t n_idx, size_t k); +/// Get the row stride in bytes to the packed RHS matrix +/// +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// +/// @return the stride in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(size_t k); + /// Gets the size in bytes of the packed RHS buffer. /// /// @param[in] n Number of rows. diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c index 419f5cd0e62812e1d5000b3defac7d6507302ba0..cb32cdc2184b2f5107c4c6ddb9665f318f76b35f 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h index f06e736cb61ed64cdc6858127cc169111ad09b8b..b4da82ca09fcf2ab7dc6c18cdd1317675608bde9 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -50,6 +50,21 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( size_t kr, // size_t bl); // +/// Gets the size in bytes for the quantized and packed RHS matrix. +/// +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] bl The block length, which defines the number of K values stored in a single block. It must be a multiple +/// of 32. +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( + size_t k, // + size_t nr, // + size_t kr, // + size_t bl); // + /// Gets the size in bytes for the quantized and packed RHS matrix. /// /// @param[in] n The number of rows in the RHS matrix (not packed) diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c index 4a4714414fc7a27d6c2132bb2f913d02195c0823..670dda9da81340c5b3fb119ffc8668b1a593500b 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -8,6 +8,8 @@ #error This file must be compiled for AArch64, FEAT_SVE2. #else // Architectural features check. +#include "kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.h" + #include #include diff --git a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.c index 6711941834312eca4cccc9888427fb572c028043..d22136f068e316dd4000c80183dabc4cfe781210 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.c +++ b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.c @@ -1,5 +1,5 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // @@ -10,6 +10,8 @@ #define MAX_NR 12 +#include "kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.h" + #include #include #include @@ -21,9 +23,8 @@ static const size_t kai_nr = 12; static const size_t kai_kr = 4; static const size_t kai_sr = 1; -size_t kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(size_t nr) { - KAI_ASSUME(kai_nr == nr); - return nr; +size_t kai_get_n_step_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(void) { + return kai_nr; } size_t kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(size_t n_idx) { @@ -66,7 +67,7 @@ void kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon( size_t height = k; const size_t width = n; - const void* in = (void*)rhs; + const void* in = (const void*)rhs; void* out = rhs_packed; const size_t in_stride = rhs_stride; const float* pad_row = rhs; @@ -80,7 +81,7 @@ void kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon( bias_step = 0; } - const void* bias_ptr = bias == NULL ? (void*)zero_bias : (void*)bias; + const void* bias_ptr = bias == NULL ? (void*)zero_bias : (const void*)bias; const size_t out_stride = nr * kai_roundup(height, kr) * sizeof(uint16_t) + nr * sizeof(uint32_t);