diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index b9c625c7332a229f0e2d9e51c0df152c7f40de07..e8b0b2f352bedf71aa3b39cd03046bcc6c847d60 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -108,6 +108,7 @@ build-examples:
matrix:
- EXAMPLE:
- matmul_clamp_f16_f16_f16p
+ - matmul_clamp_f32_bf16p_bf16p
- matmul_clamp_f32_qai8dxp_qsi4cxp
- matmul_clamp_f32_qsi8d32p_qsi4c32p
- matmul_clamp_f32_qai8dxp_qsi4c32p
@@ -130,6 +131,7 @@ test-examples:
matrix:
- EXAMPLE:
- matmul_clamp_f16_f16_f16p
+ - matmul_clamp_f32_bf16p_bf16p
- matmul_clamp_f32_qai8dxp_qsi4cxp
- matmul_clamp_f32_qsi8d32p_qsi4c32p
- matmul_clamp_f32_qai8dxp_qsi4c32p
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 642fce8d016c1163d42bc1b427e98102f37d7bae..4d846759531622d4d0962a8b8690fed45239e9d4 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -88,6 +88,12 @@ set(KLEIDIAI_FILES_NEON_FP16
kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c
)
+set(KLEIDIAI_FILES_NEON_BF16
+ kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c
+ kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c
+ kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c
+)
+
set(KLEIDIAI_FILES_NEON
kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c
kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c
@@ -137,6 +143,7 @@ target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SCALAR})
if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND NOT MSVC)
target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON})
target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_FP16})
+ target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_BF16})
target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD})
target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_NEON_I8MM})
target_sources(kleidiai PRIVATE ${KLEIDIAI_FILES_SME})
@@ -145,6 +152,7 @@ if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL
set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH})
set_source_files_properties(${KLEIDIAI_FILES_NEON} PROPERTIES COMPILE_OPTIONS -march=armv8-a${KLEIDIAI_INTERNAL_EXTRA_ARCH})
set_source_files_properties(${KLEIDIAI_FILES_NEON_FP16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+fp16${KLEIDIAI_INTERNAL_EXTRA_ARCH})
+ set_source_files_properties(${KLEIDIAI_FILES_NEON_BF16} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+bf16${KLEIDIAI_INTERNAL_EXTRA_ARCH})
set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod${KLEIDIAI_INTERNAL_EXTRA_ARCH})
set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm${KLEIDIAI_INTERNAL_EXTRA_ARCH})
set_source_files_properties(${KLEIDIAI_FILES_SME} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+sve+sve2${KLEIDIAI_INTERNAL_EXTRA_ARCH})
@@ -170,6 +178,7 @@ if(KLEIDIAI_BUILD_TESTS)
test/common/printer.cpp
test/common/int4.cpp
test/common/compare.cpp
+ test/common/matmul_test_common.cpp
test/common/matrix_portion.cpp
test/common/rect.cpp
test/common/round.cpp
@@ -205,6 +214,7 @@ if(KLEIDIAI_BUILD_TESTS)
test/tests/matmul_clamp_f32_qai8dxp_qsi4cxp_test.cpp
test/tests/matmul_clamp_f32_qsi8d32p_qsi4c32p_test.cpp
test/tests/matmul_clamp_f32_qai8dxp_qsi4c32p_test.cpp
+ test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp
)
target_link_libraries(kleidiai_test
diff --git a/README.md b/README.md
index a139d11b1bd7cfc3d05746b4cd42d32ede52482c..2d963d480dc16346ea1ca854c5523f635636288d 100644
--- a/README.md
+++ b/README.md
@@ -101,6 +101,7 @@ Some of the data types currently supported with the KleidiAI library are the fol
|---------------------------------------------------------------------------------------------------------------------| ----------- | ----------- |
| Floating-point 32-bit | f32 | |
| Floating-point 16-bit | f16 | |
+| Brain Floating-point 16-bit | bf16 | |
| Quantized (q) Symmetric (s) Signed (i) 4-bit (4) Per-Channel (cx) quantization parameters | qsi4cx | An fp32 multiplier shared among all values of the same channel. `x` denotes the entirety of the channel |
| Quantized (q) Asymmetric (a) Signed (i) 8-bit (8) Per-Dimension (dx) (for example, Per-Row) quantization parameters | qai8dx | An fp32 multiplier and a int32 zero offset shared among all values of the same dimension. |
@@ -177,6 +178,20 @@ Some of the data types currently supported with the KleidiAI library are the fol
+
+ Matrix-multiplication with LHS packed and RHS packed matrices |
+ matmul_clamp_f32_bf16p_bf16p |
+
+ LHS: bf16p
+ RHS: bf16p
+ DST: f32
+ |
+
+ |
+
+ The packing function for the RHS and Lhs matrices is listed in the header file of the GEMM micro kernel.
+ |
+
How to build
diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4b13a183a140c6ec1179ece1befa9c029010f318
--- /dev/null
+++ b/examples/matmul_clamp_f32_bf16p_bf16p/CMakeLists.txt
@@ -0,0 +1,36 @@
+#
+# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+
+cmake_minimum_required(VERSION 3.16)
+
+project(KleidiAI)
+
+set(CMAKE_CXX_STANDARD 17)
+set(KLEIDIAI_PATH ../../)
+set(MATMUL_PACK_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/pack/)
+set(MATMUL_PATH ${KLEIDIAI_PATH}/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/)
+
+# KleidiAI include directories
+include_directories(
+ ${KLEIDIAI_PATH}
+ ${MATMUL_PACK_PATH}
+ ${MATMUL_PATH})
+
+# Files requires to build the executable
+add_executable(matmul_clamp_f32_bf16p_bf16p
+ matmul_clamp_f32_bf16p_bf16p.cpp
+ ${MATMUL_PATH}/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c
+ ${MATMUL_PACK_PATH}/kai_lhs_quant_pack_bf16p_f32_neon.c
+ ${MATMUL_PACK_PATH}/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c
+)
+
+target_compile_options(matmul_clamp_f32_bf16p_bf16p
+ PRIVATE -march=armv8.2-a+bf16
+)
+
+target_compile_definitions(matmul_clamp_f32_bf16p_bf16p
+ PRIVATE $<$:KAI_DEBUG>
+)
diff --git a/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f11900c3a78dce544f1e4b457ae7e2d4c1fb1bcb
--- /dev/null
+++ b/examples/matmul_clamp_f32_bf16p_bf16p/matmul_clamp_f32_bf16p_bf16p.cpp
@@ -0,0 +1,321 @@
+//
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+// Example usage for matrix multiplication of two half-precision brain floating-point (BF16) matrices
+// and the accumulation of the result into an FP32 destination matrix.
+//
+// The activations and the weights, stored in the LHS and RHS matrices respectively, are both non-transposed matrices.
+// The matrix multiplication computation is performed using BF16 matrix multiply (BFMMLA)
+// vector instructions present in the FEAT_BF16 ArmĀ® architecture feature.
+//
+#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_SCALAR_ARITHMETIC) || \
+ !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)
+#error This file must be compiled for AArch64, FEAT_BF16.
+#else
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+// Include micro-kernel variants
+#include "kai/kai_common.h"
+#include "kai_lhs_quant_pack_bf16p_f32_neon.h"
+#include "kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h"
+#include "kai_matmul_clamp_f32_bf16p_bf16p_interface.h"
+#include "kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h"
+
+inline static float bf16_to_float(const uint16_t* v) {
+ const uint16_t uint_rep = *v;
+ return kai_cast_f32_bf16(uint_rep);
+}
+
+namespace {
+/// Micro-kernel interface
+constexpr kai_matmul_clamp_f32_bf16p_bf16p_ukernel ukernel{
+ kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla,
+ kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla,
+ kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla,
+ kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla,
+ kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla,
+ kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla,
+ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla,
+ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla,
+ kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla,
+ kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla,
+ kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla};
+
+/// @brief Truncate the 32-bit floating point number's least significant 16 mantissa bits
+/// @param x floating-point number
+/// @return truncated floating-point number
+inline static float truncate(float x) {
+ uint32_t uval = (*reinterpret_cast(&x) & 0xffff0000);
+ return *reinterpret_cast(&uval);
+}
+
+/// Reference implementation of matrix multiplication
+static void run_matmul_ref(
+ size_t m, size_t n, size_t k, const float* lhs, const float* rhs, const float* bias, float* dst, float scalar_min,
+ float scalar_max) {
+ for (size_t row_idx = 0; row_idx < m; ++row_idx) {
+ for (size_t col_idx = 0; col_idx < n; ++col_idx) {
+ float acc = bias[col_idx];
+
+ for (size_t k_idx = 0; k_idx < k; ++k_idx) {
+ float lhs_val = truncate(lhs[row_idx * k + k_idx]);
+ float rhs_val = truncate(rhs[col_idx + n * k_idx]);
+
+ acc += lhs_val * rhs_val;
+ }
+
+ dst[row_idx * n + col_idx] = std::clamp(acc, scalar_min, scalar_max);
+ }
+ }
+}
+
+/// Fills the matrix with incremental values
+void fill_matrix(size_t num_rows, size_t num_cols, float* dst, const float weight) {
+ for (size_t i = 0; i < num_rows * num_cols; i++) {
+ dst[i] = float((i + 1) * weight);
+ }
+}
+
+/// Print the matrix
+void print_matrix(size_t num_rows, size_t num_cols, const char* name, const float* src) {
+ std::cout << name << " = [\n";
+ for (size_t y = 0; y < num_rows; ++y) {
+ std::cout << " [";
+ for (size_t x = 0; x < num_cols; ++x) {
+ std::cout << std::setprecision(2) << std::fixed << src[y * num_cols + x] << ", ";
+ }
+ std::cout << ("],\n");
+ }
+ std::cout << ("]\n\n");
+}
+
+void print_matrix(size_t num_rows, size_t num_cols, const char* name, const uint16_t* src) {
+ std::cout << name << " = [\n";
+ for (size_t y = 0; y < num_rows; ++y) {
+ std::cout << " [";
+ for (size_t x = 0; x < num_cols; ++x) {
+ std::cout << std::setprecision(2) << std::fixed << bf16_to_float(&src[y * num_cols + x]) << ", ";
+ }
+ std::cout << ("],\n");
+ }
+ std::cout << ("]\n\n");
+}
+
+void print_mixed_prec_matrix(
+ size_t num_rows, size_t num_cols, const char* name, const uint8_t* src, int nr, int stride) {
+ std::cout << name << " = [\n";
+ for (size_t y = 0; y < num_rows; ++y) {
+ std::cout << " [";
+ const uint8_t* src_row = src + stride * y;
+ for (size_t x = 0; x < num_cols; ++x) {
+ if (x >= nr) {
+ // print bfloat
+ const uint16_t* src_elm =
+ reinterpret_cast(src_row + nr * sizeof(float) + (x - nr) * sizeof(uint16_t));
+ std::cout << std::setprecision(2) << std::fixed << bf16_to_float(src_elm) << ", ";
+ } else {
+ // print float
+ const float* src_elm = reinterpret_cast(src_row + x * sizeof(float));
+ std::cout << std::setprecision(2) << std::fixed << *src_elm << ", ";
+ }
+ }
+ std::cout << ("],\n");
+ }
+ std::cout << ("]\n\n");
+}
+
+void print_bf_matrix(size_t num_rows, size_t num_cols, const char* name, const float* src) {
+ std::cout << name << " = [\n";
+ for (size_t y = 0; y < num_rows; ++y) {
+ std::cout << " [";
+ for (size_t x = 0; x < num_cols; ++x) {
+ std::cout << std::setprecision(2) << std::fixed << truncate(src[y * num_cols + x]) << ", ";
+ }
+ std::cout << ("],\n");
+ }
+ std::cout << ("]\n\n");
+}
+
+/// Verify the micro-kernel output matches the reference implementation
+bool is_output_correct(
+ size_t num_rows, size_t num_cols, const float rel_tolerance, const float* ref, const float* act) {
+ bool is_valid = true;
+
+ for (size_t i = 0; i < num_rows * num_cols; ++i) {
+ if (std::fabs(ref[i] - act[i]) / (act[i] + 1e-10) > rel_tolerance) {
+ const size_t x = i % num_cols;
+ const size_t y = i / num_cols;
+
+ std::cout << std::setprecision(5) << std::fixed << "ERROR![" << y << "][" << x << "]: ref=" << ref[i]
+ << " vs. act=" << act[i] << "\n";
+
+ is_valid = false;
+ }
+ }
+ return is_valid;
+}
+} // namespace
+
+int main() {
+ // Parameters of the matrix multiplication. Change these values to see how the micro-kernels operate on different
+ // sized matrices
+ const size_t M = 25; // Rows of LHS and DST matrices
+ const size_t N = 28; // Columns of RHS and DST matrices, and length of the Bias vector.
+ const size_t K = 117; // Columns of LHS, rows of RHS matrices
+
+ const size_t lhs_size = M * K;
+ const size_t rhs_size = N * K;
+ const size_t bias_size = N;
+ const size_t dst_size = M * N;
+
+ // Allocate the memory
+ float* lhs = new float[lhs_size];
+ float* rhs = new float[rhs_size];
+ float* bias = new float[bias_size];
+
+ fill_matrix(M, K, lhs, 0.1);
+ fill_matrix(K, N, rhs, 0.1);
+ fill_matrix(1, N, bias, 1);
+
+#ifdef KAI_DEBUG
+ // std::cout << "Floats: " << std::endl;
+ print_matrix(M, K, "lhs", lhs);
+ print_matrix(K, N, "rhs", rhs);
+ print_matrix(1, N, "bias", bias);
+
+ // Print bf16 converted values
+ print_bf_matrix(M, K, "lhs_bf", lhs);
+ print_bf_matrix(K, N, "rhs_bf", rhs);
+#endif // KAI_DEBUG
+
+ //----------- REFERENCE IMPLEMENTATION
+ //------------------------------------
+ //------------------------------------
+ float* dst_ref = new float[dst_size];
+
+ run_matmul_ref(
+ M, N, K, // Dimensions
+ lhs, // LHS buffer
+ rhs, // RHS buffer
+ bias, // Bias buffer
+ dst_ref, // DST
+ -FLT_MAX, FLT_MAX // Min and max for the clamp operation
+ );
+ //----------- END REFERENCE IMPLEMENTATION
+ //------------------------------------
+ //------------------------------------
+
+ //----------- MICRO-KERNELS TESTS
+ //------------------------------------
+ //------------------------------------
+ const size_t mr = ukernel.get_mr();
+ const size_t nr = ukernel.get_nr();
+ const size_t kr = ukernel.get_kr();
+ const size_t sr = ukernel.get_sr();
+
+ // In a single row, we pack nr bias values followed by K rows of nr RHS values
+ const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(N, K, nr, kr);
+ uint8_t* rhs_packed = new uint8_t[rhs_packed_size];
+
+ const size_t lhs_stride = K * sizeof(float);
+ const size_t rhs_stride = N * sizeof(float);
+ const size_t dst_stride_row = N * sizeof(float);
+ const size_t dst_stride_col = sizeof(float);
+
+ const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon(M, K, mr, kr, sr);
+ uint16_t* lhs_packed = new uint16_t[lhs_packed_size];
+
+ // Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant.
+ kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(
+ 1, N, K, nr, kr, sr, // Packing arguments
+ rhs_stride, // RHS stride
+ rhs, // RHS
+ bias, // Bias
+ NULL, // Scale
+ rhs_packed, // RHS packed
+ 0, NULL);
+
+ // The RHS and Bias buffers can be freed after packing, however we reuse them for the reference test below
+
+#ifdef KAI_DEBUG
+ const size_t rhs_packed_cols = nr + kai_roundup(K, kr) * nr;
+
+ // Each col has nr floats and then K*nr bfloats
+ int rhs_packed_stride = nr * sizeof(float) + kai_roundup(K, kr) * nr * sizeof(uint16_t);
+ const size_t rhs_packed_rows = rhs_packed_size / rhs_packed_stride;
+
+ print_mixed_prec_matrix(rhs_packed_rows, rhs_packed_cols, "rhs_packed", rhs_packed, nr, rhs_packed_stride);
+#endif // KAI_DEBUG
+
+ float* dst = new float[dst_size];
+
+ const auto timer_matmul_start = std::chrono::high_resolution_clock::now();
+
+ kai_run_lhs_quant_pack_bf16p_f32_neon(M, K, mr, kr, sr, 0 /* m_idx_start */, lhs, lhs_stride, lhs_packed);
+
+ ukernel.run_matmul(
+ M, N, K, // Dimensions
+ lhs_packed, // LHS packed
+ rhs_packed, // RHS packed
+ dst, // DST
+ dst_stride_row, // DST stride (row)
+ dst_stride_col, // DST stride (col)
+ -FLT_MAX, FLT_MAX // Min and max for the clamp operation
+ );
+
+ const auto timer_matmul_end = std::chrono::high_resolution_clock::now();
+ const auto time_matmul =
+ std::chrono::duration_cast(timer_matmul_end - timer_matmul_start);
+
+ int ret = 0;
+
+#ifdef KAI_DEBUG
+ int num_lhs_rows = (M + mr - 1) / mr;
+ int num_lhs_cols = mr * kai_roundup(K, kr);
+
+ print_matrix(num_lhs_rows, num_lhs_cols, "lhs_packed", lhs_packed);
+ print_matrix(M, N, "dst", dst);
+ print_matrix(M, N, "ref", dst_ref);
+#endif // KAI_DEBUG
+
+ constexpr float rel_tolerance = 0.02; // This value was chosen by experimentation
+ const bool is_valid = is_output_correct(M, N, rel_tolerance, dst_ref, dst);
+
+ std::cout << "TEST[matmul_clamp_f32_bf16p_bf16p]\n";
+ std::cout << "- ukernel: matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla\n";
+ if (is_valid) {
+ std::cout << "- Status: PASSED\n";
+ std::cout << "- Performance: " << time_matmul.count() << "ns\n";
+ } else {
+ std::cout << "- Status: FAILED\n";
+ ret = 1;
+ }
+
+ //----------- END MICRO-KERNELS TESTS
+ //------------------------------------
+ //------------------------------------
+
+ delete[] lhs;
+ delete[] rhs;
+ delete[] bias;
+ delete[] lhs_packed;
+ delete[] rhs_packed;
+ delete[] dst;
+ delete[] dst_ref;
+
+ return ret;
+}
+
+#endif // Architectural features check.
diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel
index 7bcbc1a8cbffd30e9a44c39bc4bb7f2fa1b4fdcf..66a3c3868a97e70828efb60ef3f84d724c0e8ef2 100644
--- a/kai/ukernels/matmul/BUILD.bazel
+++ b/kai/ukernels/matmul/BUILD.bazel
@@ -7,6 +7,7 @@
load(
"//:kai_defs.bzl",
"kai_c_library",
+ "kai_cpu_bf16",
"kai_cpu_dotprod",
"kai_cpu_fp16",
"kai_cpu_i8mm",
@@ -32,6 +33,22 @@ kai_c_library(
],
)
+kai_c_library(
+ name = "clamp_f32_bf16p_bf16p_interface",
+ hdrs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h"],
+ cpu_uarch = kai_cpu_bf16(),
+)
+
+kai_c_library(
+ name = "clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla",
+ srcs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c"],
+ hdrs = ["matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h"],
+ cpu_uarch = kai_cpu_bf16(),
+ deps = [
+ ":clamp_f32_bf16p_bf16p_interface",
+ ],
+)
+
kai_c_library(
name = "clamp_f32_f32_f32p",
srcs = ["matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.c"],
@@ -159,6 +176,13 @@ kai_c_library(
cpu_uarch = kai_cpu_sme(),
)
+kai_c_library(
+ name = "lhs_quant_pack_bf16p_f32_neon",
+ srcs = ["pack/kai_lhs_quant_pack_bf16p_f32_neon.c"],
+ hdrs = ["pack/kai_lhs_quant_pack_bf16p_f32_neon.h"],
+ cpu_uarch = kai_cpu_bf16(),
+)
+
kai_c_library(
name = "rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon",
srcs = ["pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c"],
@@ -166,6 +190,13 @@ kai_c_library(
cpu_uarch = kai_cpu_fp16(),
)
+kai_c_library(
+ name = "rhs_quant_pack_kxn_bf16pbiasf32_f32_neon",
+ srcs = ["pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c"],
+ hdrs = ["pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h"],
+ cpu_uarch = kai_cpu_bf16(),
+)
+
kai_c_library(
name = "rhs_pack_kxn_f32pbiasf32_f32_f32_neon",
srcs = ["pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon.c"],
@@ -306,6 +337,7 @@ kai_c_library(
name = "matmul",
deps = [
":clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla",
+ ":clamp_f32_bf16p_bf16p12x1biasf32_8x12x4_neon_mmla",
":clamp_f32_f32_f32p",
":clamp_f32_f32_f32pb_1x16vl_sme2_mla",
":clamp_f32_f32p_f32p",
@@ -327,6 +359,7 @@ kai_c_library(
":clamp_f32_qsi8d32p_qsi4c32p_i8mm",
":clamp_f32_qsi8d32p_qsi4c32p_interface",
":lhs_pack_f32p2vlx1_f32_sme",
+ ":lhs_quant_pack_bf16p_f32_neon",
":lhs_quant_pack_qai8dxp_f32",
":lhs_quant_pack_qsi8d32p_f32",
":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon",
@@ -338,5 +371,6 @@ kai_c_library(
":rhs_pack_nxk_qsi4c32p_qsu4c32s1s0",
":rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0",
":rhs_pack_nxk_qsi4cxp_qs4cxs1s0",
+ ":rhs_quant_pack_kxn_bf16pbiasf32_f32_neon",
],
)
diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c
new file mode 100644
index 0000000000000000000000000000000000000000..929e3753a90c40bb64072e34f2671d97b211bc53
--- /dev/null
+++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.c
@@ -0,0 +1,578 @@
+//
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)
+#error This file must be compiled for AArch64, FEAT_BF16.
+#else // Architectural features check.
+
+#include "kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h"
+
+#include
+#include
+#include
+
+#include "kai/kai_common.h"
+
+static const size_t kai_mr = 8;
+static const size_t kai_nr = 12;
+static const size_t kai_kr = 4;
+static const size_t kai_sr = 1;
+
+size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) {
+ return kai_mr;
+}
+
+size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) {
+ return kai_nr;
+}
+
+size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) {
+ return kai_mr;
+}
+
+size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) {
+ return kai_nr;
+}
+
+size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) {
+ return kai_kr;
+}
+
+size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void) {
+ return kai_sr;
+}
+
+size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k) {
+ KAI_ASSUME(m_idx % kai_mr == 0);
+
+ return m_idx * kai_roundup(k, kai_kr) * sizeof(uint16_t);
+}
+
+size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k) {
+ KAI_ASSUME(n_idx % kai_nr == 0);
+
+ return n_idx * (sizeof(float) + kai_roundup(k, kai_kr) * sizeof(uint16_t));
+}
+
+size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(
+ size_t m_idx, size_t n_idx, size_t stride) {
+ KAI_ASSUME(m_idx % kai_mr == 0);
+ KAI_ASSUME(n_idx % kai_nr == 0);
+
+ return m_idx * stride + n_idx * sizeof(float);
+}
+
+size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n) {
+ return m * n * sizeof(float);
+}
+
+void kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(
+ size_t m, size_t n, size_t k, //
+ const void* lhs_packed, //
+ const void* rhs_packed, //
+ void* dst, size_t dst_stride_row, size_t dst_stride_col, //
+ float clamp_min, float clamp_max) {
+ KAI_ASSERT(dst_stride_col == sizeof(float));
+
+ const void* Apanel = lhs_packed;
+ void* Cpanel = dst;
+ size_t ldc = dst_stride_row / sizeof(float);
+
+ size_t M = m;
+
+ typedef struct {
+ float maxval;
+ float minval;
+ size_t N;
+ size_t K;
+ const void* Bpanel;
+ void* output_ptr;
+ } KernelArgs;
+
+ KernelArgs ka;
+
+ ka.N = n;
+ ka.K = kai_roundup(k, kai_kr) / kai_kr - 1;
+
+ ka.Bpanel = rhs_packed;
+
+ // Direct output.
+ ka.output_ptr = dst;
+
+ // Clamping output.
+ ka.maxval = clamp_max;
+ ka.minval = clamp_min;
+
+ __asm__ __volatile__(
+ "1:" // Height loop
+ "add x11, %x[Cpanel], %x[ldc], LSL #2\n"
+ "add x10, %x[Cpanel], %x[ldc], LSL #1\n"
+ "add x9, x11, %x[ldc], LSL #1\n"
+ "cmp %x[M], #0x8\n"
+ "add x28, %x[Cpanel], %x[ldc], LSL #3\n"
+ "add x27, %x[Cpanel], %x[ldc]\n"
+ "add x26, x10, %x[ldc]\n"
+ "add x25, x11, %x[ldc]\n"
+ "add x24, x9, %x[ldc]\n"
+ "bge 2f\n"
+ "cmp %x[M], #0x2\n"
+ "mov x24, %x[Cpanel]\n"
+ "csel x27, x27, %x[Cpanel], GE\n"
+ "csel x10, x10, %x[Cpanel], GT\n"
+ "cmp %x[M], #0x4\n"
+ "csel x26, x26, %x[Cpanel], GE\n"
+ "csel x11, x11, %x[Cpanel], GT\n"
+ "cmp %x[M], #0x6\n"
+ "csel x25, x25, %x[Cpanel], GE\n"
+ "csel x9, x9, %x[Cpanel], GT\n"
+ "2:" // all rows valid
+ "ldr x23, [%x[args_ptr], %[offsetof_N]]\n"
+ "ldr x22, [%x[args_ptr], %[offsetof_Bpanel]]\n"
+ "mov x21, %x[Apanel]\n"
+ "3:" // Width loop
+ "ldr q4, [x22, #0x0]\n"
+ "ldr q5, [x22, #0x10]\n"
+ "mov %x[Apanel], x21\n"
+ "ldr q6, [x22, #0x20]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_K]]\n"
+ "add x22, x22, #0x30\n"
+ "ldr q7, [x22, #0x0]\n"
+ "ldr q0, [%x[Apanel], #0x0]\n"
+ "ldr q1, [%x[Apanel], #0x10]\n"
+ "zip1 v8.2d, v4.2d, v4.2d\n"
+ "ldr q2, [%x[Apanel], #0x20]\n"
+ "zip2 v11.2d, v4.2d, v4.2d\n"
+ "ldr q4, [x22, #0x10]\n"
+ "zip1 v9.2d, v5.2d, v5.2d\n"
+ "zip2 v12.2d, v5.2d, v5.2d\n"
+ "cmp x20, #0x2\n"
+ "zip1 v10.2d, v6.2d, v6.2d\n"
+ "zip2 v13.2d, v6.2d, v6.2d\n"
+ "prfm pldl1keep, [%x[Apanel], #0x0]\n"
+ "mov v14.16b, v8.16b\n"
+ "mov v17.16b, v11.16b\n"
+ "prfm pldl1keep, [x22, #0x0]\n"
+ "mov v15.16b, v9.16b\n"
+ "mov v18.16b, v12.16b\n"
+ "prfm pldl1keep, [x22, #0x40]\n"
+ "mov v16.16b, v10.16b\n"
+ "mov v19.16b, v13.16b\n"
+ "prfm pldl1keep, [%x[Apanel], #0x40]\n"
+ "mov v20.16b, v8.16b\n"
+ "mov v21.16b, v9.16b\n"
+ "prfm pldl1keep, [x22, #0x80]\n"
+ "mov v22.16b, v10.16b\n"
+ "mov v23.16b, v11.16b\n"
+ "prfm pldl1keep, [%x[Apanel], #0x80]\n"
+ "mov v24.16b, v12.16b\n"
+ "mov v25.16b, v13.16b\n"
+ "prfm pldl1keep, [x22, #0xc0]\n"
+ "mov v26.16b, v8.16b\n"
+ "mov v27.16b, v9.16b\n"
+ "prfm pldl1keep, [x22, #0x100]\n"
+ "mov v28.16b, v10.16b\n"
+ "mov v29.16b, v11.16b\n"
+ "prfm pldl1keep, [%x[Apanel], #0xc0]\n"
+ "mov v30.16b, v12.16b\n"
+ "mov v31.16b, v13.16b\n"
+ "prfm pldl1keep, [x22, #0x140]\n"
+ "add x22, x22, #0x20\n"
+ "add %x[Apanel], %x[Apanel], #0x30\n"
+ "blt 5f\n"
+ "4:" // main loop head
+ "ldr q3, [%x[Apanel], #0x0]\n"
+ "ldr q5, [x22, #0x0]\n"
+ ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n"
+ "ldr q6, [x22, #0x10]\n"
+ ".inst 0x6e44ec0b // bfmmla v11.4s, v0.8h, v4.8h\n"
+ ".inst 0x6e47ec2e // bfmmla v14.4s, v1.8h, v7.8h\n"
+ ".inst 0x6e44ec31 // bfmmla v17.4s, v1.8h, v4.8h\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ "sub x20, x20, #0x2\n"
+ ".inst 0x6e44ec57 // bfmmla v23.4s, v2.8h, v4.8h\n"
+ ".inst 0x6e47ec7a // bfmmla v26.4s, v3.8h, v7.8h\n"
+ "ldr q7, [x22, #0x20]\n"
+ ".inst 0x6e44ec7d // bfmmla v29.4s, v3.8h, v4.8h\n"
+ "ldr q4, [x22, #0x30]\n"
+ ".inst 0x6e45ec09 // bfmmla v9.4s, v0.8h, v5.8h\n"
+ ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e45ec2f // bfmmla v15.4s, v1.8h, v5.8h\n"
+ "cmp x20, #0x2\n"
+ ".inst 0x6e46ec32 // bfmmla v18.4s, v1.8h, v6.8h\n"
+ ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n"
+ "prfm pldl1keep, [%x[Apanel], #0x100]\n"
+ ".inst 0x6e46ec58 // bfmmla v24.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e45ec7b // bfmmla v27.4s, v3.8h, v5.8h\n"
+ "ldr q5, [x22, #0x40]\n"
+ ".inst 0x6e46ec7e // bfmmla v30.4s, v3.8h, v6.8h\n"
+ "ldr q6, [x22, #0x50]\n"
+ ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e44ec0d // bfmmla v13.4s, v0.8h, v4.8h\n"
+ "ldr q0, [%x[Apanel], #0x10]\n"
+ ".inst 0x6e47ec30 // bfmmla v16.4s, v1.8h, v7.8h\n"
+ ".inst 0x6e44ec33 // bfmmla v19.4s, v1.8h, v4.8h\n"
+ "ldr q1, [%x[Apanel], #0x20]\n"
+ ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e44ec59 // bfmmla v25.4s, v2.8h, v4.8h\n"
+ "ldr q2, [%x[Apanel], #0x30]\n"
+ ".inst 0x6e47ec7c // bfmmla v28.4s, v3.8h, v7.8h\n"
+ "ldr q7, [x22, #0x60]\n"
+ ".inst 0x6e44ec7f // bfmmla v31.4s, v3.8h, v4.8h\n"
+ "ldr q3, [%x[Apanel], #0x40]\n"
+ "ldr q4, [x22, #0x70]\n"
+ ".inst 0x6e45ec08 // bfmmla v8.4s, v0.8h, v5.8h\n"
+ ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e45ec2e // bfmmla v14.4s, v1.8h, v5.8h\n"
+ ".inst 0x6e46ec31 // bfmmla v17.4s, v1.8h, v6.8h\n"
+ "prfm pldl1keep, [x22, #0x180]\n"
+ ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n"
+ ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n"
+ "prfm pldl1keep, [x22, #0x1c0]\n"
+ ".inst 0x6e45ec7a // bfmmla v26.4s, v3.8h, v5.8h\n"
+ "ldr q5, [x22, #0x80]\n"
+ ".inst 0x6e46ec7d // bfmmla v29.4s, v3.8h, v6.8h\n"
+ "ldr q6, [x22, #0x90]\n"
+ "prfm pldl1keep, [%x[Apanel], #0x140]\n"
+ ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n"
+ "prfm pldl1keep, [x22, #0x200]\n"
+ ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n"
+ ".inst 0x6e47ec2f // bfmmla v15.4s, v1.8h, v7.8h\n"
+ ".inst 0x6e44ec32 // bfmmla v18.4s, v1.8h, v4.8h\n"
+ ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n"
+ ".inst 0x6e47ec7b // bfmmla v27.4s, v3.8h, v7.8h\n"
+ "ldr q7, [x22, #0xa0]\n"
+ ".inst 0x6e44ec7e // bfmmla v30.4s, v3.8h, v4.8h\n"
+ "ldr q4, [x22, #0xb0]\n"
+ ".inst 0x6e45ec0a // bfmmla v10.4s, v0.8h, v5.8h\n"
+ ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n"
+ "ldr q0, [%x[Apanel], #0x50]\n"
+ ".inst 0x6e45ec30 // bfmmla v16.4s, v1.8h, v5.8h\n"
+ ".inst 0x6e46ec33 // bfmmla v19.4s, v1.8h, v6.8h\n"
+ "ldr q1, [%x[Apanel], #0x60]\n"
+ ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n"
+ ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n"
+ "ldr q2, [%x[Apanel], #0x70]\n"
+ ".inst 0x6e45ec7c // bfmmla v28.4s, v3.8h, v5.8h\n"
+ ".inst 0x6e46ec7f // bfmmla v31.4s, v3.8h, v6.8h\n"
+ "add %x[Apanel], %x[Apanel], #0x80\n"
+ "add x22, x22, #0xc0\n"
+ "bge 4b\n"
+ "5:" // main loop skip
+ "ldr q3, [%x[Apanel], #0x0]\n"
+ "ldr q5, [x22, #0x0]\n"
+ ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n"
+ "ldr q6, [x22, #0x10]\n"
+ ".inst 0x6e44ec0b // bfmmla v11.4s, v0.8h, v4.8h\n"
+ ".inst 0x6e47ec2e // bfmmla v14.4s, v1.8h, v7.8h\n"
+ ".inst 0x6e44ec31 // bfmmla v17.4s, v1.8h, v4.8h\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ "add %x[Apanel], %x[Apanel], #0x10\n"
+ ".inst 0x6e44ec57 // bfmmla v23.4s, v2.8h, v4.8h\n"
+ ".inst 0x6e47ec7a // bfmmla v26.4s, v3.8h, v7.8h\n"
+ "ldr q7, [x22, #0x20]\n"
+ ".inst 0x6e44ec7d // bfmmla v29.4s, v3.8h, v4.8h\n"
+ "ldr q4, [x22, #0x30]\n"
+ ".inst 0x6e45ec09 // bfmmla v9.4s, v0.8h, v5.8h\n"
+ ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e45ec2f // bfmmla v15.4s, v1.8h, v5.8h\n"
+ "add x22, x22, #0x40\n"
+ ".inst 0x6e46ec32 // bfmmla v18.4s, v1.8h, v6.8h\n"
+ ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n"
+ ".inst 0x6e46ec58 // bfmmla v24.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e45ec7b // bfmmla v27.4s, v3.8h, v5.8h\n"
+ ".inst 0x6e46ec7e // bfmmla v30.4s, v3.8h, v6.8h\n"
+ ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e44ec0d // bfmmla v13.4s, v0.8h, v4.8h\n"
+ ".inst 0x6e47ec30 // bfmmla v16.4s, v1.8h, v7.8h\n"
+ ".inst 0x6e44ec33 // bfmmla v19.4s, v1.8h, v4.8h\n"
+ ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e44ec59 // bfmmla v25.4s, v2.8h, v4.8h\n"
+ ".inst 0x6e47ec7c // bfmmla v28.4s, v3.8h, v7.8h\n"
+ ".inst 0x6e44ec7f // bfmmla v31.4s, v3.8h, v4.8h\n"
+ "cbz x20, 6f\n"
+ "ldr q5, [x22, #0x0]\n"
+ "ldr q0, [%x[Apanel], #0x0]\n"
+ "ldr q1, [%x[Apanel], #0x10]\n"
+ "ldr q6, [x22, #0x10]\n"
+ "ldr q2, [%x[Apanel], #0x20]\n"
+ "ldr q3, [%x[Apanel], #0x30]\n"
+ "add %x[Apanel], %x[Apanel], #0x40\n"
+ "ldr q7, [x22, #0x20]\n"
+ "ldr q4, [x22, #0x30]\n"
+ ".inst 0x6e45ec08 // bfmmla v8.4s, v0.8h, v5.8h\n"
+ ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e45ec2e // bfmmla v14.4s, v1.8h, v5.8h\n"
+ ".inst 0x6e46ec31 // bfmmla v17.4s, v1.8h, v6.8h\n"
+ ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n"
+ ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e45ec7a // bfmmla v26.4s, v3.8h, v5.8h\n"
+ "ldr q5, [x22, #0x40]\n"
+ ".inst 0x6e46ec7d // bfmmla v29.4s, v3.8h, v6.8h\n"
+ "ldr q6, [x22, #0x50]\n"
+ ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n"
+ ".inst 0x6e47ec2f // bfmmla v15.4s, v1.8h, v7.8h\n"
+ "add x22, x22, #0x60\n"
+ ".inst 0x6e44ec32 // bfmmla v18.4s, v1.8h, v4.8h\n"
+ ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n"
+ ".inst 0x6e47ec7b // bfmmla v27.4s, v3.8h, v7.8h\n"
+ ".inst 0x6e44ec7e // bfmmla v30.4s, v3.8h, v4.8h\n"
+ ".inst 0x6e45ec0a // bfmmla v10.4s, v0.8h, v5.8h\n"
+ ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e45ec30 // bfmmla v16.4s, v1.8h, v5.8h\n"
+ ".inst 0x6e46ec33 // bfmmla v19.4s, v1.8h, v6.8h\n"
+ ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n"
+ ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e45ec7c // bfmmla v28.4s, v3.8h, v5.8h\n"
+ ".inst 0x6e46ec7f // bfmmla v31.4s, v3.8h, v6.8h\n"
+ "6:" // multiply loop done
+ "add x20, %x[args_ptr], %[offset_max]\n"
+ "uzp1 v7.2d, v8.2d, v11.2d\n"
+ "uzp2 v8.2d, v8.2d, v11.2d\n"
+ "ld1r { v1.4s }, [x20]\n"
+ "uzp1 v11.2d, v9.2d, v12.2d\n"
+ "uzp2 v9.2d, v9.2d, v12.2d\n"
+ "uzp1 v12.2d, v10.2d, v13.2d\n"
+ "uzp2 v10.2d, v10.2d, v13.2d\n"
+ "add x20, %x[args_ptr], %[offset_min]\n"
+ "ld1r { v0.4s }, [x20]\n"
+ "uzp1 v13.2d, v14.2d, v17.2d\n"
+ "uzp2 v14.2d, v14.2d, v17.2d\n"
+ "uzp1 v17.2d, v15.2d, v18.2d\n"
+ "uzp2 v15.2d, v15.2d, v18.2d\n"
+ "cmp x23, #0xc\n"
+ "uzp1 v18.2d, v16.2d, v19.2d\n"
+ "uzp2 v16.2d, v16.2d, v19.2d\n"
+ "uzp1 v19.2d, v20.2d, v23.2d\n"
+ "uzp2 v20.2d, v20.2d, v23.2d\n"
+ "uzp1 v23.2d, v21.2d, v24.2d\n"
+ "uzp2 v21.2d, v21.2d, v24.2d\n"
+ "uzp1 v24.2d, v22.2d, v25.2d\n"
+ "uzp2 v22.2d, v22.2d, v25.2d\n"
+ "uzp1 v25.2d, v26.2d, v29.2d\n"
+ "uzp2 v26.2d, v26.2d, v29.2d\n"
+ "uzp1 v29.2d, v27.2d, v30.2d\n"
+ "uzp2 v27.2d, v27.2d, v30.2d\n"
+ "uzp1 v30.2d, v28.2d, v31.2d\n"
+ "uzp2 v28.2d, v28.2d, v31.2d\n"
+ "fmin v7.4s, v7.4s, v1.4s\n"
+ "fmin v11.4s, v11.4s, v1.4s\n"
+ "fmin v12.4s, v12.4s, v1.4s\n"
+ "fmin v8.4s, v8.4s, v1.4s\n"
+ "fmin v9.4s, v9.4s, v1.4s\n"
+ "fmin v10.4s, v10.4s, v1.4s\n"
+ "fmin v13.4s, v13.4s, v1.4s\n"
+ "fmin v17.4s, v17.4s, v1.4s\n"
+ "fmin v18.4s, v18.4s, v1.4s\n"
+ "fmin v14.4s, v14.4s, v1.4s\n"
+ "fmin v15.4s, v15.4s, v1.4s\n"
+ "fmin v16.4s, v16.4s, v1.4s\n"
+ "fmin v19.4s, v19.4s, v1.4s\n"
+ "fmin v23.4s, v23.4s, v1.4s\n"
+ "fmin v24.4s, v24.4s, v1.4s\n"
+ "fmin v20.4s, v20.4s, v1.4s\n"
+ "fmin v21.4s, v21.4s, v1.4s\n"
+ "fmin v22.4s, v22.4s, v1.4s\n"
+ "fmin v25.4s, v25.4s, v1.4s\n"
+ "fmin v29.4s, v29.4s, v1.4s\n"
+ "fmin v30.4s, v30.4s, v1.4s\n"
+ "fmin v26.4s, v26.4s, v1.4s\n"
+ "fmin v27.4s, v27.4s, v1.4s\n"
+ "fmin v28.4s, v28.4s, v1.4s\n"
+ "fmax v7.4s, v7.4s, v0.4s\n"
+ "fmax v11.4s, v11.4s, v0.4s\n"
+ "fmax v12.4s, v12.4s, v0.4s\n"
+ "fmax v8.4s, v8.4s, v0.4s\n"
+ "fmax v9.4s, v9.4s, v0.4s\n"
+ "fmax v10.4s, v10.4s, v0.4s\n"
+ "fmax v13.4s, v13.4s, v0.4s\n"
+ "fmax v17.4s, v17.4s, v0.4s\n"
+ "fmax v18.4s, v18.4s, v0.4s\n"
+ "fmax v14.4s, v14.4s, v0.4s\n"
+ "fmax v15.4s, v15.4s, v0.4s\n"
+ "fmax v16.4s, v16.4s, v0.4s\n"
+ "fmax v19.4s, v19.4s, v0.4s\n"
+ "fmax v23.4s, v23.4s, v0.4s\n"
+ "fmax v24.4s, v24.4s, v0.4s\n"
+ "fmax v20.4s, v20.4s, v0.4s\n"
+ "fmax v21.4s, v21.4s, v0.4s\n"
+ "fmax v22.4s, v22.4s, v0.4s\n"
+ "fmax v25.4s, v25.4s, v0.4s\n"
+ "fmax v29.4s, v29.4s, v0.4s\n"
+ "fmax v30.4s, v30.4s, v0.4s\n"
+ "fmax v26.4s, v26.4s, v0.4s\n"
+ "fmax v27.4s, v27.4s, v0.4s\n"
+ "fmax v28.4s, v28.4s, v0.4s\n"
+ "blt 7f\n"
+ "str q26, [x24, #0x0]\n"
+ "str q27, [x24, #0x10]\n"
+ "str q28, [x24, #0x20]\n"
+ "add x24, x24, #0x30\n"
+ "str q25, [x9, #0x0]\n"
+ "str q29, [x9, #0x10]\n"
+ "str q30, [x9, #0x20]\n"
+ "add x9, x9, #0x30\n"
+ "str q20, [x25, #0x0]\n"
+ "str q21, [x25, #0x10]\n"
+ "str q22, [x25, #0x20]\n"
+ "add x25, x25, #0x30\n"
+ "str q19, [x11, #0x0]\n"
+ "str q23, [x11, #0x10]\n"
+ "str q24, [x11, #0x20]\n"
+ "add x11, x11, #0x30\n"
+ "str q14, [x26, #0x0]\n"
+ "str q15, [x26, #0x10]\n"
+ "str q16, [x26, #0x20]\n"
+ "add x26, x26, #0x30\n"
+ "str q13, [x10, #0x0]\n"
+ "str q17, [x10, #0x10]\n"
+ "str q18, [x10, #0x20]\n"
+ "add x10, x10, #0x30\n"
+ "str q8, [x27, #0x0]\n"
+ "str q9, [x27, #0x10]\n"
+ "str q10, [x27, #0x20]\n"
+ "add x27, x27, #0x30\n"
+ "str q7, [%x[Cpanel], #0x0]\n"
+ "str q11, [%x[Cpanel], #0x10]\n"
+ "str q12, [%x[Cpanel], #0x20]\n"
+ "add %x[Cpanel], %x[Cpanel], #0x30\n"
+ "b 14f\n"
+ "7:" // partial output
+ "tbz x23, #3, 9f\n"
+ "st1 { v26.4s }, [x24], #0x10\n"
+ "st1 { v27.4s }, [x24], #0x10\n"
+ "st1 { v25.4s }, [x9], #0x10\n"
+ "st1 { v29.4s }, [x9], #0x10\n"
+ "st1 { v20.4s }, [x25], #0x10\n"
+ "st1 { v21.4s }, [x25], #0x10\n"
+ "st1 { v19.4s }, [x11], #0x10\n"
+ "st1 { v23.4s }, [x11], #0x10\n"
+ "st1 { v14.4s }, [x26], #0x10\n"
+ "st1 { v15.4s }, [x26], #0x10\n"
+ "st1 { v13.4s }, [x10], #0x10\n"
+ "st1 { v17.4s }, [x10], #0x10\n"
+ "st1 { v8.4s }, [x27], #0x10\n"
+ "st1 { v9.4s }, [x27], #0x10\n"
+ "st1 { v7.4s }, [%x[Cpanel]], #0x10\n"
+ "st1 { v11.4s }, [%x[Cpanel]], #0x10\n"
+ "tbz x23, #1, 8f\n"
+ "str d28, [x24], #0x8\n"
+ "str d30, [x9], #0x8\n"
+ "str d22, [x25], #0x8\n"
+ "str d24, [x11], #0x8\n"
+ "str d16, [x26], #0x8\n"
+ "str d18, [x10], #0x8\n"
+ "str d10, [x27], #0x8\n"
+ "str d12, [%x[Cpanel]], #0x8\n"
+ "tbz x23, #0, 13f\n"
+ "st1 { v28.s }[2], [x24]\n"
+ "st1 { v30.s }[2], [x9]\n"
+ "st1 { v22.s }[2], [x25]\n"
+ "st1 { v24.s }[2], [x11]\n"
+ "st1 { v16.s }[2], [x26]\n"
+ "st1 { v18.s }[2], [x10]\n"
+ "st1 { v10.s }[2], [x27]\n"
+ "st1 { v12.s }[2], [%x[Cpanel]]\n"
+ "b 13f\n"
+ "8:" // partial result store: partial_1_8
+ "tbz x23, #0, 13f\n"
+ "str s28, [x24, #0x0]\n"
+ "str s30, [x9, #0x0]\n"
+ "str s22, [x25, #0x0]\n"
+ "str s24, [x11, #0x0]\n"
+ "str s16, [x26, #0x0]\n"
+ "str s18, [x10, #0x0]\n"
+ "str s10, [x27, #0x0]\n"
+ "str s12, [%x[Cpanel], #0x0]\n"
+ "b 13f\n"
+ "9:" // partial result store: partial_4_0
+ "tbz x23, #2, 11f\n"
+ "st1 { v26.4s }, [x24], #0x10\n"
+ "st1 { v25.4s }, [x9], #0x10\n"
+ "st1 { v20.4s }, [x25], #0x10\n"
+ "st1 { v19.4s }, [x11], #0x10\n"
+ "st1 { v14.4s }, [x26], #0x10\n"
+ "st1 { v13.4s }, [x10], #0x10\n"
+ "st1 { v8.4s }, [x27], #0x10\n"
+ "st1 { v7.4s }, [%x[Cpanel]], #0x10\n"
+ "tbz x23, #1, 10f\n"
+ "str d27, [x24], #0x8\n"
+ "str d29, [x9], #0x8\n"
+ "str d21, [x25], #0x8\n"
+ "str d23, [x11], #0x8\n"
+ "str d15, [x26], #0x8\n"
+ "str d17, [x10], #0x8\n"
+ "str d9, [x27], #0x8\n"
+ "str d11, [%x[Cpanel]], #0x8\n"
+ "tbz x23, #0, 13f\n"
+ "st1 { v27.s }[2], [x24]\n"
+ "st1 { v29.s }[2], [x9]\n"
+ "st1 { v21.s }[2], [x25]\n"
+ "st1 { v23.s }[2], [x11]\n"
+ "st1 { v15.s }[2], [x26]\n"
+ "st1 { v17.s }[2], [x10]\n"
+ "st1 { v9.s }[2], [x27]\n"
+ "st1 { v11.s }[2], [%x[Cpanel]]\n"
+ "b 13f\n"
+ "10:" // partial result store: partial_1_4
+ "tbz x23, #0, 13f\n"
+ "str s27, [x24, #0x0]\n"
+ "str s29, [x9, #0x0]\n"
+ "str s21, [x25, #0x0]\n"
+ "str s23, [x11, #0x0]\n"
+ "str s15, [x26, #0x0]\n"
+ "str s17, [x10, #0x0]\n"
+ "str s9, [x27, #0x0]\n"
+ "str s11, [%x[Cpanel], #0x0]\n"
+ "b 13f\n"
+ "11:" // partial result store: partial_2_0
+ "tbz x23, #1, 12f\n"
+ "str d26, [x24], #0x8\n"
+ "str d25, [x9], #0x8\n"
+ "str d20, [x25], #0x8\n"
+ "str d19, [x11], #0x8\n"
+ "str d14, [x26], #0x8\n"
+ "str d13, [x10], #0x8\n"
+ "str d8, [x27], #0x8\n"
+ "str d7, [%x[Cpanel]], #0x8\n"
+ "tbz x23, #0, 13f\n"
+ "st1 { v26.s }[2], [x24]\n"
+ "st1 { v25.s }[2], [x9]\n"
+ "st1 { v20.s }[2], [x25]\n"
+ "st1 { v19.s }[2], [x11]\n"
+ "st1 { v14.s }[2], [x26]\n"
+ "st1 { v13.s }[2], [x10]\n"
+ "st1 { v8.s }[2], [x27]\n"
+ "st1 { v7.s }[2], [%x[Cpanel]]\n"
+ "b 13f\n"
+ "12:" // partial result store: partial_1_0
+ "str s26, [x24, #0x0]\n"
+ "str s25, [x9, #0x0]\n"
+ "str s20, [x25, #0x0]\n"
+ "str s19, [x11, #0x0]\n"
+ "str s14, [x26, #0x0]\n"
+ "str s13, [x10, #0x0]\n"
+ "str s8, [x27, #0x0]\n"
+ "str s7, [%x[Cpanel], #0x0]\n"
+ "13:" // partial result store: Done
+ "14:" // store done
+ "subs x23, x23, #0xc\n"
+ "bgt 3b\n"
+ "subs %x[M], %x[M], #0x8\n"
+ "mov %x[Cpanel], x28\n"
+ "bgt 1b\n"
+ : [Apanel] "+&r"(Apanel), [Cpanel] "+&r"(Cpanel), [M] "+&r"(M)
+ : [args_ptr] "r"(&ka), [ldc] "r"(ldc * sizeof(float)), [offset_max] "I"(offsetof(KernelArgs, maxval)),
+ [offset_min] "I"(offsetof(KernelArgs, minval)), [offsetof_Bpanel] "I"(offsetof(KernelArgs, Bpanel)),
+ [offsetof_K] "I"(offsetof(KernelArgs, K)), [offsetof_N] "I"(offsetof(KernelArgs, N))
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
+ "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31", "x9", "x10", "x11", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28");
+}
+
+#endif // Architectural features check.
diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h
new file mode 100644
index 0000000000000000000000000000000000000000..e870fb2a09706234cb1c82c0c837fa606d10b044
--- /dev/null
+++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.h
@@ -0,0 +1,127 @@
+//
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)
+#error This file must be compiled for AArch64, FEAT_BF16.
+#else // Architectural features check.
+
+#include
+#include
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+/// --------------------------------------------------
+
+/// Gets m step value.
+///
+/// The starting row index must be divisible by `m_step`.
+///
+/// @return The m step value.
+size_t kai_get_m_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void);
+
+/// Gets n step value.
+///
+/// The starting column index must be divisible by `n_step`.
+///
+/// @return The n step value.
+size_t kai_get_n_step_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void);
+
+/// Gets mr value.
+///
+/// This is the packing parameter which must be used to pack the LHS matrix.
+///
+/// @return The mr value.
+size_t kai_get_mr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void);
+
+/// Gets nr value.
+///
+/// This is the packing parameter which must be used to pack the RHS matrix.
+///
+/// @return The nr value.
+size_t kai_get_nr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void);
+
+/// Gets kr value.
+///
+/// This is the packing parameter which must be used to pack the LHS & RHS matrices.
+///
+/// @return The kr value.
+size_t kai_get_kr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void);
+
+/// Gets sr value.
+///
+/// This is the packing parameter which must be used to pack the RHS matrix.
+///
+/// @return The sr value.
+size_t kai_get_sr_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(void);
+
+/// Gets the offset in bytes to the data element in the packed LHS matrix buffer.
+///
+/// @param[in] m_idx Row index.
+/// @param[in] k Number of columns.
+///
+/// @return The offset in bytes to the data element.
+size_t kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t k);
+
+/// Gets the offset in bytes to the data element in the packed RHS matrix buffer.
+///
+/// @param[in] n_idx Row index.
+/// @param[in] k Number of columns.
+///
+/// @return The offset in bytes to the data element.
+size_t kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t n_idx, size_t k);
+
+/// Gets the offset in bytes to the data element in the destination matrix buffer.
+///
+/// @param[in] m_idx Row index.
+/// @param[in] n_idx Column index.
+/// @param[in] stride Row stride in bytes.
+///
+/// @return The offset in bytes to the data element.
+size_t kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m_idx, size_t n_idx, size_t stride);
+
+/// Gets the size in bytes of the destination matrix buffer.
+///
+/// @param[in] m Number of rows.
+/// @param[in] n Number of columns.
+///
+/// @return The size in bytes of the destination matrix buffer.
+size_t kai_get_dst_size_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(size_t m, size_t n);
+
+/// Runs the matrix multiplication microkernel followed by a clamp operation.
+///
+/// The pointer of each buffers (packed LHS, packed RHS and output) needs to be added with offset
+/// calculated using the following functions:
+///
+/// * Packed LHS: @ref kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.
+/// * Packed RHS: @ref kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.
+/// * Output: @ref kai_get_dst_offset_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla.
+///
+/// @param[in] m Number of output rows to be computed.
+/// @param[in] n Number of output columns to be computed.
+/// @param[in] k Common dimension of the LHS and RHS operand.
+/// @param[in] lhs_packed Packed LHS buffer.
+/// @param[in] rhs_packed Packed RHS buffer.
+/// @param[out] dst Output matrix buffer.
+/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix.
+/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float)
+/// @param[in] clamp_min Minimum value to clamp the final result.
+/// @param[in] clamp_max Maximum value to clamp the final result.
+void kai_run_matmul_clamp_f32_bf16p_bf16p12x4b_8x12x4_neon_mmla(
+ size_t m, size_t n, size_t k, //
+ const void* lhs_packed, //
+ const void* rhs_packed, //
+ void* dst, size_t dst_stride_row, size_t dst_stride_col, //
+ float clamp_min, float clamp_max);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // Architectural features check.
diff --git a/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h
new file mode 100644
index 0000000000000000000000000000000000000000..62f89279181dadb1e2dd58aae033fcc893b96ec1
--- /dev/null
+++ b/kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h
@@ -0,0 +1,57 @@
+//
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+#pragma once
+
+#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)
+#error This file must be compiled for AArch64, FEAT_BF16
+#else // Architectural features check.
+
+#include
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// All micro-kernels variants of the same type share the same interfaces
+// In this case, the micro-kernel type is: matmul_clamp_f32_bf16p_bf16p
+
+/// Micro-kernel helper functions ("get" methods)
+typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_m_step_func_t)(void);
+typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_n_step_func_t)(void);
+typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_mr_func_t)(void);
+typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t)(void);
+typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_kr_func_t)(void);
+typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t)(void);
+typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_packed_offset_func_t)(size_t m_idx, size_t k);
+typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t)(size_t n_idx, size_t k);
+typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t)(size_t m_idx, size_t n_idx, size_t dst_stride);
+typedef size_t (*kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t)(size_t m, size_t n);
+
+/// Micro-kernel core function ("run" method)
+typedef void (*kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t)(
+ size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
+ size_t dst_stride_col, float scalar_min, float scalar_max);
+
+/// Micro-kernel interface
+struct kai_matmul_clamp_f32_bf16p_bf16p_ukernel {
+ kai_matmul_clamp_f32_bf16p_bf16p_get_m_step_func_t get_m_step;
+ kai_matmul_clamp_f32_bf16p_bf16p_get_n_step_func_t get_n_step;
+ kai_matmul_clamp_f32_bf16p_bf16p_get_mr_func_t get_mr;
+ kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t get_nr;
+ kai_matmul_clamp_f32_bf16p_bf16p_get_nr_func_t get_kr;
+ kai_matmul_clamp_f32_bf16p_bf16p_get_sr_func_t get_sr;
+ kai_matmul_clamp_f32_bf16p_bf16p_get_lhs_packed_offset_func_t get_lhs_packed_offset;
+ kai_matmul_clamp_f32_bf16p_bf16p_get_rhs_packed_offset_func_t get_rhs_packed_offset;
+ kai_matmul_clamp_f32_bf16p_bf16p_get_dst_offset_func_t get_dst_offset;
+ kai_matmul_clamp_f32_bf16p_bf16p_get_dst_size_func_t get_dst_size;
+ kai_matmul_clamp_f32_bf16p_bf16p_run_matmul_func_t run_matmul;
+};
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // Architectural features check.
diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c
new file mode 100644
index 0000000000000000000000000000000000000000..cea886ed29a9212eae88cc3ab75f7aa646ed0e76
--- /dev/null
+++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.c
@@ -0,0 +1,200 @@
+//
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#if !defined(__aarch64__) || !defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)
+#error This file must be compiled for AArch64, FEAT_BF16.
+#else // Architectural features check.
+
+#define MAX_MR 8
+
+#include
+#include
+#include
+
+#include "kai/kai_common.h"
+
+size_t kai_get_m_step_lhs_quant_pack_bf16p_f32_neon(size_t mr) {
+ return mr;
+}
+
+size_t kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon(size_t m_idx, size_t lhs_stride) {
+ return m_idx * lhs_stride;
+}
+
+size_t kai_get_lhs_packed_offset_lhs_quant_pack_bf16p_f32_neon(
+ size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) {
+ KAI_UNUSED(sr);
+ KAI_ASSUME(m_idx % mr == 0);
+
+ return m_idx * kai_roundup(k, kr) * sizeof(uint16_t);
+}
+
+size_t kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr) {
+ KAI_UNUSED(sr);
+
+ return kai_roundup(m, mr) * kai_roundup(k, kr) * sizeof(uint16_t);
+}
+
+void kai_run_lhs_quant_pack_bf16p_f32_neon(
+ size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, size_t lhs_stride,
+ uint16_t* lhs_packed) {
+ KAI_UNUSED(sr);
+ KAI_ASSUME(lhs != NULL);
+ KAI_ASSUME(lhs_packed != NULL);
+
+ KAI_ASSUME(m_idx_start == 0);
+ KAI_ASSUME(mr <= MAX_MR);
+
+ const size_t block_height = mr;
+ const size_t row_offset = 0;
+
+ const void* in[MAX_MR];
+
+ for (size_t block_y = 0; block_y < m; block_y += block_height) {
+ const size_t height = KAI_MIN(m - block_y, block_height);
+ void* out = (char*)lhs_packed + block_y * kai_roundup(k, kr) * sizeof(uint16_t);
+ size_t width = k;
+
+ for (size_t y = 0; y < height; y++) {
+ in[y] = (char*)lhs + (block_y + y) * lhs_stride;
+ }
+
+ __asm__ __volatile__(
+ "ldr x28, [%x[in], #0x0]\n"
+ "ldr x27, [%x[in], #0x8]\n"
+ "cmp %x[height], #0x8\n"
+ "ldr x26, [%x[in], #0x10]\n"
+ "ldr x25, [%x[in], #0x18]\n"
+ "ldr x24, [%x[in], #0x20]\n"
+ "ldr x23, [%x[in], #0x28]\n"
+ "ldr x22, [%x[in], #0x30]\n"
+ "ldr x21, [%x[in], #0x38]\n"
+ "add x28, x28, %x[row_offset], LSL #2\n"
+ "add x27, x27, %x[row_offset], LSL #2\n"
+ "add x26, x26, %x[row_offset], LSL #2\n"
+ "add x25, x25, %x[row_offset], LSL #2\n"
+ "add x24, x24, %x[row_offset], LSL #2\n"
+ "add x23, x23, %x[row_offset], LSL #2\n"
+ "add x22, x22, %x[row_offset], LSL #2\n"
+ "add x21, x21, %x[row_offset], LSL #2\n"
+ "beq 1f\n"
+ "cmp %x[height], #0x2\n"
+ "mov x21, x28\n"
+ "csel x27, x27, x28, GE\n"
+ "csel x26, x26, x28, GT\n"
+ "cmp %x[height], #0x4\n"
+ "csel x25, x25, x28, GE\n"
+ "csel x24, x24, x28, GT\n"
+ "cmp %x[height], #0x6\n"
+ "csel x23, x23, x28, GE\n"
+ "csel x22, x22, x28, GT\n"
+ "1:" // no_pointer_adj
+ "cmp %x[width], #0x4\n"
+ "prfm pldl1keep, [x28, #0x0]\n"
+ "prfm pldl1keep, [x27, #0x0]\n"
+ "prfm pldl1keep, [x26, #0x0]\n"
+ "prfm pldl1keep, [x25, #0x0]\n"
+ "prfm pldl1keep, [x24, #0x0]\n"
+ "prfm pldl1keep, [x23, #0x0]\n"
+ "prfm pldl1keep, [x22, #0x0]\n"
+ "prfm pldl1keep, [x21, #0x0]\n"
+ "prfm pldl1keep, [x28, #0x40]\n"
+ "prfm pldl1keep, [x27, #0x40]\n"
+ "prfm pldl1keep, [x26, #0x40]\n"
+ "prfm pldl1keep, [x25, #0x40]\n"
+ "prfm pldl1keep, [x24, #0x40]\n"
+ "prfm pldl1keep, [x23, #0x40]\n"
+ "prfm pldl1keep, [x22, #0x40]\n"
+ "prfm pldl1keep, [x21, #0x40]\n"
+ "blt 3f\n"
+ "2:" // Main loop head
+ "ldr q19, [x28], #0x10\n"
+ "ldr q18, [x26], #0x10\n"
+ "subs %x[width], %x[width], #0x4\n"
+ "ldr q17, [x24], #0x10\n"
+ "ldr q16, [x22], #0x10\n"
+ "cmp %x[width], #0x4\n"
+ "ldr q23, [x27], #0x10\n"
+ "ldr q22, [x25], #0x10\n"
+ "ldr q21, [x23], #0x10\n"
+ "ldr q20, [x21], #0x10\n"
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ "prfm pldl1keep, [x28, #0x70]\n"
+ "prfm pldl1keep, [x27, #0x70]\n"
+ "prfm pldl1keep, [x26, #0x70]\n"
+ "prfm pldl1keep, [x25, #0x70]\n"
+ "prfm pldl1keep, [x24, #0x70]\n"
+ "prfm pldl1keep, [x23, #0x70]\n"
+ ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n"
+ ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n"
+ "prfm pldl1keep, [x22, #0x70]\n"
+ "prfm pldl1keep, [x21, #0x70]\n"
+ ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n"
+ ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n"
+ "str q19, [%x[out_ptr], #0x0]\n"
+ "str q18, [%x[out_ptr], #0x10]\n"
+ "str q17, [%x[out_ptr], #0x20]\n"
+ "str q16, [%x[out_ptr], #0x30]\n"
+ "add %x[out_ptr], %x[out_ptr], #0x40\n"
+ "bge 2b\n"
+ "3:" // Main loop skip
+ "cbz %x[width], 6f\n"
+ "tbz %x[width], #1, 4f\n"
+ "ldr d19, [x28], #0x8\n"
+ "ldr d23, [x27], #0x8\n"
+ "mov x20, #0x1\n"
+ "ldr d18, [x26], #0x8\n"
+ "ldr d22, [x25], #0x8\n"
+ "ldr d17, [x24], #0x8\n"
+ "ldr d21, [x23], #0x8\n"
+ "ldr d16, [x22], #0x8\n"
+ "ldr d20, [x21], #0x8\n"
+ "tbz %x[width], #0, 5f\n"
+ "ld1 { v19.s }[2], [x28]\n"
+ "ld1 { v23.s }[2], [x27]\n"
+ "ld1 { v18.s }[2], [x26]\n"
+ "ld1 { v22.s }[2], [x25]\n"
+ "ld1 { v17.s }[2], [x24]\n"
+ "ld1 { v21.s }[2], [x23]\n"
+ "ld1 { v16.s }[2], [x22]\n"
+ "ld1 { v20.s }[2], [x21]\n"
+ "b 5f\n"
+ "4:" // odd_loads_1_0
+ "ldr s19, [x28, #0x0]\n"
+ "ldr s23, [x27, #0x0]\n"
+ "mov x20, #0x1\n"
+ "ldr s18, [x26, #0x0]\n"
+ "ldr s22, [x25, #0x0]\n"
+ "ldr s17, [x24, #0x0]\n"
+ "ldr s21, [x23, #0x0]\n"
+ "ldr s16, [x22, #0x0]\n"
+ "ldr s20, [x21, #0x0]\n"
+ "5:" // Odd load end
+ ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n"
+ ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x4ea16af3 // bfcvtn2 v19.8h, v23.4s\n"
+ ".inst 0x4ea16ad2 // bfcvtn2 v18.8h, v22.4s\n"
+ ".inst 0x4ea16ab1 // bfcvtn2 v17.8h, v21.4s\n"
+ ".inst 0x4ea16a90 // bfcvtn2 v16.8h, v20.4s\n"
+ "str q19, [%x[out_ptr], #0x0]\n"
+ "str q18, [%x[out_ptr], #0x10]\n"
+ "str q17, [%x[out_ptr], #0x20]\n"
+ "str q16, [%x[out_ptr], #0x30]\n"
+ "add %x[out_ptr], %x[out_ptr], #0x40\n"
+ "6:" // Odds skip
+ : [out_ptr] "+&r"(out), [width] "+&r"(width)
+ : [height] "r"(height), [in] "r"(in), [row_offset] "r"(row_offset)
+ : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "x22", "x23", "x24",
+ "x25", "x26", "x27", "x28");
+ }
+}
+
+#endif // Architectural features check.
diff --git a/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h
new file mode 100644
index 0000000000000000000000000000000000000000..200a66c46ab394d8939fc8a8af55d0e4fa6219bd
--- /dev/null
+++ b/kai/ukernels/matmul/pack/kai_lhs_quant_pack_bf16p_f32_neon.h
@@ -0,0 +1,79 @@
+//
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+#pragma once
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+#include
+#include
+
+#include "kai/kai_common.h"
+
+/// Gets m step value.
+///
+/// The starting row index must be divisible by `m_step`.
+///
+/// @param[in] mr Number of rows to be interleaved.
+///
+/// @return The m step value.
+size_t kai_get_m_step_lhs_quant_pack_bf16p_f32_neon(size_t mr);
+
+/// Gets the offset in bytes to the data element in the LHS buffer.
+///
+/// @param[in] m_idx Row index.
+/// @param[in] lhs_stride Row stride in bytes.
+///
+/// @return The offset in bytes to the data element.
+size_t kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon(size_t m_idx, size_t lhs_stride);
+
+/// Gets the offset in bytes to the data element in the packed LHS buffer.
+///
+/// @param[in] m_idx Row index in the unpacked LHS matrix.
+/// @param[in] k Number of columns in the unpacked LHS matrix.
+/// @param[in] mr Number of rows to be interleaved.
+/// @param[in] kr Number of columns to be interleaved.
+/// @param[in] sr Unused. Must be 1.
+///
+/// @return The offset in bytes to the data element.
+size_t kai_get_lhs_packed_offset_lhs_quant_pack_bf16p_f32_neon(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr);
+
+/// Gets the size in bytes of the packed LHS buffer.
+///
+/// @param[in] m Number of rows in the unpacked LHS matrix.
+/// @param[in] k Number of columns in the unpacked LHS matrix.
+/// @param[in] mr Number of rows to be interleaved.
+/// @param[in] kr Number of columns to be interleaved.
+/// @param[in] sr Unused. Must be 1.
+///
+/// @return The size in bytes of the packed LHS buffer.
+size_t kai_get_lhs_packed_size_lhs_quant_pack_bf16p_f32_neon(size_t m, size_t k, size_t mr, size_t kr, size_t sr);
+
+/// Runs the LHS packing function for matrix multiplication.
+///
+/// The pointer of each buffers (LHS and packed LHS) needs to be added with offset
+/// calculated using the following functions:
+///
+/// * LHS: @ref kai_get_lhs_offset_lhs_quant_pack_bf16p_f32_neon.
+/// * Packed LHS: @ref kai_get_lhs_packed_offset_lhs_quant_pack_bf16p_f32_neon.
+///
+/// @param[in] m Number of rows of the unpacked LHS matrix.
+/// @param[in] k Common dimension between the LHS and RHS matrix.
+/// @param[in] mr Block size in M dimension. It must be 8.
+/// @param[in] kr Block size in K dimension. It must be 4.
+/// @param[in] sr Number of kr splits. It must be 1.
+/// @param[in] m_idx_start Unused. Must be 0.
+/// @param[in] lhs LHS matrix data buffer.
+/// @param[in] lhs_stride Row stride in bytes of the LHS matrix.
+/// @param[out] lhs_packed Packed LHS matrix.
+void kai_run_lhs_quant_pack_bf16p_f32_neon(
+ size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
+ void* lhs_packed);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
diff --git a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c
new file mode 100644
index 0000000000000000000000000000000000000000..bf9eb10720dcda6a0339f16ca50ee217094fe5ad
--- /dev/null
+++ b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.c
@@ -0,0 +1,464 @@
+//
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#if !defined(__aarch64__)
+#error This file must be compiled for AArch64.
+#else // Architectural features check.
+
+#define MAX_NR 12
+
+#include
+#include
+#include
+#include
+
+#include "kai/kai_common.h"
+
+size_t kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t nr) {
+ return nr;
+}
+
+size_t kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx) {
+ return n_idx * sizeof(float);
+}
+
+size_t kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx) {
+ return n_idx * sizeof(uint32_t);
+}
+
+size_t kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(
+ size_t n_idx, size_t k, size_t nr, size_t kr) {
+ KAI_ASSUME(n_idx % nr == 0);
+
+ return n_idx * (sizeof(uint32_t) + kai_roundup(k, kr) * sizeof(uint16_t));
+}
+
+size_t kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n, size_t k, size_t nr, size_t kr) {
+ return kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(kai_roundup(n, nr), k, nr, kr);
+}
+
+void kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(
+ size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
+ const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) {
+ KAI_ASSUME(num_groups == 1);
+ KAI_ASSUME(sr == 1);
+ KAI_ASSUME(rhs != NULL);
+ KAI_ASSUME(scale == NULL);
+ KAI_ASSUME(rhs_packed != NULL);
+ KAI_ASSUME(extra_bytes == 0);
+ KAI_ASSUME(params == NULL);
+ KAI_ASSUME(nr <= MAX_NR);
+
+ size_t height = k;
+ const size_t width = n;
+ const void* in = (void*)rhs;
+ void* out = rhs_packed;
+ const size_t in_stride = rhs_stride;
+ const float* pad_row = rhs;
+
+ // Fill zeros if bias is nullptr
+ size_t bias_step = nr * sizeof(float);
+ uint8_t zero_bias[MAX_NR * sizeof(float)];
+
+ if (bias == NULL) {
+ memset(zero_bias, 0, MAX_NR * sizeof(float));
+ bias_step = 0;
+ }
+
+ const void* bias_ptr = bias == NULL ? (void*)zero_bias : (void*)bias;
+
+ const size_t out_stride = nr * kai_roundup(height, kr) * sizeof(uint16_t) + nr * sizeof(uint32_t);
+
+ __asm__ __volatile__(
+ "mov x22, %x[width]\n"
+ "mov x21, %x[out]\n"
+ "cmp x22, #0xc\n"
+ "blt 2f\n"
+ "1:" // Bias: Full loop
+ "ldr q16, [%x[bias], #0x0]\n"
+ "ldr q26, [%x[bias], #0x10]\n"
+ "sub x22, x22, #0xc\n"
+ "ldr q8, [%x[bias], #0x20]\n"
+ "cmp x22, #0xc\n"
+ "add %x[bias], %x[bias], %x[bias_step]\n"
+ "str q16, [x21, #0x0]\n"
+ "str q26, [x21, #0x10]\n"
+ "str q8, [x21, #0x20]\n"
+ "add x21, x21, %x[out_stride]\n"
+ "bge 1b\n"
+ "cbz x22, 3f\n"
+ "2:" // Bias: Tail loop
+ "ldr w20, [%x[bias], #0x0]\n"
+ "sub x22, x22, #0x1\n"
+ "add %x[bias], %x[bias], #0x4\n"
+ "cmp x22, #0x0\n"
+ "str w20, [x21]\n"
+ "add x21, x21, #0x4\n"
+ "bgt 2b\n"
+ "3:" // Bias: Done
+ "cmp %x[height], #0x8\n"
+ "add %x[out], %x[out], #0x30\n"
+ "blt 12f\n"
+ "4:" // Main row loop: Head
+ "mov x9, %x[in]\n"
+ "mov x28, %x[width]\n"
+ "mov x27, %x[out]\n"
+ "sub %x[height], %x[height], #0x8\n"
+ "add x26, x9, %x[in_stride]\n"
+ "add x25, x26, %x[in_stride]\n"
+ "add x24, x25, %x[in_stride]\n"
+ "cmp x28, #0xc\n"
+ "add x23, x24, %x[in_stride]\n"
+ "add x22, x23, %x[in_stride]\n"
+ "add x21, x22, %x[in_stride]\n"
+ "add x20, x21, %x[in_stride]\n"
+ "add %x[in], x20, %x[in_stride]\n"
+ "blt 6f\n"
+ "5:" // Main row loop: Column loop
+ "ldr q28, [x9], #0x10\n"
+ "ldr q27, [x26], #0x10\n"
+ "sub x28, x28, #0xc\n"
+ "ldr q11, [x25], #0x10\n"
+ "ldr q5, [x24], #0x10\n"
+ "cmp x28, #0xc\n"
+ "ldr q14, [x23], #0x10\n"
+ "ldr q6, [x22], #0x10\n"
+ "ldr q2, [x21], #0x10\n"
+ "ldr q18, [x20], #0x10\n"
+ "ldr q1, [x9], #0x10\n"
+ "ldr q7, [x26], #0x10\n"
+ "zip1 v15.4s, v28.4s, v11.4s\n"
+ "zip1 v8.4s, v27.4s, v5.4s\n"
+ "ldr q3, [x25], #0x10\n"
+ "ldr q23, [x24], #0x10\n"
+ "zip2 v17.4s, v28.4s, v11.4s\n"
+ "zip2 v27.4s, v27.4s, v5.4s\n"
+ "ldr q5, [x23], #0x10\n"
+ "ldr q30, [x22], #0x10\n"
+ "zip1 v26.4s, v14.4s, v2.4s\n"
+ "zip1 v31.4s, v6.4s, v18.4s\n"
+ "ldr q20, [x21], #0x10\n"
+ "ldr q16, [x20], #0x10\n"
+ "zip2 v12.4s, v14.4s, v2.4s\n"
+ "zip2 v24.4s, v6.4s, v18.4s\n"
+ "ldr q29, [x9], #0x10\n"
+ "ldr q6, [x26], #0x10\n"
+ "zip1 v18.4s, v1.4s, v3.4s\n"
+ "zip1 v4.4s, v7.4s, v23.4s\n"
+ "ldr q22, [x25], #0x10\n"
+ "ldr q0, [x24], #0x10\n"
+ "zip2 v3.4s, v1.4s, v3.4s\n"
+ "zip2 v1.4s, v7.4s, v23.4s\n"
+ "ldr q2, [x23], #0x10\n"
+ "ldr q10, [x22], #0x10\n"
+ "zip1 v28.4s, v5.4s, v20.4s\n"
+ "zip1 v14.4s, v30.4s, v16.4s\n"
+ "ldr q9, [x21], #0x10\n"
+ "ldr q23, [x20], #0x10\n"
+ "zip2 v13.4s, v5.4s, v20.4s\n"
+ "zip2 v30.4s, v30.4s, v16.4s\n"
+ "zip1 v16.4s, v29.4s, v22.4s\n"
+ "zip1 v5.4s, v6.4s, v0.4s\n"
+ "zip2 v22.4s, v29.4s, v22.4s\n"
+ "zip2 v0.4s, v6.4s, v0.4s\n"
+ "zip1 v7.4s, v2.4s, v9.4s\n"
+ "zip1 v19.4s, v10.4s, v23.4s\n"
+ "zip2 v21.4s, v2.4s, v9.4s\n"
+ "zip2 v25.4s, v10.4s, v23.4s\n"
+ "zip1 v11.4s, v15.4s, v8.4s\n"
+ "zip1 v9.4s, v17.4s, v27.4s\n"
+ "zip1 v6.4s, v18.4s, v4.4s\n"
+ "zip1 v2.4s, v3.4s, v1.4s\n"
+ "zip1 v29.4s, v16.4s, v5.4s\n"
+ "zip1 v20.4s, v22.4s, v0.4s\n"
+ "zip1 v10.4s, v26.4s, v31.4s\n"
+ "zip1 v23.4s, v12.4s, v24.4s\n"
+ ".inst 0x0ea1696b // bfcvtn v11.4h, v11.4s\n"
+ "zip2 v8.4s, v15.4s, v8.4s\n"
+ "zip1 v15.4s, v28.4s, v14.4s\n"
+ ".inst 0x0ea16929 // bfcvtn v9.4h, v9.4s\n"
+ "zip2 v27.4s, v17.4s, v27.4s\n"
+ "zip1 v17.4s, v13.4s, v30.4s\n"
+ ".inst 0x0ea168c6 // bfcvtn v6.4h, v6.4s\n"
+ "zip2 v4.4s, v18.4s, v4.4s\n"
+ "zip1 v18.4s, v7.4s, v19.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "zip2 v1.4s, v3.4s, v1.4s\n"
+ "zip1 v3.4s, v21.4s, v25.4s\n"
+ ".inst 0x0ea16bbd // bfcvtn v29.4h, v29.4s\n"
+ "zip2 v5.4s, v16.4s, v5.4s\n"
+ ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n"
+ "zip2 v0.4s, v22.4s, v0.4s\n"
+ ".inst 0x0ea16956 // bfcvtn v22.4h, v10.4s\n"
+ "zip2 v31.4s, v26.4s, v31.4s\n"
+ ".inst 0x0ea16aea // bfcvtn v10.4h, v23.4s\n"
+ "zip2 v26.4s, v12.4s, v24.4s\n"
+ ".inst 0x0ea169ef // bfcvtn v15.4h, v15.4s\n"
+ "zip2 v12.4s, v28.4s, v14.4s\n"
+ ".inst 0x0ea16a2e // bfcvtn v14.4h, v17.4s\n"
+ "zip2 v24.4s, v13.4s, v30.4s\n"
+ ".inst 0x0ea16a57 // bfcvtn v23.4h, v18.4s\n"
+ "zip2 v18.4s, v7.4s, v19.4s\n"
+ ".inst 0x0ea16871 // bfcvtn v17.4h, v3.4s\n"
+ "zip2 v16.4s, v21.4s, v25.4s\n"
+ ".inst 0x4ea1690b // bfcvtn2 v11.8h, v8.4s\n"
+ ".inst 0x4ea16b69 // bfcvtn2 v9.8h, v27.4s\n"
+ ".inst 0x4ea16886 // bfcvtn2 v6.8h, v4.4s\n"
+ ".inst 0x4ea16822 // bfcvtn2 v2.8h, v1.4s\n"
+ ".inst 0x4ea168bd // bfcvtn2 v29.8h, v5.4s\n"
+ ".inst 0x4ea16814 // bfcvtn2 v20.8h, v0.4s\n"
+ ".inst 0x4ea16bf6 // bfcvtn2 v22.8h, v31.4s\n"
+ ".inst 0x4ea16b4a // bfcvtn2 v10.8h, v26.4s\n"
+ "str q11, [x27, #0x0]\n"
+ ".inst 0x4ea1698f // bfcvtn2 v15.8h, v12.4s\n"
+ ".inst 0x4ea16b0e // bfcvtn2 v14.8h, v24.4s\n"
+ "str q9, [x27, #0x10]\n"
+ ".inst 0x4ea16a57 // bfcvtn2 v23.8h, v18.4s\n"
+ ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n"
+ "str q6, [x27, #0x20]\n"
+ "str q2, [x27, #0x30]\n"
+ "str q29, [x27, #0x40]\n"
+ "str q20, [x27, #0x50]\n"
+ "str q22, [x27, #0x60]\n"
+ "str q10, [x27, #0x70]\n"
+ "str q15, [x27, #0x80]\n"
+ "str q14, [x27, #0x90]\n"
+ "str q23, [x27, #0xa0]\n"
+ "str q17, [x27, #0xb0]\n"
+ "add x27, x27, %x[out_stride]\n"
+ "bge 5b\n"
+ "6:" // Main row loop: Column loop skip
+ "cbz x28, 11f\n"
+ "cmp x28, #0x4\n"
+ "movi v16.16b, #0x0\n"
+ "str q16, [x27, #0x0]\n"
+ "str q16, [x27, #0x10]\n"
+ "str q16, [x27, #0x20]\n"
+ "str q16, [x27, #0x30]\n"
+ "str q16, [x27, #0x40]\n"
+ "str q16, [x27, #0x50]\n"
+ "str q16, [x27, #0x60]\n"
+ "str q16, [x27, #0x70]\n"
+ "str q16, [x27, #0x80]\n"
+ "str q16, [x27, #0x90]\n"
+ "str q16, [x27, #0xa0]\n"
+ "str q16, [x27, #0xb0]\n"
+ "blt 8f\n"
+ "7:" // Main row loop: width 4 loop: loop
+ "ldr q25, [x9], #0x10\n"
+ "ldr q24, [x26], #0x10\n"
+ "sub x28, x28, #0x4\n"
+ "ldr q21, [x25], #0x10\n"
+ "ldr q20, [x24], #0x10\n"
+ "cmp x28, #0x4\n"
+ "ldr q23, [x23], #0x10\n"
+ "ldr q19, [x22], #0x10\n"
+ "ldr q18, [x21], #0x10\n"
+ "ldr q17, [x20], #0x10\n"
+ "zip1 v22.4s, v25.4s, v21.4s\n"
+ "zip1 v16.4s, v24.4s, v20.4s\n"
+ "zip2 v21.4s, v25.4s, v21.4s\n"
+ "zip2 v20.4s, v24.4s, v20.4s\n"
+ "zip1 v27.4s, v23.4s, v18.4s\n"
+ "zip1 v26.4s, v19.4s, v17.4s\n"
+ "zip2 v25.4s, v23.4s, v18.4s\n"
+ "zip2 v24.4s, v19.4s, v17.4s\n"
+ "zip1 v19.4s, v22.4s, v16.4s\n"
+ "zip1 v18.4s, v21.4s, v20.4s\n"
+ "zip1 v17.4s, v27.4s, v26.4s\n"
+ "zip2 v23.4s, v22.4s, v16.4s\n"
+ "zip1 v16.4s, v25.4s, v24.4s\n"
+ "zip2 v22.4s, v21.4s, v20.4s\n"
+ ".inst 0x0ea16a75 // bfcvtn v21.4h, v19.4s\n"
+ ".inst 0x0ea16a54 // bfcvtn v20.4h, v18.4s\n"
+ ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n"
+ "zip2 v18.4s, v27.4s, v26.4s\n"
+ ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n"
+ "zip2 v16.4s, v25.4s, v24.4s\n"
+ ".inst 0x4ea16af5 // bfcvtn2 v21.8h, v23.4s\n"
+ ".inst 0x4ea16ad4 // bfcvtn2 v20.8h, v22.4s\n"
+ ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n"
+ ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n"
+ "str q21, [x27, #0x0]\n"
+ "str q20, [x27, #0x10]\n"
+ "str q19, [x27, #0x60]\n"
+ "str q17, [x27, #0x70]\n"
+ "add x27, x27, #0x20\n"
+ "bge 7b\n"
+ "8:" // Main row loop: width 4 loop: skip
+ "cmp x28, #0x1\n"
+ "blt 10f\n"
+ "9:" // Main row loop: width 1 loop: loop
+ "ldr s23, [x9], #0x4\n"
+ "ldr s22, [x26], #0x4\n"
+ "sub x28, x28, #0x1\n"
+ "ldr s19, [x25], #0x4\n"
+ "ldr s17, [x24], #0x4\n"
+ "cmp x28, #0x1\n"
+ "ldr s21, [x23], #0x4\n"
+ "ldr s20, [x22], #0x4\n"
+ "ldr s18, [x21], #0x4\n"
+ "ldr s16, [x20], #0x4\n"
+ "zip1 v19.4s, v23.4s, v19.4s\n"
+ "zip1 v17.4s, v22.4s, v17.4s\n"
+ "zip1 v18.4s, v21.4s, v18.4s\n"
+ "zip1 v16.4s, v20.4s, v16.4s\n"
+ "zip1 v17.4s, v19.4s, v17.4s\n"
+ "zip1 v16.4s, v18.4s, v16.4s\n"
+ ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ "str d17, [x27, #0x0]\n"
+ "str d16, [x27, #0x60]\n"
+ "add x27, x27, #0x8\n"
+ "bge 9b\n"
+ "10:" // Main row loop: width 1 loop: skip
+ "11:" // Main row loop: odd col skip
+ "cmp %x[height], #0x8\n"
+ "add %x[out], %x[out], #0xc0\n"
+ "bge 4b\n"
+ "cbz %x[height], 21f\n"
+ "12:" // Main loop skip
+ "13:" // Tail row loop: Head
+ "mov x9, %x[in]\n"
+ "mov x20, %x[width]\n"
+ "cmp %x[height], #0x3\n"
+ "mov x27, %x[out]\n"
+ "add x26, x9, %x[in_stride]\n"
+ "add x25, x26, %x[in_stride]\n"
+ "add x24, x25, %x[in_stride]\n"
+ "csel x25, x25, %x[pad_row], GE\n"
+ "add %x[in], x24, %x[in_stride]\n"
+ "csel x24, x24, %x[pad_row], GT\n"
+ "cmp %x[height], #0x1\n"
+ "sub %x[height], %x[height], #0x4\n"
+ "csel x26, x26, %x[pad_row], GT\n"
+ "cmp x20, #0xc\n"
+ "blt 15f\n"
+ "14:" // Tail row loop: Column loop
+ "ldr q24, [x9], #0x10\n"
+ "ldr q23, [x26], #0x10\n"
+ "sub x20, x20, #0xc\n"
+ "ldr q22, [x25], #0x10\n"
+ "ldr q16, [x24], #0x10\n"
+ "cmp x20, #0xc\n"
+ "ldr q28, [x9], #0x10\n"
+ "ldr q27, [x26], #0x10\n"
+ "ldr q21, [x25], #0x10\n"
+ "ldr q20, [x24], #0x10\n"
+ "ldr q19, [x9], #0x10\n"
+ "zip1 v26.4s, v24.4s, v22.4s\n"
+ "zip1 v25.4s, v23.4s, v16.4s\n"
+ "ldr q18, [x26], #0x10\n"
+ "ldr q17, [x25], #0x10\n"
+ "zip2 v24.4s, v24.4s, v22.4s\n"
+ "zip2 v23.4s, v23.4s, v16.4s\n"
+ "ldr q16, [x24], #0x10\n"
+ "zip1 v2.4s, v28.4s, v21.4s\n"
+ "zip1 v22.4s, v27.4s, v20.4s\n"
+ "zip2 v1.4s, v28.4s, v21.4s\n"
+ "zip2 v0.4s, v27.4s, v20.4s\n"
+ "zip1 v31.4s, v19.4s, v17.4s\n"
+ "zip1 v30.4s, v18.4s, v16.4s\n"
+ "zip2 v29.4s, v19.4s, v17.4s\n"
+ "zip2 v28.4s, v18.4s, v16.4s\n"
+ "zip1 v21.4s, v26.4s, v25.4s\n"
+ "zip1 v20.4s, v24.4s, v23.4s\n"
+ "zip1 v19.4s, v2.4s, v22.4s\n"
+ "zip1 v18.4s, v1.4s, v0.4s\n"
+ "zip1 v17.4s, v31.4s, v30.4s\n"
+ "zip1 v16.4s, v29.4s, v28.4s\n"
+ ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n"
+ "zip2 v26.4s, v26.4s, v25.4s\n"
+ ".inst 0x0ea16a99 // bfcvtn v25.4h, v20.4s\n"
+ "zip2 v24.4s, v24.4s, v23.4s\n"
+ ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n"
+ "zip2 v22.4s, v2.4s, v22.4s\n"
+ ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n"
+ "zip2 v20.4s, v1.4s, v0.4s\n"
+ ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n"
+ "zip2 v18.4s, v31.4s, v30.4s\n"
+ ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n"
+ "zip2 v16.4s, v29.4s, v28.4s\n"
+ ".inst 0x4ea16b5b // bfcvtn2 v27.8h, v26.4s\n"
+ ".inst 0x4ea16b19 // bfcvtn2 v25.8h, v24.4s\n"
+ ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n"
+ ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n"
+ ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n"
+ ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n"
+ "str q27, [x27, #0x0]\n"
+ "str q25, [x27, #0x10]\n"
+ "str q23, [x27, #0x20]\n"
+ "str q21, [x27, #0x30]\n"
+ "str q19, [x27, #0x40]\n"
+ "str q17, [x27, #0x50]\n"
+ "add x27, x27, %x[out_stride]\n"
+ "bge 14b\n"
+ "15:" // Tail row loop: Column loop skip
+ "cbz x20, 20f\n"
+ "cmp x20, #0x4\n"
+ "movi v16.16b, #0x0\n"
+ "str q16, [x27, #0x0]\n"
+ "str q16, [x27, #0x10]\n"
+ "str q16, [x27, #0x20]\n"
+ "str q16, [x27, #0x30]\n"
+ "str q16, [x27, #0x40]\n"
+ "str q16, [x27, #0x50]\n"
+ "blt 17f\n"
+ "16:" // Tail row loop: width 4 loop: loop
+ "ldr q21, [x9], #0x10\n"
+ "ldr q20, [x26], #0x10\n"
+ "sub x20, x20, #0x4\n"
+ "ldr q19, [x25], #0x10\n"
+ "ldr q17, [x24], #0x10\n"
+ "cmp x20, #0x4\n"
+ "zip1 v18.4s, v21.4s, v19.4s\n"
+ "zip1 v16.4s, v20.4s, v17.4s\n"
+ "zip2 v21.4s, v21.4s, v19.4s\n"
+ "zip2 v20.4s, v20.4s, v17.4s\n"
+ "zip1 v17.4s, v18.4s, v16.4s\n"
+ "zip2 v19.4s, v18.4s, v16.4s\n"
+ "zip1 v16.4s, v21.4s, v20.4s\n"
+ ".inst 0x0ea16a32 // bfcvtn v18.4h, v17.4s\n"
+ "zip2 v17.4s, v21.4s, v20.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ ".inst 0x4ea16a72 // bfcvtn2 v18.8h, v19.4s\n"
+ ".inst 0x4ea16a30 // bfcvtn2 v16.8h, v17.4s\n"
+ "str q18, [x27, #0x0]\n"
+ "str q16, [x27, #0x10]\n"
+ "add x27, x27, #0x20\n"
+ "bge 16b\n"
+ "17:" // Tail row loop: width 4 loop: skip
+ "cmp x20, #0x1\n"
+ "blt 19f\n"
+ "18:" // Tail row loop: width 1 loop: loop
+ "ldr s19, [x9], #0x4\n"
+ "ldr s18, [x26], #0x4\n"
+ "sub x20, x20, #0x1\n"
+ "ldr s17, [x25], #0x4\n"
+ "ldr s16, [x24], #0x4\n"
+ "cmp x20, #0x1\n"
+ "zip1 v17.4s, v19.4s, v17.4s\n"
+ "zip1 v16.4s, v18.4s, v16.4s\n"
+ "zip1 v16.4s, v17.4s, v16.4s\n"
+ ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n"
+ "str d16, [x27, #0x0]\n"
+ "add x27, x27, #0x8\n"
+ "bge 18b\n"
+ "19:" // Tail row loop: width 1 loop: skip
+ "20:" // Tail row loop: odd col skip
+ "cmp %x[height], #0x1\n"
+ "add %x[out], %x[out], #0x60\n"
+ "bge 13b\n"
+ "21:" // Done
+ : [bias] "+&r"(bias_ptr), [height] "+&r"(height), [in] "+&r"(in), [out] "+&r"(out)
+ : [bias_step] "r"(bias_step), [in_stride] "r"(in_stride), [out_stride] "r"(out_stride), [pad_row] "r"(pad_row),
+ [width] "r"(width)
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
+ "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28");
+}
+
+#endif // Architectural features check.
diff --git a/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h
new file mode 100644
index 0000000000000000000000000000000000000000..f786c7a7c92865d84184ee079072250ff3c5ce08
--- /dev/null
+++ b/kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.h
@@ -0,0 +1,84 @@
+//
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+/// Gets n step value.
+///
+/// The starting row index must be divisible by `n_step`.
+///
+/// @return The n step value.
+size_t kai_get_n_step_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(void);
+
+/// Gets the offset in bytes to the data element in the RHS matrix buffer.
+///
+/// @param[in] n_idx Column index.
+///
+/// @return The offset in bytes to the data element.
+size_t kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx);
+
+/// Gets the offset in bytes to the data element in the bias buffer.
+///
+/// @param[in] n_idx Column index.
+///
+/// @return The offset in bytes to the data element.
+size_t kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx);
+
+/// Gets the offset in bytes to the data element in the packed RHS buffer.
+///
+/// @param[in] n_idx Row index.
+/// @param[in] k Number of columns.
+/// @param[in] nr Block size in N dimension.
+/// @param[in] kr Block size in K dimension.
+///
+/// @return The offset in bytes to the data element.
+size_t kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n_idx, size_t k, size_t nr, size_t kr);
+
+/// Gets the size in bytes of the packed RHS buffer.
+///
+/// @param[in] n Number of rows.
+/// @param[in] k Number of columns.
+/// @param[in] nr Block size in N dimension.
+/// @param[in] kr Block size in K dimension.
+///
+/// @return The size in bytes of the packed RHS buffer.
+size_t kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(size_t n, size_t k, size_t nr, size_t kr);
+
+/// Runs the RHS packing function for matrix multiplication.
+///
+/// The pointer of each buffers (RHS, bias and packed RHS) needs to be added with offset
+/// calculated using the following functions:
+///
+/// * RHS: @ref kai_get_rhs_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.
+/// * Bias: @ref kai_get_bias_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.
+/// * Output: @ref kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon.
+///
+/// @param[in] num_groups Number of groups. It must be 1.
+/// @param[in] n Number of columns of the output matrix.
+/// @param[in] k Common dimension between the LHS and RHS matrix.
+/// @param[in] nr Block size in N dimension.
+/// @param[in] kr Block size in K dimension.
+/// @param[in] sr Number of kr splits. It must be 1.
+/// @param[in] rhs_stride Row stride in bytes of the RHS matrix.
+/// @param[in] rhs RHS matrix data buffer.
+/// @param[in] bias Bias matrix data buffer.
+/// @param[in] scale Scale data buffer. It must be NULL.
+/// @param[out] rhs_packed Packed RHS matrix.
+/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. It must be 0.
+/// @param[in] params Extra packing parameters. It must be NULL.
+void kai_run_rhs_quant_pack_kxn_bf16pbiasf32_f32_neon(
+ size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
+ const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
diff --git a/test/common/bfloat16.hpp b/test/common/bfloat16.hpp
index ca0e0b371f6bad9f6b00ca55fd6a3e09c331171a..9291918a638499b94545e6b9ff70c30310c46874 100644
--- a/test/common/bfloat16.hpp
+++ b/test/common/bfloat16.hpp
@@ -82,6 +82,14 @@ public:
return _data != rhs._data;
}
+ uint16_t data() const {
+ return _data;
+ }
+
+ void set_data(uint16_t data) {
+ _data = data;
+ }
+
/// Writes the value to the output stream.
///
/// @param[in] os Output stream to be written to.
diff --git a/test/common/compare.cpp b/test/common/compare.cpp
index 54af776fd6937a5dd4deaa3469977a7ec1db3a83..b000f3f5efaa7e06295fd4539e7d05b021806008 100644
--- a/test/common/compare.cpp
+++ b/test/common/compare.cpp
@@ -213,6 +213,9 @@ bool compare(
case DataType::FP16:
return compare_raw(imp_data, ref_data, format, full_height, full_width, rect, handler);
+ case DataType::BF16:
+ return compare_raw(imp_data, ref_data, format, full_height, full_width, rect, handler);
+
default:
break;
}
diff --git a/test/common/matmul_test_common.cpp b/test/common/matmul_test_common.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..73d41c09e28cab56db5c95df6e78e1e0c757c319
--- /dev/null
+++ b/test/common/matmul_test_common.cpp
@@ -0,0 +1,24 @@
+//
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "matmul_test_common.hpp"
+
+#include
+
+namespace kai::test {
+void PrintTo(const MatMulTestParams& param, std::ostream* os) {
+ const auto& [method, shape, portion] = param;
+
+ // NOLINTBEGIN(cppcoreguidelines-pro-bounds-constant-array-index)
+ *os << "Method_" << method.name //
+ << "__M_" << shape.m << "__N_" << shape.n << "__K_" << shape.k //
+ << "__PortionStartRow_" << static_cast(portion.start_row() * 1000) //
+ << "__PortionStartCol_" << static_cast(portion.start_col() * 1000) //
+ << "__PortionHeight_" << static_cast(portion.height() * 1000) //
+ << "__PortionWidth_" << static_cast(portion.width() * 1000);
+ // NOLINTEND(cppcoreguidelines-pro-bounds-constant-array-index)
+}
+} // namespace kai::test
diff --git a/test/common/matmul_test_common.hpp b/test/common/matmul_test_common.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..21f6e244c434d6789321c69ddf5b4a2471eec3d2
--- /dev/null
+++ b/test/common/matmul_test_common.hpp
@@ -0,0 +1,359 @@
+//
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+
+#include "kai/kai_common.h"
+#include "test/common/data_format.hpp"
+#include "test/common/float16.hpp"
+#include "test/common/matrix_portion.hpp"
+
+namespace kai::test {
+/// Matrix multiplication shape.
+struct MatMulShape {
+ size_t m; ///< LHS height.
+ size_t n; ///< RHS width.
+ size_t k; ///< LHS width and RHS height.
+};
+
+// NOLINTBEGIN(misc-non-private-member-variables-in-classes)
+
+/// Matrix multiplication method.
+struct MatMulMethod {
+ std::string_view name; ///< Name of matmul method.
+
+ size_t m0{0}; ///< Block size in M dimension.
+ size_t n0{0}; ///< Block size in N dimension.
+ size_t k0{0}; ///< Block size in K dimension.
+
+ bool lhs_transposed; ///< LHS matrix is transposed.
+ bool rhs_transposed; ///< RHS matrix is transposed.
+
+ DataFormat dst_format; ///< Data format of the destination matrix.
+ DataFormat lhs_format; ///< Data format of the LHS matrix.
+ DataFormat packed_lhs_format; ///< Data format of the packed LHS matrix.
+ DataFormat rhs_format; ///< Data format of the RHS matrix.
+ DataFormat packed_rhs_format; ///< Data format of the packed RHS matrix.
+ DataFormat bias_format; ///< Data format of the bias vector.
+
+ /// Check if CPU supports required features.
+ ///
+ /// @return Supported (true) or not supported (false).
+ std::function fn_is_supported;
+
+ /// Gets mr value.
+ ///
+ /// This is the packing parameter which must be used to pack the LHS matrix (if necessary).
+ ///
+ /// @return The mr value.
+ std::function fn_get_mr;
+
+ /// Gets nr value.
+ ///
+ /// This is the packing parameter which must be used to pack the RHS matrix (if necessary).
+ ///
+ /// @return The nr value.
+ std::function fn_get_nr;
+
+ /// Gets kr value.
+ ///
+ /// This is the packing parameter which must be used to pack the LHS and RHS matrix (if necessary).
+ ///
+ /// @return The kr value.
+ std::function fn_get_kr;
+
+ /// Gets sr value.
+ ///
+ /// This is the packing parameter which must be used to pack the RHS matrix.
+ ///
+ /// @return The sr value.
+ std::function fn_get_sr;
+
+ /// Gets m step value for main kernel.
+ ///
+ /// The starting row index must be divisible by `m_step`.
+ ///
+ /// @return The m step value.
+ std::function fn_get_main_m_step;
+
+ /// Gets n step value for RHS packing kernel.
+ ///
+ /// The starting row index must be divisible by `n_step`.
+ ///
+ /// @return The n step value.
+ std::function fn_get_pack_rhs_n_step;
+
+ /// Gets n step value for main kernel.
+ ///
+ /// The starting column index must be divisible by `n_step`.
+ ///
+ /// @return The n step value.
+ std::function fn_get_main_n_step;
+
+ /// Gets the offset in bytes of the LHS matrix.
+ ///
+ /// @param[in] m_idx Coordinate of the matrix in M dimension.
+ /// @param[in] stride Row stride in bytes.
+ ///
+ /// @return The offset in bytes.
+ std::function fn_get_lhs_offset;
+
+ /// Gets the size in bytes of the packed LHS matrix.
+ ///
+ /// @param[in] m Number of rows in the unpacked LHS matrix.
+ /// @param[in] k Number of columns in the unpacked LHS matrix.
+ /// @param[in] mr Number of rows to be interleaved.
+ /// @param[in] kr Unused. Must be 1.
+ /// @param[in] sr Unused. Must be 1.
+ ///
+ /// @return The size in bytes.
+ std::function fn_get_packed_lhs_size;
+
+ /// Gets the offset in bytes of the packed LHS matrix.
+ ///
+ /// @param[in] m_idx Coordinate of the matrix in M dimension.
+ /// @param[in] k Size of the matrix in K dimension.
+ ///
+ /// @return The offset in bytes.
+ std::function fn_get_packed_lhs_offset;
+
+ /// Preprocesses the LHS matrix.
+ ///
+ /// @param[in] m Number of rows of the unpacked LHS matrix.
+ /// @param[in] k Common dimension between the LHS and RHS matrix.
+ /// @param[in] mr Block size in M dimension. It must be {{ kernel.interleave_by }}VL.
+ /// @param[in] kr Block size in K dimension. It must be {{ kernel.block_by }}.
+ /// @param[in] sr Number of kr splits. It must be 1.
+ /// @param[in] m_idx_start Unused. Must be 0.
+ /// @param[in] lhs LHS matrix data buffer.
+ /// @param[in] lhs_stride Row stride in bytes of the LHS matrix.
+ /// @param[out] lhs_packed Packed RHS matrix.
+ std::function
+ fn_pack_lhs;
+
+ /// Gets a value indicating whether LHS packing is needed.
+ [[nodiscard]] bool is_pack_lhs_needed() const {
+ return fn_pack_lhs != nullptr;
+ }
+
+ /// Gets the offset in bytes of the RHS matrix.
+ ///
+ /// @param[in] n_idx Coordinate of the matrix in N dimension.
+ ///
+ /// @return The offset in bytes.
+ std::function fn_get_rhs_offset;
+
+ /// Gets the size in bytes of the packed RHS matrix.
+ ///
+ /// @param[in] n Size of the matrix in N dimension.
+ /// @param[in] k Size of the matrix in K dimension.
+ ///
+ /// @return The size in bytes.
+ std::function fn_get_packed_rhs_size;
+
+ /// Gets the size in bytes of the packed RHS matrix.
+ ///
+ /// @param[in] n Size of the matrix in N dimension.
+ /// @param[in] k Size of the matrix in K dimension.
+ /// @param[in] nr Block size in N dimension.
+ /// @param[in] kr Block size in K dimension.
+ ///
+ /// @return The size in bytes.
+ std::function fn_get_packed_rhs_size_generic_block_size = nullptr;
+
+ /// Gets the offset in bytes of the packed RHS matrix in the RHS packing kernel
+ ///
+ /// @param[in] n_idx Coordinate of the matrix in N dimension.
+ /// @param[in] k Size of the matrix in K dimension.
+ ///
+ /// @return The offset in bytes.
+ std::function fn_get_pack_rhs_packed_rhs_offset;
+
+ /// Gets the offset in bytes of the packed RHS matrix in the main kernel.
+ ///
+ /// @param[in] n_idx Coordinate of the matrix in N dimension.
+ /// @param[in] k Size of the matrix in K dimension.
+ ///
+ /// @return The offset in bytes.
+ std::function fn_get_main_packed_rhs_offset;
+
+ std::function
+ fn_pack_rhs;
+
+ /// Gets the offset in bytes to the data element in the bias buffer.
+ ///
+ /// @param[in] n_idx Column index.
+ ///
+ /// @return The offset in bytes to the data element.
+ std::function fn_get_bias_offset;
+
+ /// Gets the offset in bytes to the data element in the destination matrix buffer.
+ ///
+ /// @param[in] m_idx Row index.
+ /// @param[in] n_idx Column index.
+ /// @param[in] stride Row stride in bytes.
+ ///
+ /// @return The offset in bytes to the data element.
+ std::function fn_get_dst_offset;
+
+ /// Gets the size in bytes of the destination matrix buffer.
+ ///
+ /// @param[in] m Number of rows.
+ /// @param[in] n Number of columns.
+ ///
+ /// @return The size in bytes of the destination matrix buffer.
+ std::function fn_get_dst_size;
+
+ /// Performs F16 or F32 matrix multiplication with RHS packing
+ /// followed by clamp operation.
+ ///
+ /// @param[in] m Size of the matrix in M dimension.
+ /// @param[in] n Size of the matrix in N dimension.
+ /// @param[in] k Size of the matrix in K dimension.
+ /// @param[in] lhs LHS data buffer.
+ /// @param[in] packed_rhs Packed RHS data buffer.
+ /// @param[out] dst Output data buffer.
+ /// @param[in] lhs_stride LHS row stride.
+ /// @param[in] dst_stride_row Output row stride.
+ /// @param[in] dst_stride_col Output column stride.
+ /// @param[in] clamp_min Lower bound of the output data.
+ /// @param[in] clamp_max Upper bound of the output data.
+ std::function
+ fn_matmul_f16_f16_f16p = nullptr;
+
+ std::function
+ fn_matmul_f32_f32_f32p = nullptr;
+
+ /// Performs BF16 matrix multiplication with LHS and RHS packing
+ /// followed by clamp operation.
+ ///
+ /// @param[in] m Size of the matrix in M dimension.
+ /// @param[in] n Size of the matrix in N dimension.
+ /// @param[in] k Size of the matrix in K dimension.
+ /// @param[in] packed_lhs Packed LHS data buffer.
+ /// @param[in] packed_rhs Packed RHS data buffer.
+ /// @param[out] dst Output data buffer.
+ /// @param[in] dst_stride_row Output row stride.
+ /// @param[in] dst_stride_col Output column stride.
+ /// @param[in] clamp_min Lower bound of the output data.
+ /// @param[in] clamp_max Upper bound of the output data.
+ std::function
+ fn_matmul_f32_bf16p_bf16p = nullptr;
+
+ /// Performs F32 matrix multiplication with LHS & RHS packing
+ /// followed by clamp operation.
+ ///
+ /// @param[in] m Number of output rows to be computed.
+ /// @param[in] n Number of output columns to be computed.
+ /// @param[in] k Common dimension of the LHS and RHS operands.
+ /// @param[in] packed_lhs Packed LHS matrix buffer.
+ /// @param[in] packed_rhs Packed RHS matrix buffer.
+ /// @param[out] dst Output matrix buffer.
+ /// @param[in] dst_stride_row Row stride in bytes of the output matrix.
+ /// @param[in] dst_stride_col Column stride in bytes of the output matrix.
+ /// @param[in] clamp_min Minimum value to clamp the final result.
+ /// @param[in] clamp_max Maximum value to clamp the final result.
+ std::function
+ fn_matmul_f32_f32p_f32p = nullptr;
+
+ /// Gets a value indicating whether pre-processing the RHS matrix is needed.
+ [[nodiscard]] bool is_pack_rhs_needed() const {
+ return fn_pack_rhs != nullptr;
+ }
+
+ /// Preprocesses the RHS matrix.
+ ///
+ /// @param[in] n Size of the matrix in N dimension.
+ /// @param[in] k Size of the matrix in K dimension.
+ /// @param[in] rhs RHS data buffer.
+ /// @param[in] rhs_row_stride RHS row stride.
+ /// @param[in] bias Bias data buffer.
+ /// @param[in] scale Quantization scales data buffer.
+ /// @param[out] packed_rhs Packed RHS data buffer.
+ void pack_rhs(
+ size_t n, size_t k, const void* rhs, size_t rhs_row_stride, const void* bias, const void* scale,
+ void* packed_rhs) const {
+ KAI_UNUSED(n);
+ KAI_UNUSED(k);
+ KAI_UNUSED(rhs);
+ KAI_UNUSED(rhs_row_stride);
+ KAI_UNUSED(bias);
+ KAI_UNUSED(scale);
+
+ if (fn_pack_rhs != nullptr) {
+ fn_pack_rhs(
+ 1, n, k, fn_get_nr(), fn_get_kr(), fn_get_sr(), rhs_row_stride, rhs, bias, nullptr, packed_rhs, 0,
+ nullptr);
+ } else {
+ KAI_ERROR("RHS pre-processing is not supported!");
+ }
+ }
+
+ [[nodiscard]] bool has_main_kernel() const {
+ return fn_matmul_f16_f16_f16p != nullptr || fn_matmul_f32_f32p_f32p != nullptr ||
+ fn_matmul_f32_f32_f32p != nullptr || fn_matmul_f32_bf16p_bf16p != nullptr;
+ }
+
+ void main_kernel(
+ size_t m, size_t n, size_t k, const void* lhs, const void* rhs, const void* bias, void* dst, size_t lhs_stride,
+ size_t rhs_stride, size_t dst_stride, float clamp_min, float clamp_max) const {
+ KAI_UNUSED(bias);
+ KAI_UNUSED(rhs_stride);
+
+ if (fn_matmul_f16_f16_f16p) {
+ fn_matmul_f16_f16_f16p(
+ m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(Float16), clamp_min,
+ static_cast(clamp_max));
+ } else if (fn_matmul_f32_f32_f32p) {
+ fn_matmul_f32_f32_f32p(
+ m, n, k, lhs, lhs_stride, rhs, dst, dst_stride, sizeof(float), clamp_min,
+ static_cast(clamp_max));
+ } else if (fn_matmul_f32_f32p_f32p) {
+ fn_matmul_f32_f32p_f32p(m, n, k, lhs, rhs, dst, dst_stride, sizeof(float), clamp_min, clamp_max);
+ } else if (fn_matmul_f32_bf16p_bf16p) {
+ fn_matmul_f32_bf16p_bf16p(
+ m, n, k, reinterpret_cast(lhs), rhs, reinterpret_cast(dst), dst_stride,
+ sizeof(float), clamp_min, clamp_max);
+ } else {
+ KAI_ERROR("Main kernel is not available!");
+ }
+ }
+};
+
+// NOLINTEND(misc-non-private-member-variables-in-classes)
+
+/// Matrix multiplication test information.
+using MatMulTestParams = std::tuple;
+
+/// Prints the test information.
+void PrintTo(const MatMulTestParams& param, std::ostream* os);
+} // namespace kai::test
diff --git a/test/common/memory.hpp b/test/common/memory.hpp
index bf5fbb017ed3d4cca5895a895c795c08ff30721f..c856218f6351eefa6c30fdf715ec6f875f756037 100644
--- a/test/common/memory.hpp
+++ b/test/common/memory.hpp
@@ -7,8 +7,11 @@
#pragma once
#include
+#include
#include
+#include "kai/kai_common.h"
+#include "test/common/bfloat16.hpp"
#include "test/common/int4.hpp"
namespace kai::test {
@@ -39,6 +42,9 @@ T read_array(const void* array, size_t index) {
} else if constexpr (std::is_same_v) {
const auto [lo, hi] = Int4::unpack_u8(reinterpret_cast(array)[index / 2]);
return index % 2 == 0 ? lo : hi;
+ } else if constexpr (std::is_same_v) {
+ uint16_t raw_value = reinterpret_cast(array)[index];
+ return BFloat16(kai_cast_f32_bf16(raw_value));
} else {
return reinterpret_cast(array)[index];
}
diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp
index 221ba36079ec8d6c6bf59289826a37ec2525b300..ad123762d0852f007c943674e571a6b5a3414a7e 100644
--- a/test/reference/pack.cpp
+++ b/test/reference/pack.cpp
@@ -10,29 +10,31 @@
#include
#include
#include
-#include
#include
#include "kai/kai_common.h"
+#include "test/common/bfloat16.hpp"
#include "test/common/data_format.hpp"
#include "test/common/data_type.hpp"
-#include "test/common/float16.hpp"
#include "test/common/memory.hpp"
#include "test/common/round.hpp"
-#include "test/reference/quantize.hpp"
namespace kai::test {
namespace {
+uint16_t convert(const uint8_t* src_ptr_elm, DataType src_dtype, DataType dst_dtype) {
+ KAI_ASSUME(src_dtype == DataType::FP32 && dst_dtype == DataType::BF16);
+ return BFloat16(*reinterpret_cast(src_ptr_elm)).data();
+}
+
std::vector pack_block(
- const void* src, size_t data_esize, size_t full_height, size_t full_width, size_t block_height, size_t block_width,
- size_t subblock_height, size_t subblock_width) {
+ const void* src, DataType src_dtype, DataType dst_dtype, size_t src_esize, size_t dst_esize, size_t full_height,
+ size_t full_width, size_t block_height, size_t block_width, size_t subblock_height, size_t subblock_width) {
const auto dst_bytes =
- round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * data_esize;
+ round_up_multiple(full_height, block_height) * round_up_multiple(full_width, block_width) * dst_esize;
- std::vector dst;
- dst.resize(dst_bytes);
+ std::vector dst(dst_bytes, 0);
const auto* src_ptr = reinterpret_cast(src);
auto* dst_ptr = dst.data();
@@ -42,18 +44,38 @@ std::vector pack_block(
for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) {
for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) {
for (size_t y_element = 0; y_element < subblock_height; ++y_element) {
- if (y_block + y_subblock + y_element < full_height) {
- const auto len = std::min(subblock_width, full_width - x_block - x_subblock);
-
- memcpy(
- dst_ptr,
- src_ptr +
- ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock) *
- data_esize,
- len * data_esize);
- }
+ if (src_dtype == dst_dtype) {
+ const size_t esize = dst_esize;
+
+ if (y_block + y_subblock + y_element < full_height) {
+ const auto len = std::min(subblock_width, full_width - x_block - x_subblock);
+
+ memcpy(
+ dst_ptr,
+ src_ptr +
+ ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock) *
+ esize,
+ len * esize);
+ }
- dst_ptr += subblock_width * data_esize;
+ dst_ptr += subblock_width * esize;
+ } else if (dst_esize == 2 /* 16 bits */) {
+ for (size_t x_element = 0; x_element < subblock_width; ++x_element) {
+ if (y_block + y_subblock + y_element < full_height) {
+ if (x_block + x_subblock + x_element < full_width) {
+ const uint8_t* src_ptr_elm = src_ptr +
+ ((y_block + y_subblock + y_element) * full_width + x_block + x_subblock +
+ x_element) *
+ src_esize;
+
+ uint16_t src_value = convert(src_ptr_elm, src_dtype, dst_dtype);
+ memcpy(dst_ptr, &src_value, dst_esize);
+ }
+ }
+
+ dst_ptr += dst_esize;
+ }
+ }
}
}
}
@@ -67,43 +89,65 @@ std::vector pack_block(
/// Packs the matrix from raw to per-row bias format.
std::vector pack_bias_per_row(
- size_t data_esize, size_t zero_point_esize, const void* src, const void* bias, size_t height, size_t width,
- size_t block_height, size_t block_width, size_t subblock_height, size_t subblock_width) {
+ DataType src_dtype, DataType bias_dtype, DataType dst_dtype, size_t src_esize, size_t bias_esize, size_t dst_esize,
+ const void* src, const void* bias, size_t height, size_t width, size_t block_height, size_t block_width,
+ size_t subblock_height, size_t subblock_width) {
+ KAI_ASSUME(src_dtype == bias_dtype);
+
const auto num_groups = (height + block_height - 1) / block_height;
const auto group_num_blocks = (width + block_width - 1) / block_width;
-
- const auto group_zero_points_bytes = block_height * zero_point_esize;
- const auto block_data_bytes = block_height * block_width * data_esize;
- const auto group_bytes = group_zero_points_bytes + group_num_blocks * block_data_bytes;
+ const auto group_bias_bytes = block_height * bias_esize;
+ const auto block_data_bytes = block_height * block_width * dst_esize;
+ const auto group_bytes = group_bias_bytes + group_num_blocks * block_data_bytes;
const auto dst_bytes = num_groups * group_bytes;
- std::vector dst;
- dst.resize(dst_bytes);
+ std::vector dst(dst_bytes, 0);
const auto* src_ptr = reinterpret_cast(src);
const auto* bias_ptr = reinterpret_cast(bias);
auto* dst_ptr = dst.data();
for (size_t y_block = 0; y_block < height; y_block += block_height) {
- // Packs the zero points.
+ // Packs the bias.
const auto bias_len = std::min(block_height, height - y_block);
- memcpy(dst_ptr, bias_ptr, bias_len * zero_point_esize);
- bias_ptr += block_height * zero_point_esize;
- dst_ptr += block_height * zero_point_esize;
+ memcpy(dst_ptr, bias_ptr, bias_len * bias_esize);
+ bias_ptr += block_height * bias_esize;
+ dst_ptr += block_height * bias_esize;
for (size_t x_block = 0; x_block < width; x_block += block_width) {
for (size_t y_subblock = 0; y_subblock < block_height; y_subblock += subblock_height) {
for (size_t x_subblock = 0; x_subblock < block_width; x_subblock += subblock_width) {
for (size_t y_element = 0; y_element < subblock_height; ++y_element) {
- if (y_block + y_subblock + y_element < height) {
- const auto len = std::min(subblock_width, width - x_block - x_subblock);
- memcpy(
- dst_ptr,
- src_ptr +
- ((y_block + y_subblock + y_element) * width + x_block + x_subblock) * data_esize,
- len * data_esize);
+ if (src_dtype == dst_dtype) {
+ const size_t esize = dst_esize;
+ if (y_block + y_subblock + y_element < height) {
+ const auto len = std::min(subblock_width, width - x_block - x_subblock);
+
+ memcpy(
+ dst_ptr,
+ src_ptr +
+ ((y_block + y_subblock + y_element) * width + x_block + x_subblock) * esize,
+ len * esize);
+ }
+
+ dst_ptr += subblock_width * esize;
+ } else if (dst_esize == 2 /* 16 bits */) {
+ for (size_t x_element = 0; x_element < subblock_width; ++x_element) {
+ if (y_block + y_subblock + y_element < height) {
+ if (x_block + x_subblock + x_element < width) {
+ const uint8_t* src_ptr_elm = src_ptr +
+ ((y_block + y_subblock + y_element) * width + x_block + x_subblock +
+ x_element) *
+ src_esize;
+
+ const uint16_t dst_value = convert(src_ptr_elm, src_dtype, dst_dtype);
+ memcpy(dst_ptr, &dst_value, dst_esize);
+ }
+ }
+
+ dst_ptr += dst_esize;
+ }
}
- dst_ptr += subblock_width * data_esize;
}
}
}
@@ -118,7 +162,7 @@ std::vector pack_bias_per_row(
} // namespace
std::vector pack(
- const DataFormat& dst_format, const void* src, [[maybe_unused]] const void* scales, const void* zero_points,
+ const DataFormat& dst_format, const void* src, [[maybe_unused]] const void* scales, const void* bias,
const DataFormat& src_format, size_t height, size_t width) {
const auto dst_dt = dst_format.data_type();
const auto dst_qf = dst_format.pack_format();
@@ -131,27 +175,31 @@ std::vector pack(
const auto subblock_width = dst_format.actual_subblock_width(width);
if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::BIAS_PER_ROW) {
- KAI_ASSUME(src_dt == dst_dt);
+ KAI_ASSUME((src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16));
- const auto data_esize = data_type_size_in_bits(dst_dt);
- const auto zero_point_esize = data_type_size_in_bits(dst_format.zero_point_data_type());
+ const auto src_esize = data_type_size_in_bits(src_dt);
+ const auto dst_esize = data_type_size_in_bits(dst_dt);
+ const auto bias_esize = data_type_size_in_bits(dst_format.zero_point_data_type());
+ const auto bias_dt = dst_format.zero_point_data_type();
- if (data_esize % 8 == 0 && zero_point_esize % 8 == 0) {
- return pack_bias_per_row(
- data_esize / 8, zero_point_esize / 8, src, zero_points, height, width, block_height, block_width,
- subblock_height, subblock_width);
- }
+ KAI_ASSUME(dst_esize % 8 == 0 && bias_esize % 8 == 0 && src_esize % 8 == 0);
+
+ return pack_bias_per_row(
+ src_dt, bias_dt, dst_dt, src_esize / 8, bias_esize / 8, dst_esize / 8, src, bias, height, width,
+ block_height, block_width, subblock_height, subblock_width);
}
if (src_qf == DataFormat::PackFormat::NONE && dst_qf == DataFormat::PackFormat::NONE) {
- KAI_ASSUME(src_dt == dst_dt);
+ KAI_ASSUME((src_dt == dst_dt) || (src_dt == DataType::FP32 && dst_dt == DataType::BF16));
- const auto data_esize = data_type_size_in_bits(dst_dt);
+ const auto dst_esize = data_type_size_in_bits(dst_dt);
+ const auto src_esize = data_type_size_in_bits(src_dt);
- if (data_esize % 8 == 0) {
- return pack_block(
- src, data_esize / 8, height, width, block_height, block_width, subblock_height, subblock_width);
- }
+ KAI_ASSUME(src_esize % 8 == 0 && dst_esize % 8 == 0);
+
+ return pack_block(
+ src, src_dt, dst_dt, src_esize / 8, dst_esize / 8, height, width, block_height, block_width,
+ subblock_height, subblock_width);
}
KAI_ERROR("Unsupported operation!");
diff --git a/test/reference/pack.hpp b/test/reference/pack.hpp
index 10d76a7f6a289fea0a5217c391cea77b703cedaa..128ad0400b09a4a3ae8e1d503fc8ae0824837ebe 100644
--- a/test/reference/pack.hpp
+++ b/test/reference/pack.hpp
@@ -22,8 +22,8 @@ class DataFormat;
/// @param[in] height Number of rows of the source matrix.
/// @param[in] width Number of columns of the source matrix.
std::vector pack(
- const DataFormat& dst_format, const void* src, const void* scales, const void* zero_points,
- const DataFormat& src_format, size_t height, size_t width);
+ const DataFormat& dst_format, const void* src, const void* scales, const void* bias, const DataFormat& src_format,
+ size_t height, size_t width);
/// Packs the quantized data and the quantization scale into a single buffer.
///
diff --git a/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..730ff5aed286f79ab87bac3014fa9b150000f4c6
--- /dev/null
+++ b/test/tests/matmul_clamp_f32_bf16p_bf16p_test.cpp
@@ -0,0 +1,374 @@
+//
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include